From eda5bc26f44ee9a6f83dcf8c91f17296d7fc509d Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 12 Feb 2024 14:52:43 +0100 Subject: Move into version control --- .../include/armadillo_bits/BaseCube_bones.hpp | 86 + .../include/armadillo_bits/BaseCube_meat.hpp | 498 + .../include/armadillo_bits/Base_bones.hpp | 167 + src/armadillo/include/armadillo_bits/Base_meat.hpp | 1031 ++ src/armadillo/include/armadillo_bits/Col_bones.hpp | 288 + src/armadillo/include/armadillo_bits/Col_meat.hpp | 1888 ++++ .../include/armadillo_bits/CubeToMatOp_bones.hpp | 46 + .../include/armadillo_bits/CubeToMatOp_meat.hpp | 54 + .../include/armadillo_bits/Cube_bones.hpp | 564 + src/armadillo/include/armadillo_bits/Cube_meat.hpp | 5920 +++++++++++ .../include/armadillo_bits/GenCube_bones.hpp | 58 + .../include/armadillo_bits/GenCube_meat.hpp | 188 + src/armadillo/include/armadillo_bits/Gen_bones.hpp | 61 + src/armadillo/include/armadillo_bits/Gen_meat.hpp | 232 + .../include/armadillo_bits/GlueCube_bones.hpp | 42 + .../include/armadillo_bits/GlueCube_meat.hpp | 44 + .../include/armadillo_bits/Glue_bones.hpp | 66 + src/armadillo/include/armadillo_bits/Glue_meat.hpp | 56 + .../include/armadillo_bits/MapMat_bones.hpp | 247 + .../include/armadillo_bits/MapMat_meat.hpp | 1778 ++++ src/armadillo/include/armadillo_bits/Mat_bones.hpp | 945 ++ src/armadillo/include/armadillo_bits/Mat_meat.hpp | 10169 +++++++++++++++++++ .../include/armadillo_bits/OpCube_bones.hpp | 47 + .../include/armadillo_bits/OpCube_meat.hpp | 87 + src/armadillo/include/armadillo_bits/Op_bones.hpp | 69 + src/armadillo/include/armadillo_bits/Op_meat.hpp | 79 + src/armadillo/include/armadillo_bits/Proxy.hpp | 2537 +++++ src/armadillo/include/armadillo_bits/ProxyCube.hpp | 488 + src/armadillo/include/armadillo_bits/Row_bones.hpp | 288 + src/armadillo/include/armadillo_bits/Row_meat.hpp | 1888 ++++ .../include/armadillo_bits/SizeCube_bones.hpp | 52 + .../include/armadillo_bits/SizeCube_meat.hpp | 155 + .../include/armadillo_bits/SizeMat_bones.hpp | 51 + .../include/armadillo_bits/SizeMat_meat.hpp | 146 + .../include/armadillo_bits/SpBase_bones.hpp | 116 + .../include/armadillo_bits/SpBase_meat.hpp | 883 ++ .../include/armadillo_bits/SpCol_bones.hpp | 82 + .../include/armadillo_bits/SpCol_meat.hpp | 432 + .../include/armadillo_bits/SpGlue_bones.hpp | 49 + .../include/armadillo_bits/SpGlue_meat.hpp | 66 + .../include/armadillo_bits/SpMat_bones.hpp | 747 ++ .../armadillo_bits/SpMat_iterators_meat.hpp | 964 ++ .../include/armadillo_bits/SpMat_meat.hpp | 6855 +++++++++++++ .../include/armadillo_bits/SpOp_bones.hpp | 51 + src/armadillo/include/armadillo_bits/SpOp_meat.hpp | 76 + src/armadillo/include/armadillo_bits/SpProxy.hpp | 688 ++ .../include/armadillo_bits/SpRow_bones.hpp | 85 + .../include/armadillo_bits/SpRow_meat.hpp | 433 + .../include/armadillo_bits/SpSubview_bones.hpp | 418 + .../armadillo_bits/SpSubview_col_list_bones.hpp | 96 + .../armadillo_bits/SpSubview_col_list_meat.hpp | 719 ++ .../armadillo_bits/SpSubview_iterators_meat.hpp | 1154 +++ .../include/armadillo_bits/SpSubview_meat.hpp | 2006 ++++ .../include/armadillo_bits/SpToDGlue_bones.hpp | 45 + .../include/armadillo_bits/SpToDGlue_meat.hpp | 44 + .../include/armadillo_bits/SpToDOp_bones.hpp | 57 + .../include/armadillo_bits/SpToDOp_meat.hpp | 54 + .../include/armadillo_bits/SpValProxy_bones.hpp | 86 + .../include/armadillo_bits/SpValProxy_meat.hpp | 364 + src/armadillo/include/armadillo_bits/access.hpp | 45 + .../include/armadillo_bits/arma_cmath.hpp | 378 + .../include/armadillo_bits/arma_config.hpp | 252 + .../include/armadillo_bits/arma_forward.hpp | 475 + .../include/armadillo_bits/arma_ostream_bones.hpp | 79 + .../include/armadillo_bits/arma_ostream_meat.hpp | 1274 +++ .../armadillo_bits/arma_rel_comparators.hpp | 170 + src/armadillo/include/armadillo_bits/arma_rng.hpp | 1042 ++ .../include/armadillo_bits/arma_rng_cxx03.hpp | 185 + .../include/armadillo_bits/arma_static_check.hpp | 30 + src/armadillo/include/armadillo_bits/arma_str.hpp | 437 + .../include/armadillo_bits/arma_version.hpp | 61 + .../include/armadillo_bits/arrayops_bones.hpp | 229 + .../include/armadillo_bits/arrayops_meat.hpp | 1108 ++ .../include/armadillo_bits/auxlib_bones.hpp | 467 + .../include/armadillo_bits/auxlib_meat.hpp | 7050 +++++++++++++ .../include/armadillo_bits/band_helper.hpp | 379 + .../include/armadillo_bits/compiler_check.hpp | 78 + .../include/armadillo_bits/compiler_setup.hpp | 511 + .../include/armadillo_bits/compiler_setup_post.hpp | 24 + .../include/armadillo_bits/cond_rel_bones.hpp | 42 + .../include/armadillo_bits/cond_rel_meat.hpp | 134 + src/armadillo/include/armadillo_bits/config.hpp | 351 + .../include/armadillo_bits/config.hpp.cmake | 351 + src/armadillo/include/armadillo_bits/constants.hpp | 263 + .../include/armadillo_bits/constants_old.hpp | 93 + src/armadillo/include/armadillo_bits/csv_name.hpp | 138 + src/armadillo/include/armadillo_bits/debug.hpp | 1467 +++ .../include/armadillo_bits/def_arpack.hpp | 109 + src/armadillo/include/armadillo_bits/def_atlas.hpp | 79 + src/armadillo/include/armadillo_bits/def_blas.hpp | 161 + src/armadillo/include/armadillo_bits/def_fftw3.hpp | 49 + .../include/armadillo_bits/def_lapack.hpp | 1178 +++ .../include/armadillo_bits/def_superlu.hpp | 78 + .../include/armadillo_bits/diagmat_proxy.hpp | 375 + .../include/armadillo_bits/diagview_bones.hpp | 117 + .../include/armadillo_bits/diagview_meat.hpp | 1025 ++ .../include/armadillo_bits/diskio_bones.hpp | 229 + .../include/armadillo_bits/diskio_meat.hpp | 5356 ++++++++++ .../include/armadillo_bits/distr_param.hpp | 91 + .../include/armadillo_bits/eGlueCube_bones.hpp | 54 + .../include/armadillo_bits/eGlueCube_meat.hpp | 153 + .../include/armadillo_bits/eGlue_bones.hpp | 58 + .../include/armadillo_bits/eGlue_meat.hpp | 136 + .../include/armadillo_bits/eOpCube_bones.hpp | 62 + .../include/armadillo_bits/eOpCube_meat.hpp | 173 + src/armadillo/include/armadillo_bits/eOp_bones.hpp | 64 + src/armadillo/include/armadillo_bits/eOp_meat.hpp | 151 + .../include/armadillo_bits/eglue_core_bones.hpp | 86 + .../include/armadillo_bits/eglue_core_meat.hpp | 1250 +++ src/armadillo/include/armadillo_bits/eop_aux.hpp | 191 + .../include/armadillo_bits/eop_core_bones.hpp | 116 + .../include/armadillo_bits/eop_core_meat.hpp | 1168 +++ .../include/armadillo_bits/fft_engine_fftw3.hpp | 104 + .../include/armadillo_bits/fft_engine_kissfft.hpp | 392 + .../include/armadillo_bits/field_bones.hpp | 357 + .../include/armadillo_bits/field_meat.hpp | 2999 ++++++ src/armadillo/include/armadillo_bits/fill.hpp | 116 + src/armadillo/include/armadillo_bits/fn_accu.hpp | 1002 ++ src/armadillo/include/armadillo_bits/fn_all.hpp | 95 + src/armadillo/include/armadillo_bits/fn_any.hpp | 95 + .../include/armadillo_bits/fn_approx_equal.hpp | 471 + .../include/armadillo_bits/fn_as_scalar.hpp | 379 + .../include/armadillo_bits/fn_chi2rnd.hpp | 182 + src/armadillo/include/armadillo_bits/fn_chol.hpp | 149 + src/armadillo/include/armadillo_bits/fn_clamp.hpp | 117 + .../include/armadillo_bits/fn_cond_rcond.hpp | 63 + src/armadillo/include/armadillo_bits/fn_conv.hpp | 74 + .../include/armadillo_bits/fn_conv_to.hpp | 720 ++ src/armadillo/include/armadillo_bits/fn_cor.hpp | 54 + src/armadillo/include/armadillo_bits/fn_cov.hpp | 54 + src/armadillo/include/armadillo_bits/fn_cross.hpp | 43 + .../include/armadillo_bits/fn_cumprod.hpp | 89 + src/armadillo/include/armadillo_bits/fn_cumsum.hpp | 89 + src/armadillo/include/armadillo_bits/fn_det.hpp | 82 + .../include/armadillo_bits/fn_diagmat.hpp | 93 + .../include/armadillo_bits/fn_diags_spdiags.hpp | 134 + .../include/armadillo_bits/fn_diagvec.hpp | 64 + src/armadillo/include/armadillo_bits/fn_diff.hpp | 91 + src/armadillo/include/armadillo_bits/fn_dot.hpp | 340 + .../include/armadillo_bits/fn_eig_gen.hpp | 170 + .../include/armadillo_bits/fn_eig_pair.hpp | 144 + .../include/armadillo_bits/fn_eig_sym.hpp | 161 + .../include/armadillo_bits/fn_eigs_gen.hpp | 425 + .../include/armadillo_bits/fn_eigs_sym.hpp | 290 + src/armadillo/include/armadillo_bits/fn_elem.hpp | 1167 +++ src/armadillo/include/armadillo_bits/fn_eps.hpp | 106 + src/armadillo/include/armadillo_bits/fn_expmat.hpp | 103 + src/armadillo/include/armadillo_bits/fn_eye.hpp | 114 + src/armadillo/include/armadillo_bits/fn_fft.hpp | 136 + src/armadillo/include/armadillo_bits/fn_fft2.hpp | 136 + src/armadillo/include/armadillo_bits/fn_find.hpp | 469 + .../include/armadillo_bits/fn_find_unique.hpp | 69 + src/armadillo/include/armadillo_bits/fn_flip.hpp | 76 + src/armadillo/include/armadillo_bits/fn_hess.hpp | 174 + src/armadillo/include/armadillo_bits/fn_hist.hpp | 76 + src/armadillo/include/armadillo_bits/fn_histc.hpp | 58 + .../include/armadillo_bits/fn_index_max.hpp | 164 + .../include/armadillo_bits/fn_index_min.hpp | 164 + .../include/armadillo_bits/fn_inplace_strans.hpp | 95 + .../include/armadillo_bits/fn_inplace_trans.hpp | 131 + .../include/armadillo_bits/fn_interp1.hpp | 351 + .../include/armadillo_bits/fn_interp2.hpp | 264 + .../include/armadillo_bits/fn_intersect.hpp | 65 + src/armadillo/include/armadillo_bits/fn_inv.hpp | 138 + .../include/armadillo_bits/fn_inv_sympd.hpp | 138 + src/armadillo/include/armadillo_bits/fn_join.hpp | 502 + src/armadillo/include/armadillo_bits/fn_kmeans.hpp | 59 + src/armadillo/include/armadillo_bits/fn_kron.hpp | 104 + .../include/armadillo_bits/fn_log_det.hpp | 157 + .../include/armadillo_bits/fn_log_normpdf.hpp | 205 + src/armadillo/include/armadillo_bits/fn_logmat.hpp | 127 + src/armadillo/include/armadillo_bits/fn_lu.hpp | 88 + src/armadillo/include/armadillo_bits/fn_max.hpp | 277 + src/armadillo/include/armadillo_bits/fn_mean.hpp | 145 + src/armadillo/include/armadillo_bits/fn_median.hpp | 73 + src/armadillo/include/armadillo_bits/fn_min.hpp | 277 + src/armadillo/include/armadillo_bits/fn_misc.hpp | 587 ++ src/armadillo/include/armadillo_bits/fn_mvnrnd.hpp | 110 + .../include/armadillo_bits/fn_n_unique.hpp | 132 + .../include/armadillo_bits/fn_nonzeros.hpp | 49 + src/armadillo/include/armadillo_bits/fn_norm.hpp | 342 + .../include/armadillo_bits/fn_normalise.hpp | 116 + .../include/armadillo_bits/fn_normcdf.hpp | 201 + .../include/armadillo_bits/fn_normpdf.hpp | 205 + src/armadillo/include/armadillo_bits/fn_numel.hpp | 95 + src/armadillo/include/armadillo_bits/fn_ones.hpp | 161 + .../include/armadillo_bits/fn_orth_null.hpp | 98 + src/armadillo/include/armadillo_bits/fn_pinv.hpp | 110 + .../include/armadillo_bits/fn_polyfit.hpp | 67 + .../include/armadillo_bits/fn_polyval.hpp | 42 + src/armadillo/include/armadillo_bits/fn_powext.hpp | 179 + src/armadillo/include/armadillo_bits/fn_powmat.hpp | 108 + .../include/armadillo_bits/fn_princomp.hpp | 180 + src/armadillo/include/armadillo_bits/fn_prod.hpp | 81 + src/armadillo/include/armadillo_bits/fn_qr.hpp | 145 + .../include/armadillo_bits/fn_quantile.hpp | 58 + src/armadillo/include/armadillo_bits/fn_qz.hpp | 66 + src/armadillo/include/armadillo_bits/fn_randg.hpp | 241 + src/armadillo/include/armadillo_bits/fn_randi.hpp | 270 + src/armadillo/include/armadillo_bits/fn_randn.hpp | 357 + .../include/armadillo_bits/fn_randperm.hpp | 153 + src/armadillo/include/armadillo_bits/fn_randu.hpp | 357 + src/armadillo/include/armadillo_bits/fn_range.hpp | 62 + src/armadillo/include/armadillo_bits/fn_rank.hpp | 57 + .../include/armadillo_bits/fn_regspace.hpp | 265 + .../include/armadillo_bits/fn_repelem.hpp | 55 + src/armadillo/include/armadillo_bits/fn_repmat.hpp | 55 + .../include/armadillo_bits/fn_reshape.hpp | 138 + src/armadillo/include/armadillo_bits/fn_resize.hpp | 102 + .../include/armadillo_bits/fn_reverse.hpp | 100 + src/armadillo/include/armadillo_bits/fn_roots.hpp | 67 + src/armadillo/include/armadillo_bits/fn_schur.hpp | 114 + src/armadillo/include/armadillo_bits/fn_shift.hpp | 118 + .../include/armadillo_bits/fn_shuffle.hpp | 88 + src/armadillo/include/armadillo_bits/fn_size.hpp | 327 + src/armadillo/include/armadillo_bits/fn_solve.hpp | 224 + src/armadillo/include/armadillo_bits/fn_sort.hpp | 151 + .../include/armadillo_bits/fn_sort_index.hpp | 112 + src/armadillo/include/armadillo_bits/fn_speye.hpp | 93 + src/armadillo/include/armadillo_bits/fn_spones.hpp | 47 + .../include/armadillo_bits/fn_sprandn.hpp | 127 + .../include/armadillo_bits/fn_sprandu.hpp | 127 + .../include/armadillo_bits/fn_spsolve.hpp | 192 + .../include/armadillo_bits/fn_sqrtmat.hpp | 125 + src/armadillo/include/armadillo_bits/fn_stddev.hpp | 89 + src/armadillo/include/armadillo_bits/fn_strans.hpp | 110 + src/armadillo/include/armadillo_bits/fn_sum.hpp | 147 + src/armadillo/include/armadillo_bits/fn_svd.hpp | 206 + src/armadillo/include/armadillo_bits/fn_svds.hpp | 352 + .../include/armadillo_bits/fn_sylvester.hpp | 137 + src/armadillo/include/armadillo_bits/fn_symmat.hpp | 135 + .../include/armadillo_bits/fn_toeplitz.hpp | 63 + src/armadillo/include/armadillo_bits/fn_trace.hpp | 663 ++ src/armadillo/include/armadillo_bits/fn_trans.hpp | 99 + src/armadillo/include/armadillo_bits/fn_trapz.hpp | 59 + src/armadillo/include/armadillo_bits/fn_trig.hpp | 493 + src/armadillo/include/armadillo_bits/fn_trimat.hpp | 143 + .../include/armadillo_bits/fn_trimat_ind.hpp | 139 + .../include/armadillo_bits/fn_trunc_exp.hpp | 93 + .../include/armadillo_bits/fn_trunc_log.hpp | 100 + src/armadillo/include/armadillo_bits/fn_unique.hpp | 57 + src/armadillo/include/armadillo_bits/fn_var.hpp | 143 + .../include/armadillo_bits/fn_vecnorm.hpp | 385 + .../include/armadillo_bits/fn_vectorise.hpp | 114 + .../include/armadillo_bits/fn_wishrnd.hpp | 204 + src/armadillo/include/armadillo_bits/fn_zeros.hpp | 192 + .../include/armadillo_bits/glue_affmul_bones.hpp | 55 + .../include/armadillo_bits/glue_affmul_meat.hpp | 490 + .../include/armadillo_bits/glue_atan2_bones.hpp | 47 + .../include/armadillo_bits/glue_atan2_meat.hpp | 228 + .../include/armadillo_bits/glue_conv_bones.hpp | 57 + .../include/armadillo_bits/glue_conv_meat.hpp | 385 + .../include/armadillo_bits/glue_cor_bones.hpp | 43 + .../include/armadillo_bits/glue_cor_meat.hpp | 71 + .../include/armadillo_bits/glue_cov_bones.hpp | 43 + .../include/armadillo_bits/glue_cov_meat.hpp | 69 + .../include/armadillo_bits/glue_cross_bones.hpp | 42 + .../include/armadillo_bits/glue_cross_meat.hpp | 81 + .../include/armadillo_bits/glue_hist_bones.hpp | 54 + .../include/armadillo_bits/glue_hist_meat.hpp | 253 + .../include/armadillo_bits/glue_histc_bones.hpp | 54 + .../include/armadillo_bits/glue_histc_meat.hpp | 167 + .../include/armadillo_bits/glue_hypot_bones.hpp | 47 + .../include/armadillo_bits/glue_hypot_meat.hpp | 172 + .../armadillo_bits/glue_intersect_bones.hpp | 46 + .../include/armadillo_bits/glue_intersect_meat.hpp | 148 + .../include/armadillo_bits/glue_join_bones.hpp | 90 + .../include/armadillo_bits/glue_join_meat.hpp | 379 + .../include/armadillo_bits/glue_kron_bones.hpp | 46 + .../include/armadillo_bits/glue_kron_meat.hpp | 147 + .../include/armadillo_bits/glue_max_bones.hpp | 47 + .../include/armadillo_bits/glue_max_meat.hpp | 183 + .../include/armadillo_bits/glue_min_bones.hpp | 47 + .../include/armadillo_bits/glue_min_meat.hpp | 183 + .../include/armadillo_bits/glue_mixed_bones.hpp | 98 + .../include/armadillo_bits/glue_mixed_meat.hpp | 560 + .../include/armadillo_bits/glue_mvnrnd_bones.hpp | 58 + .../include/armadillo_bits/glue_mvnrnd_meat.hpp | 175 + .../include/armadillo_bits/glue_polyfit_bones.hpp | 47 + .../include/armadillo_bits/glue_polyfit_meat.hpp | 133 + .../include/armadillo_bits/glue_polyval_bones.hpp | 45 + .../include/armadillo_bits/glue_polyval_meat.hpp | 83 + .../include/armadillo_bits/glue_powext_bones.hpp | 70 + .../include/armadillo_bits/glue_powext_meat.hpp | 674 ++ .../include/armadillo_bits/glue_quantile_bones.hpp | 58 + .../include/armadillo_bits/glue_quantile_meat.hpp | 230 + .../armadillo_bits/glue_relational_bones.hpp | 136 + .../armadillo_bits/glue_relational_meat.hpp | 419 + .../include/armadillo_bits/glue_solve_bones.hpp | 175 + .../include/armadillo_bits/glue_solve_meat.hpp | 587 ++ .../include/armadillo_bits/glue_times_bones.hpp | 168 + .../include/armadillo_bits/glue_times_meat.hpp | 952 ++ .../armadillo_bits/glue_times_misc_bones.hpp | 88 + .../armadillo_bits/glue_times_misc_meat.hpp | 646 ++ .../include/armadillo_bits/glue_toeplitz_bones.hpp | 35 + .../include/armadillo_bits/glue_toeplitz_meat.hpp | 73 + .../include/armadillo_bits/glue_trapz_bones.hpp | 56 + .../include/armadillo_bits/glue_trapz_meat.hpp | 168 + .../include/armadillo_bits/gmm_diag_bones.hpp | 179 + .../include/armadillo_bits/gmm_diag_meat.hpp | 2655 +++++ .../include/armadillo_bits/gmm_full_bones.hpp | 167 + .../include/armadillo_bits/gmm_full_meat.hpp | 2739 +++++ .../include/armadillo_bits/gmm_misc_bones.hpp | 119 + .../include/armadillo_bits/gmm_misc_meat.hpp | 193 + src/armadillo/include/armadillo_bits/hdf5_misc.hpp | 772 ++ src/armadillo/include/armadillo_bits/hdf5_name.hpp | 93 + .../include/armadillo_bits/include_hdf5.hpp | 45 + .../include/armadillo_bits/include_superlu.hpp | 393 + .../include/armadillo_bits/injector_bones.hpp | 84 + .../include/armadillo_bits/injector_meat.hpp | 379 + src/armadillo/include/armadillo_bits/memory.hpp | 224 + src/armadillo/include/armadillo_bits/mp_misc.hpp | 91 + .../include/armadillo_bits/mtGlueCube_bones.hpp | 43 + .../include/armadillo_bits/mtGlueCube_meat.hpp | 56 + .../include/armadillo_bits/mtGlue_bones.hpp | 47 + .../include/armadillo_bits/mtGlue_meat.hpp | 56 + .../include/armadillo_bits/mtOpCube_bones.hpp | 60 + .../include/armadillo_bits/mtOpCube_meat.hpp | 105 + .../include/armadillo_bits/mtOp_bones.hpp | 62 + src/armadillo/include/armadillo_bits/mtOp_meat.hpp | 104 + .../include/armadillo_bits/mtSpGlue_bones.hpp | 48 + .../include/armadillo_bits/mtSpGlue_meat.hpp | 55 + .../include/armadillo_bits/mtSpOp_bones.hpp | 57 + .../include/armadillo_bits/mtSpOp_meat.hpp | 79 + src/armadillo/include/armadillo_bits/mul_gemm.hpp | 435 + .../include/armadillo_bits/mul_gemm_mixed.hpp | 291 + src/armadillo/include/armadillo_bits/mul_gemv.hpp | 495 + src/armadillo/include/armadillo_bits/mul_herk.hpp | 492 + src/armadillo/include/armadillo_bits/mul_syrk.hpp | 501 + .../newarp_DenseGenMatProd_bones.hpp | 43 + .../armadillo_bits/newarp_DenseGenMatProd_meat.hpp | 51 + .../armadillo_bits/newarp_DoubleShiftQR_bones.hpp | 76 + .../armadillo_bits/newarp_DoubleShiftQR_meat.hpp | 399 + .../include/armadillo_bits/newarp_EigsSelect.hpp | 52 + .../armadillo_bits/newarp_GenEigsSolver_bones.hpp | 109 + .../armadillo_bits/newarp_GenEigsSolver_meat.hpp | 492 + .../armadillo_bits/newarp_SortEigenvalue.hpp | 203 + .../newarp_SparseGenMatProd_bones.hpp | 44 + .../newarp_SparseGenMatProd_meat.hpp | 63 + .../newarp_SparseGenRealShiftSolve_bones.hpp | 51 + .../newarp_SparseGenRealShiftSolve_meat.hpp | 138 + .../newarp_SymEigsShiftSolver_bones.hpp | 43 + .../newarp_SymEigsShiftSolver_meat.hpp | 50 + .../armadillo_bits/newarp_SymEigsSolver_bones.hpp | 107 + .../armadillo_bits/newarp_SymEigsSolver_meat.hpp | 508 + .../armadillo_bits/newarp_TridiagEigen_bones.hpp | 58 + .../armadillo_bits/newarp_TridiagEigen_meat.hpp | 132 + .../newarp_UpperHessenbergEigen_bones.hpp | 59 + .../newarp_UpperHessenbergEigen_meat.hpp | 168 + .../newarp_UpperHessenbergQR_bones.hpp | 86 + .../newarp_UpperHessenbergQR_meat.hpp | 310 + .../include/armadillo_bits/newarp_cx_attrib.hpp | 37 + .../include/armadillo_bits/op_all_bones.hpp | 81 + .../include/armadillo_bits/op_all_meat.hpp | 406 + .../include/armadillo_bits/op_any_bones.hpp | 81 + .../include/armadillo_bits/op_any_meat.hpp | 377 + .../include/armadillo_bits/op_chi2rnd_bones.hpp | 54 + .../include/armadillo_bits/op_chi2rnd_meat.hpp | 176 + .../include/armadillo_bits/op_chol_bones.hpp | 38 + .../include/armadillo_bits/op_chol_meat.hpp | 74 + .../include/armadillo_bits/op_clamp_bones.hpp | 74 + .../include/armadillo_bits/op_clamp_meat.hpp | 577 ++ .../include/armadillo_bits/op_col_as_mat_bones.hpp | 33 + .../include/armadillo_bits/op_col_as_mat_meat.hpp | 53 + .../include/armadillo_bits/op_cond_bones.hpp | 36 + .../include/armadillo_bits/op_cond_meat.hpp | 174 + .../include/armadillo_bits/op_cor_bones.hpp | 36 + .../include/armadillo_bits/op_cor_meat.hpp | 126 + .../include/armadillo_bits/op_cov_bones.hpp | 36 + .../include/armadillo_bits/op_cov_meat.hpp | 104 + .../include/armadillo_bits/op_cumprod_bones.hpp | 49 + .../include/armadillo_bits/op_cumprod_meat.hpp | 174 + .../include/armadillo_bits/op_cumsum_bones.hpp | 49 + .../include/armadillo_bits/op_cumsum_meat.hpp | 174 + .../include/armadillo_bits/op_cx_scalar_bones.hpp | 168 + .../include/armadillo_bits/op_cx_scalar_meat.hpp | 564 + .../include/armadillo_bits/op_det_bones.hpp | 54 + .../include/armadillo_bits/op_det_meat.hpp | 178 + .../include/armadillo_bits/op_diagmat_bones.hpp | 61 + .../include/armadillo_bits/op_diagmat_meat.hpp | 767 ++ .../include/armadillo_bits/op_diagvec_bones.hpp | 58 + .../include/armadillo_bits/op_diagvec_meat.hpp | 536 + .../include/armadillo_bits/op_diff_bones.hpp | 49 + .../include/armadillo_bits/op_diff_meat.hpp | 224 + .../include/armadillo_bits/op_dot_bones.hpp | 121 + .../include/armadillo_bits/op_dot_meat.hpp | 580 ++ .../include/armadillo_bits/op_dotext_bones.hpp | 50 + .../include/armadillo_bits/op_dotext_meat.hpp | 214 + .../include/armadillo_bits/op_expmat_bones.hpp | 53 + .../include/armadillo_bits/op_expmat_meat.hpp | 256 + .../include/armadillo_bits/op_fft_bones.hpp | 61 + .../include/armadillo_bits/op_fft_meat.hpp | 325 + .../include/armadillo_bits/op_find_bones.hpp | 130 + .../include/armadillo_bits/op_find_meat.hpp | 660 ++ .../armadillo_bits/op_find_unique_bones.hpp | 76 + .../include/armadillo_bits/op_find_unique_meat.hpp | 130 + .../include/armadillo_bits/op_flip_bones.hpp | 59 + .../include/armadillo_bits/op_flip_meat.hpp | 341 + .../include/armadillo_bits/op_hist_bones.hpp | 39 + .../include/armadillo_bits/op_hist_meat.hpp | 125 + .../include/armadillo_bits/op_htrans_bones.hpp | 107 + .../include/armadillo_bits/op_htrans_meat.hpp | 419 + .../include/armadillo_bits/op_index_max_bones.hpp | 57 + .../include/armadillo_bits/op_index_max_meat.hpp | 433 + .../include/armadillo_bits/op_index_min_bones.hpp | 57 + .../include/armadillo_bits/op_index_min_meat.hpp | 433 + .../include/armadillo_bits/op_inv_gen_bones.hpp | 143 + .../include/armadillo_bits/op_inv_gen_meat.hpp | 428 + .../include/armadillo_bits/op_inv_spd_bones.hpp | 76 + .../include/armadillo_bits/op_inv_spd_meat.hpp | 365 + .../include/armadillo_bits/op_log_det_bones.hpp | 52 + .../include/armadillo_bits/op_log_det_meat.hpp | 239 + .../include/armadillo_bits/op_logmat_bones.hpp | 82 + .../include/armadillo_bits/op_logmat_meat.hpp | 572 ++ .../include/armadillo_bits/op_max_bones.hpp | 112 + .../include/armadillo_bits/op_max_meat.hpp | 1325 +++ .../include/armadillo_bits/op_mean_bones.hpp | 115 + .../include/armadillo_bits/op_mean_meat.hpp | 713 ++ .../include/armadillo_bits/op_median_bones.hpp | 77 + .../include/armadillo_bits/op_median_meat.hpp | 338 + .../include/armadillo_bits/op_min_bones.hpp | 112 + .../include/armadillo_bits/op_min_meat.hpp | 1325 +++ .../include/armadillo_bits/op_misc_bones.hpp | 80 + .../include/armadillo_bits/op_misc_meat.hpp | 404 + .../include/armadillo_bits/op_nonzeros_bones.hpp | 52 + .../include/armadillo_bits/op_nonzeros_meat.hpp | 151 + .../include/armadillo_bits/op_norm2est_bones.hpp | 60 + .../include/armadillo_bits/op_norm2est_meat.hpp | 248 + .../include/armadillo_bits/op_norm_bones.hpp | 52 + .../include/armadillo_bits/op_norm_meat.hpp | 905 ++ .../include/armadillo_bits/op_normalise_bones.hpp | 47 + .../include/armadillo_bits/op_normalise_meat.hpp | 148 + .../include/armadillo_bits/op_orth_null_bones.hpp | 53 + .../include/armadillo_bits/op_orth_null_meat.hpp | 181 + .../include/armadillo_bits/op_pinv_bones.hpp | 55 + .../include/armadillo_bits/op_pinv_meat.hpp | 313 + .../include/armadillo_bits/op_powmat_bones.hpp | 56 + .../include/armadillo_bits/op_powmat_meat.hpp | 261 + .../include/armadillo_bits/op_princomp_bones.hpp | 75 + .../include/armadillo_bits/op_princomp_meat.hpp | 319 + .../include/armadillo_bits/op_prod_bones.hpp | 42 + .../include/armadillo_bits/op_prod_meat.hpp | 217 + .../include/armadillo_bits/op_range_bones.hpp | 40 + .../include/armadillo_bits/op_range_meat.hpp | 96 + .../include/armadillo_bits/op_rank_bones.hpp | 41 + .../include/armadillo_bits/op_rank_meat.hpp | 184 + .../include/armadillo_bits/op_rcond_bones.hpp | 32 + .../include/armadillo_bits/op_rcond_meat.hpp | 113 + .../include/armadillo_bits/op_relational_bones.hpp | 164 + .../include/armadillo_bits/op_relational_meat.hpp | 510 + .../include/armadillo_bits/op_repelem_bones.hpp | 37 + .../include/armadillo_bits/op_repelem_meat.hpp | 103 + .../include/armadillo_bits/op_repmat_bones.hpp | 37 + .../include/armadillo_bits/op_repmat_meat.hpp | 124 + .../include/armadillo_bits/op_reshape_bones.hpp | 49 + .../include/armadillo_bits/op_reshape_meat.hpp | 246 + .../include/armadillo_bits/op_resize_bones.hpp | 47 + .../include/armadillo_bits/op_resize_meat.hpp | 169 + .../include/armadillo_bits/op_reverse_bones.hpp | 46 + .../include/armadillo_bits/op_reverse_meat.hpp | 128 + .../include/armadillo_bits/op_roots_bones.hpp | 41 + .../include/armadillo_bits/op_roots_meat.hpp | 140 + .../include/armadillo_bits/op_row_as_mat_bones.hpp | 33 + .../include/armadillo_bits/op_row_as_mat_meat.hpp | 63 + .../include/armadillo_bits/op_shift_bones.hpp | 45 + .../include/armadillo_bits/op_shift_meat.hpp | 181 + .../include/armadillo_bits/op_shuffle_bones.hpp | 47 + .../include/armadillo_bits/op_shuffle_meat.hpp | 234 + .../include/armadillo_bits/op_sort_bones.hpp | 61 + .../include/armadillo_bits/op_sort_index_bones.hpp | 137 + .../include/armadillo_bits/op_sort_index_meat.hpp | 206 + .../include/armadillo_bits/op_sort_meat.hpp | 242 + .../include/armadillo_bits/op_sp_minus_bones.hpp | 72 + .../include/armadillo_bits/op_sp_minus_meat.hpp | 255 + .../include/armadillo_bits/op_sp_plus_bones.hpp | 48 + .../include/armadillo_bits/op_sp_plus_meat.hpp | 139 + .../include/armadillo_bits/op_sqrtmat_bones.hpp | 78 + .../include/armadillo_bits/op_sqrtmat_meat.hpp | 549 + .../include/armadillo_bits/op_stddev_bones.hpp | 38 + .../include/armadillo_bits/op_stddev_meat.hpp | 112 + .../include/armadillo_bits/op_strans_bones.hpp | 85 + .../include/armadillo_bits/op_strans_meat.hpp | 465 + .../include/armadillo_bits/op_sum_bones.hpp | 59 + .../include/armadillo_bits/op_sum_meat.hpp | 430 + .../include/armadillo_bits/op_symmat_bones.hpp | 68 + .../include/armadillo_bits/op_symmat_meat.hpp | 278 + .../include/armadillo_bits/op_toeplitz_bones.hpp | 46 + .../include/armadillo_bits/op_toeplitz_meat.hpp | 110 + .../include/armadillo_bits/op_trimat_bones.hpp | 76 + .../include/armadillo_bits/op_trimat_meat.hpp | 381 + .../include/armadillo_bits/op_unique_bones.hpp | 79 + .../include/armadillo_bits/op_unique_meat.hpp | 174 + .../include/armadillo_bits/op_var_bones.hpp | 67 + .../include/armadillo_bits/op_var_meat.hpp | 330 + .../include/armadillo_bits/op_vecnorm_bones.hpp | 55 + .../include/armadillo_bits/op_vecnorm_meat.hpp | 254 + .../include/armadillo_bits/op_vectorise_bones.hpp | 81 + .../include/armadillo_bits/op_vectorise_meat.hpp | 463 + .../include/armadillo_bits/op_wishrnd_bones.hpp | 63 + .../include/armadillo_bits/op_wishrnd_meat.hpp | 281 + .../include/armadillo_bits/operator_cube_div.hpp | 197 + .../include/armadillo_bits/operator_cube_minus.hpp | 213 + .../include/armadillo_bits/operator_cube_plus.hpp | 213 + .../armadillo_bits/operator_cube_relational.hpp | 301 + .../include/armadillo_bits/operator_cube_schur.hpp | 131 + .../include/armadillo_bits/operator_cube_times.hpp | 124 + .../include/armadillo_bits/operator_div.hpp | 382 + .../include/armadillo_bits/operator_minus.hpp | 570 ++ .../include/armadillo_bits/operator_ostream.hpp | 186 + .../include/armadillo_bits/operator_plus.hpp | 540 + .../include/armadillo_bits/operator_relational.hpp | 483 + .../include/armadillo_bits/operator_schur.hpp | 366 + .../include/armadillo_bits/operator_times.hpp | 482 + .../include/armadillo_bits/podarray_bones.hpp | 90 + .../include/armadillo_bits/podarray_meat.hpp | 309 + .../include/armadillo_bits/promote_type.hpp | 216 + .../include/armadillo_bits/restrictors.hpp | 214 + .../include/armadillo_bits/running_stat_bones.hpp | 121 + .../include/armadillo_bits/running_stat_meat.hpp | 463 + .../armadillo_bits/running_stat_vec_bones.hpp | 157 + .../armadillo_bits/running_stat_vec_meat.hpp | 636 ++ .../include/armadillo_bits/sp_auxlib_bones.hpp | 283 + .../include/armadillo_bits/sp_auxlib_meat.hpp | 2814 +++++ src/armadillo/include/armadillo_bits/span.hpp | 90 + .../include/armadillo_bits/spdiagview_bones.hpp | 113 + .../include/armadillo_bits/spdiagview_meat.hpp | 1073 ++ .../include/armadillo_bits/spglue_join_bones.hpp | 78 + .../include/armadillo_bits/spglue_join_meat.hpp | 350 + .../include/armadillo_bits/spglue_kron_bones.hpp | 45 + .../include/armadillo_bits/spglue_kron_meat.hpp | 159 + .../include/armadillo_bits/spglue_max_bones.hpp | 56 + .../include/armadillo_bits/spglue_max_meat.hpp | 222 + .../include/armadillo_bits/spglue_merge_bones.hpp | 43 + .../include/armadillo_bits/spglue_merge_meat.hpp | 554 + .../include/armadillo_bits/spglue_min_bones.hpp | 56 + .../include/armadillo_bits/spglue_min_meat.hpp | 222 + .../include/armadillo_bits/spglue_minus_bones.hpp | 59 + .../include/armadillo_bits/spglue_minus_meat.hpp | 340 + .../include/armadillo_bits/spglue_plus_bones.hpp | 55 + .../include/armadillo_bits/spglue_plus_meat.hpp | 295 + .../armadillo_bits/spglue_relational_bones.hpp | 80 + .../armadillo_bits/spglue_relational_meat.hpp | 545 + .../include/armadillo_bits/spglue_schur_bones.hpp | 66 + .../include/armadillo_bits/spglue_schur_meat.hpp | 382 + .../include/armadillo_bits/spglue_times_bones.hpp | 66 + .../include/armadillo_bits/spglue_times_meat.hpp | 369 + .../include/armadillo_bits/spop_diagmat_bones.hpp | 64 + .../include/armadillo_bits/spop_diagmat_meat.hpp | 456 + .../include/armadillo_bits/spop_htrans_bones.hpp | 46 + .../include/armadillo_bits/spop_htrans_meat.hpp | 61 + .../include/armadillo_bits/spop_max_bones.hpp | 61 + .../include/armadillo_bits/spop_max_meat.hpp | 686 ++ .../include/armadillo_bits/spop_mean_bones.hpp | 62 + .../include/armadillo_bits/spop_mean_meat.hpp | 376 + .../include/armadillo_bits/spop_min_bones.hpp | 61 + .../include/armadillo_bits/spop_min_meat.hpp | 722 ++ .../include/armadillo_bits/spop_misc_bones.hpp | 265 + .../include/armadillo_bits/spop_misc_meat.hpp | 596 ++ .../include/armadillo_bits/spop_norm_bones.hpp | 39 + .../include/armadillo_bits/spop_norm_meat.hpp | 129 + .../armadillo_bits/spop_normalise_bones.hpp | 37 + .../include/armadillo_bits/spop_normalise_meat.hpp | 133 + .../include/armadillo_bits/spop_repmat_bones.hpp | 41 + .../include/armadillo_bits/spop_repmat_meat.hpp | 166 + .../include/armadillo_bits/spop_reverse_bones.hpp | 40 + .../include/armadillo_bits/spop_reverse_meat.hpp | 185 + .../include/armadillo_bits/spop_strans_bones.hpp | 49 + .../include/armadillo_bits/spop_strans_meat.hpp | 152 + .../include/armadillo_bits/spop_sum_bones.hpp | 33 + .../include/armadillo_bits/spop_sum_meat.hpp | 104 + .../include/armadillo_bits/spop_symmat_bones.hpp | 46 + .../include/armadillo_bits/spop_symmat_meat.hpp | 87 + .../include/armadillo_bits/spop_trimat_bones.hpp | 66 + .../include/armadillo_bits/spop_trimat_meat.hpp | 366 + .../include/armadillo_bits/spop_var_bones.hpp | 64 + .../include/armadillo_bits/spop_var_meat.hpp | 414 + .../include/armadillo_bits/spop_vecnorm_bones.hpp | 52 + .../include/armadillo_bits/spop_vecnorm_meat.hpp | 209 + .../armadillo_bits/spop_vectorise_bones.hpp | 58 + .../include/armadillo_bits/spop_vectorise_meat.hpp | 126 + .../armadillo_bits/spsolve_factoriser_bones.hpp | 57 + .../armadillo_bits/spsolve_factoriser_meat.hpp | 289 + src/armadillo/include/armadillo_bits/strip.hpp | 231 + .../include/armadillo_bits/subview_bones.hpp | 673 ++ .../include/armadillo_bits/subview_cube_bones.hpp | 248 + .../armadillo_bits/subview_cube_each_bones.hpp | 161 + .../armadillo_bits/subview_cube_each_meat.hpp | 1035 ++ .../include/armadillo_bits/subview_cube_meat.hpp | 2722 +++++ .../armadillo_bits/subview_cube_slices_bones.hpp | 92 + .../armadillo_bits/subview_cube_slices_meat.hpp | 555 + .../include/armadillo_bits/subview_each_bones.hpp | 166 + .../include/armadillo_bits/subview_each_meat.hpp | 1404 +++ .../include/armadillo_bits/subview_elem1_bones.hpp | 109 + .../include/armadillo_bits/subview_elem1_meat.hpp | 953 ++ .../include/armadillo_bits/subview_elem2_bones.hpp | 112 + .../include/armadillo_bits/subview_elem2_meat.hpp | 873 ++ .../include/armadillo_bits/subview_field_bones.hpp | 95 + .../include/armadillo_bits/subview_field_meat.hpp | 558 + .../include/armadillo_bits/subview_meat.hpp | 4974 +++++++++ .../include/armadillo_bits/sym_helper.hpp | 485 + src/armadillo/include/armadillo_bits/traits.hpp | 1315 +++ .../include/armadillo_bits/translate_arpack.hpp | 114 + .../include/armadillo_bits/translate_atlas.hpp | 282 + .../include/armadillo_bits/translate_blas.hpp | 261 + .../include/armadillo_bits/translate_fftw3.hpp | 106 + .../include/armadillo_bits/translate_lapack.hpp | 1347 +++ .../include/armadillo_bits/translate_superlu.hpp | 348 + .../include/armadillo_bits/trimat_helper.hpp | 165 + .../include/armadillo_bits/typedef_elem.hpp | 175 + .../include/armadillo_bits/typedef_elem_check.hpp | 48 + .../include/armadillo_bits/typedef_mat.hpp | 144 + .../include/armadillo_bits/typedef_mat_fixed.hpp | 326 + src/armadillo/include/armadillo_bits/unwrap.hpp | 3421 +++++++ .../include/armadillo_bits/unwrap_cube.hpp | 133 + .../include/armadillo_bits/unwrap_spmat.hpp | 196 + .../include/armadillo_bits/upgrade_val.hpp | 161 + .../include/armadillo_bits/wall_clock_bones.hpp | 43 + .../include/armadillo_bits/wall_clock_meat.hpp | 72 + .../include/armadillo_bits/xtrans_mat_bones.hpp | 56 + .../include/armadillo_bits/xtrans_mat_meat.hpp | 87 + .../include/armadillo_bits/xvec_htrans_bones.hpp | 54 + .../include/armadillo_bits/xvec_htrans_meat.hpp | 90 + 622 files changed, 202027 insertions(+) create mode 100644 src/armadillo/include/armadillo_bits/BaseCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/BaseCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Base_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Base_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Col_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Col_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/CubeToMatOp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/CubeToMatOp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Cube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Cube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/GenCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/GenCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Gen_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Gen_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/GlueCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/GlueCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Glue_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Glue_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/MapMat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/MapMat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Mat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Mat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/OpCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/OpCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Op_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Op_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/Proxy.hpp create mode 100644 src/armadillo/include/armadillo_bits/ProxyCube.hpp create mode 100644 src/armadillo/include/armadillo_bits/Row_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/Row_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SizeCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SizeCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SizeMat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SizeMat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpBase_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpBase_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpCol_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpCol_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpGlue_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpGlue_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpMat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpMat_iterators_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpMat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpOp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpOp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpProxy.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpRow_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpRow_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpSubview_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpSubview_col_list_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpSubview_col_list_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpSubview_iterators_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpSubview_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpToDGlue_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpToDGlue_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpToDOp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpToDOp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpValProxy_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/SpValProxy_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/access.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_cmath.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_config.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_forward.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_ostream_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_ostream_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_rel_comparators.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_rng.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_rng_cxx03.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_static_check.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_str.hpp create mode 100644 src/armadillo/include/armadillo_bits/arma_version.hpp create mode 100644 src/armadillo/include/armadillo_bits/arrayops_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/arrayops_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/auxlib_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/auxlib_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/band_helper.hpp create mode 100644 src/armadillo/include/armadillo_bits/compiler_check.hpp create mode 100644 src/armadillo/include/armadillo_bits/compiler_setup.hpp create mode 100644 src/armadillo/include/armadillo_bits/compiler_setup_post.hpp create mode 100644 src/armadillo/include/armadillo_bits/cond_rel_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/cond_rel_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/config.hpp create mode 100644 src/armadillo/include/armadillo_bits/config.hpp.cmake create mode 100644 src/armadillo/include/armadillo_bits/constants.hpp create mode 100644 src/armadillo/include/armadillo_bits/constants_old.hpp create mode 100644 src/armadillo/include/armadillo_bits/csv_name.hpp create mode 100644 src/armadillo/include/armadillo_bits/debug.hpp create mode 100644 src/armadillo/include/armadillo_bits/def_arpack.hpp create mode 100644 src/armadillo/include/armadillo_bits/def_atlas.hpp create mode 100644 src/armadillo/include/armadillo_bits/def_blas.hpp create mode 100644 src/armadillo/include/armadillo_bits/def_fftw3.hpp create mode 100644 src/armadillo/include/armadillo_bits/def_lapack.hpp create mode 100644 src/armadillo/include/armadillo_bits/def_superlu.hpp create mode 100644 src/armadillo/include/armadillo_bits/diagmat_proxy.hpp create mode 100644 src/armadillo/include/armadillo_bits/diagview_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/diagview_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/diskio_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/diskio_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/distr_param.hpp create mode 100644 src/armadillo/include/armadillo_bits/eGlueCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/eGlueCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/eGlue_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/eGlue_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/eOpCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/eOpCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/eOp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/eOp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/eglue_core_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/eglue_core_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/eop_aux.hpp create mode 100644 src/armadillo/include/armadillo_bits/eop_core_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/eop_core_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fft_engine_fftw3.hpp create mode 100644 src/armadillo/include/armadillo_bits/fft_engine_kissfft.hpp create mode 100644 src/armadillo/include/armadillo_bits/field_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/field_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fill.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_accu.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_all.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_any.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_approx_equal.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_as_scalar.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_chi2rnd.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_chol.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_clamp.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_cond_rcond.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_conv.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_conv_to.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_cor.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_cov.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_cross.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_cumprod.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_cumsum.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_det.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_diagmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_diags_spdiags.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_diagvec.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_diff.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_dot.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_eig_gen.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_eig_pair.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_eig_sym.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_eigs_gen.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_eigs_sym.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_elem.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_eps.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_expmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_eye.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_fft.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_fft2.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_find.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_find_unique.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_flip.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_hess.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_hist.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_histc.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_index_max.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_index_min.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_inplace_strans.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_inplace_trans.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_interp1.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_interp2.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_intersect.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_inv.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_inv_sympd.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_join.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_kmeans.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_kron.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_log_det.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_log_normpdf.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_logmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_lu.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_max.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_mean.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_median.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_min.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_misc.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_mvnrnd.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_n_unique.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_nonzeros.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_norm.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_normalise.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_normcdf.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_normpdf.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_numel.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_ones.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_orth_null.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_pinv.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_polyfit.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_polyval.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_powext.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_powmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_princomp.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_prod.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_qr.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_quantile.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_qz.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_randg.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_randi.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_randn.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_randperm.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_randu.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_range.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_rank.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_regspace.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_repelem.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_repmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_reshape.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_resize.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_reverse.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_roots.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_schur.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_shift.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_shuffle.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_size.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_solve.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_sort.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_sort_index.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_speye.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_spones.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_sprandn.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_sprandu.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_spsolve.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_sqrtmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_stddev.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_strans.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_sum.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_svd.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_svds.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_sylvester.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_symmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_toeplitz.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trace.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trans.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trapz.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trig.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trimat.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trimat_ind.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trunc_exp.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_trunc_log.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_unique.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_var.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_vecnorm.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_vectorise.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_wishrnd.hpp create mode 100644 src/armadillo/include/armadillo_bits/fn_zeros.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_affmul_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_affmul_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_atan2_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_atan2_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_conv_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_conv_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_cor_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_cor_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_cov_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_cov_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_cross_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_cross_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_hist_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_hist_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_histc_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_histc_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_hypot_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_hypot_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_intersect_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_intersect_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_join_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_join_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_kron_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_kron_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_max_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_max_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_min_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_min_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_mixed_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_mixed_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_mvnrnd_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_mvnrnd_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_polyfit_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_polyfit_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_polyval_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_polyval_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_powext_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_powext_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_quantile_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_quantile_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_relational_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_relational_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_solve_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_solve_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_times_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_times_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_times_misc_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_times_misc_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_toeplitz_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_toeplitz_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_trapz_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/glue_trapz_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/gmm_diag_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/gmm_diag_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/gmm_full_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/gmm_full_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/gmm_misc_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/gmm_misc_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/hdf5_misc.hpp create mode 100644 src/armadillo/include/armadillo_bits/hdf5_name.hpp create mode 100644 src/armadillo/include/armadillo_bits/include_hdf5.hpp create mode 100644 src/armadillo/include/armadillo_bits/include_superlu.hpp create mode 100644 src/armadillo/include/armadillo_bits/injector_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/injector_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/memory.hpp create mode 100644 src/armadillo/include/armadillo_bits/mp_misc.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtGlueCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtGlueCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtGlue_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtGlue_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtOpCube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtOpCube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtOp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtOp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtSpGlue_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtSpGlue_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtSpOp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/mtSpOp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/mul_gemm.hpp create mode 100644 src/armadillo/include/armadillo_bits/mul_gemm_mixed.hpp create mode 100644 src/armadillo/include/armadillo_bits/mul_gemv.hpp create mode 100644 src/armadillo/include/armadillo_bits/mul_herk.hpp create mode 100644 src/armadillo/include/armadillo_bits/mul_syrk.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_EigsSelect.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SortEigenvalue.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_TridiagEigen_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_TridiagEigen_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/newarp_cx_attrib.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_all_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_all_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_any_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_any_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_chi2rnd_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_chi2rnd_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_chol_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_chol_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_clamp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_clamp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_col_as_mat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_col_as_mat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cond_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cond_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cor_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cor_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cov_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cov_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cumprod_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cumprod_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cumsum_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cumsum_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cx_scalar_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_cx_scalar_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_det_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_det_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_diagmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_diagmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_diagvec_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_diagvec_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_diff_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_diff_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_dot_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_dot_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_dotext_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_dotext_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_expmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_expmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_fft_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_fft_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_find_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_find_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_find_unique_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_find_unique_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_flip_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_flip_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_hist_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_hist_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_htrans_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_htrans_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_index_max_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_index_max_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_index_min_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_index_min_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_inv_gen_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_inv_gen_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_inv_spd_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_inv_spd_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_log_det_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_log_det_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_logmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_logmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_max_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_max_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_mean_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_mean_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_median_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_median_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_min_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_min_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_misc_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_misc_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_nonzeros_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_nonzeros_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_norm2est_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_norm2est_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_norm_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_norm_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_normalise_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_normalise_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_orth_null_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_orth_null_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_pinv_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_pinv_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_powmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_powmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_princomp_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_princomp_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_prod_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_prod_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_range_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_range_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_rank_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_rank_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_rcond_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_rcond_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_relational_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_relational_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_repelem_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_repelem_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_repmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_repmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_reshape_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_reshape_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_resize_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_resize_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_reverse_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_reverse_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_roots_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_roots_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_row_as_mat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_row_as_mat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_shift_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_shift_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_shuffle_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_shuffle_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sort_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sort_index_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sort_index_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sort_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sp_minus_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sp_minus_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sp_plus_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sp_plus_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sqrtmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_stddev_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_stddev_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_strans_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_strans_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sum_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_sum_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_symmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_symmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_toeplitz_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_toeplitz_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_trimat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_trimat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_unique_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_unique_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_var_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_var_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_vecnorm_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_vecnorm_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_vectorise_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_vectorise_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_wishrnd_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/op_wishrnd_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_cube_div.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_cube_minus.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_cube_plus.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_cube_relational.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_cube_schur.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_cube_times.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_div.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_minus.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_ostream.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_plus.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_relational.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_schur.hpp create mode 100644 src/armadillo/include/armadillo_bits/operator_times.hpp create mode 100644 src/armadillo/include/armadillo_bits/podarray_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/podarray_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/promote_type.hpp create mode 100644 src/armadillo/include/armadillo_bits/restrictors.hpp create mode 100644 src/armadillo/include/armadillo_bits/running_stat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/running_stat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/running_stat_vec_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/running_stat_vec_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/sp_auxlib_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/sp_auxlib_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/span.hpp create mode 100644 src/armadillo/include/armadillo_bits/spdiagview_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spdiagview_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_join_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_join_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_kron_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_kron_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_max_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_max_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_merge_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_merge_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_min_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_min_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_minus_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_minus_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_plus_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_plus_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_relational_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_relational_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_schur_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_schur_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_times_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spglue_times_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_diagmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_diagmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_htrans_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_htrans_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_max_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_max_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_mean_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_mean_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_min_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_min_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_misc_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_misc_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_norm_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_norm_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_normalise_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_normalise_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_repmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_repmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_reverse_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_reverse_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_strans_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_strans_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_sum_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_sum_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_symmat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_symmat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_trimat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_trimat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_var_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_var_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_vecnorm_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_vecnorm_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_vectorise_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spop_vectorise_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/spsolve_factoriser_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/spsolve_factoriser_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/strip.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_cube_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_cube_each_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_cube_each_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_cube_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_cube_slices_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_cube_slices_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_each_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_each_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_elem1_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_elem1_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_elem2_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_elem2_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_field_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_field_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/subview_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/sym_helper.hpp create mode 100644 src/armadillo/include/armadillo_bits/traits.hpp create mode 100644 src/armadillo/include/armadillo_bits/translate_arpack.hpp create mode 100644 src/armadillo/include/armadillo_bits/translate_atlas.hpp create mode 100644 src/armadillo/include/armadillo_bits/translate_blas.hpp create mode 100644 src/armadillo/include/armadillo_bits/translate_fftw3.hpp create mode 100644 src/armadillo/include/armadillo_bits/translate_lapack.hpp create mode 100644 src/armadillo/include/armadillo_bits/translate_superlu.hpp create mode 100644 src/armadillo/include/armadillo_bits/trimat_helper.hpp create mode 100644 src/armadillo/include/armadillo_bits/typedef_elem.hpp create mode 100644 src/armadillo/include/armadillo_bits/typedef_elem_check.hpp create mode 100644 src/armadillo/include/armadillo_bits/typedef_mat.hpp create mode 100644 src/armadillo/include/armadillo_bits/typedef_mat_fixed.hpp create mode 100644 src/armadillo/include/armadillo_bits/unwrap.hpp create mode 100644 src/armadillo/include/armadillo_bits/unwrap_cube.hpp create mode 100644 src/armadillo/include/armadillo_bits/unwrap_spmat.hpp create mode 100644 src/armadillo/include/armadillo_bits/upgrade_val.hpp create mode 100644 src/armadillo/include/armadillo_bits/wall_clock_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/wall_clock_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/xtrans_mat_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/xtrans_mat_meat.hpp create mode 100644 src/armadillo/include/armadillo_bits/xvec_htrans_bones.hpp create mode 100644 src/armadillo/include/armadillo_bits/xvec_htrans_meat.hpp (limited to 'src/armadillo/include/armadillo_bits') diff --git a/src/armadillo/include/armadillo_bits/BaseCube_bones.hpp b/src/armadillo/include/armadillo_bits/BaseCube_bones.hpp new file mode 100644 index 0000000..15d6a4c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/BaseCube_bones.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup BaseCube +//! @{ + + + +template +struct BaseCube_eval_Cube + { + arma_warn_unused arma_inline const derived& eval() const; + }; + + +template +struct BaseCube_eval_expr + { + arma_warn_unused inline Cube eval() const; //!< force the immediate evaluation of a delayed expression + }; + + +template +struct BaseCube_eval {}; + +template +struct BaseCube_eval { typedef BaseCube_eval_Cube result; }; + +template +struct BaseCube_eval { typedef BaseCube_eval_expr result; }; + + + +//! Analog of the Base class, intended for cubes +template +struct BaseCube + : public BaseCube_eval::value>::result + { + arma_inline const derived& get_ref() const; + + arma_cold inline void print( const std::string extra_text = "") const; + arma_cold inline void print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void raw_print( const std::string extra_text = "") const; + arma_cold inline void raw_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void brief_print( const std::string extra_text = "") const; + arma_cold inline void brief_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_warn_unused inline elem_type min() const; + arma_warn_unused inline elem_type max() const; + + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + arma_warn_unused inline bool is_zero(const typename get_pod_type::result tol = 0) const; + + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_finite() const; + + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; + + arma_warn_unused inline const CubeToMatOp row_as_mat(const uword in_row) const; + arma_warn_unused inline const CubeToMatOp col_as_mat(const uword in_col) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/BaseCube_meat.hpp b/src/armadillo/include/armadillo_bits/BaseCube_meat.hpp new file mode 100644 index 0000000..2d0df91 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/BaseCube_meat.hpp @@ -0,0 +1,498 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup BaseCube +//! @{ + + + +template +arma_inline +const derived& +BaseCube::get_ref() const + { + return static_cast(*this); + } + + + +template +inline +void +BaseCube::print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, true); + } + + + +template +inline +void +BaseCube::print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, true); + } + + + +template +inline +void +BaseCube::raw_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, false); + } + + + +template +inline +void +BaseCube::raw_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, false); + } + + + +template +inline +void +BaseCube::brief_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::brief_print(get_cout_stream(), tmp.M); + } + + + +template +inline +void +BaseCube::brief_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::brief_print(user_stream, tmp.M); + } + + + +template +inline +elem_type +BaseCube::min() const + { + return op_min::min( (*this).get_ref() ); + } + + + +template +inline +elem_type +BaseCube::max() const + { + return op_max::max( (*this).get_ref() ); + } + + + +template +inline +uword +BaseCube::index_min() const + { + const ProxyCube P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_min(): object has no elements"); + } + else + { + op_min::min_with_index(P, index); + } + + return index; + } + + + +template +inline +uword +BaseCube::index_max() const + { + const ProxyCube P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_max(): object has no elements"); + } + else + { + op_max::max_with_index(P, index); + } + + return index; + } + + + +template +inline +bool +BaseCube::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + arma_debug_check( (tol < T(0)), "is_zero(): parameter 'tol' must be >= 0" ); + + if(ProxyCube::use_at || is_Cube::stored_type>::value) + { + const unwrap_cube U( (*this).get_ref() ); + + return arrayops::is_zero( U.M.memptr(), U.M.n_elem, tol ); + } + + const ProxyCube P( (*this).get_ref() ); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { return false; } + + const typename ProxyCube::ea_type Pea = P.get_ea(); + + if(is_cx::yes) + { + for(uword i=0; i tol) { return false; } + if(eop_aux::arma_abs(val_imag) > tol) { return false; } + } + } + else // not complex + { + for(uword i=0; i < n_elem; ++i) + { + if(eop_aux::arma_abs(Pea[i]) > tol) { return false; } + } + } + + return true; + } + + + +template +inline +bool +BaseCube::is_empty() const + { + arma_extra_debug_sigprint(); + + const ProxyCube P( (*this).get_ref() ); + + return (P.get_n_elem() == uword(0)); + } + + + +template +inline +bool +BaseCube::is_finite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Cube::stored_type>::value) + { + const unwrap_cube U( (*this).get_ref() ); + + return arrayops::is_finite( U.M.memptr(), U.M.n_elem ); + } + else + { + const ProxyCube P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s +inline +bool +BaseCube::has_inf() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Cube::stored_type>::value) + { + const unwrap_cube U( (*this).get_ref() ); + + return arrayops::has_inf( U.M.memptr(), U.M.n_elem ); + } + else + { + const ProxyCube P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s +inline +bool +BaseCube::has_nan() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Cube::stored_type>::value) + { + const unwrap_cube U( (*this).get_ref() ); + + return arrayops::has_nan( U.M.memptr(), U.M.n_elem ); + } + else + { + const ProxyCube P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s +inline +bool +BaseCube::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Cube::stored_type>::value) + { + const unwrap_cube U( (*this).get_ref() ); + + return (arrayops::is_finite( U.M.memptr(), U.M.n_elem ) == false); + } + else + { + const ProxyCube P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s +inline +const CubeToMatOp +BaseCube::row_as_mat(const uword in_row) const + { + return CubeToMatOp( (*this).get_ref(), in_row ); + } + + + +template +inline +const CubeToMatOp +BaseCube::col_as_mat(const uword in_col) const + { + return CubeToMatOp( (*this).get_ref(), in_col ); + } + + + +// +// extra functions defined in BaseCube_eval_Cube + +template +arma_inline +const derived& +BaseCube_eval_Cube::eval() const + { + arma_extra_debug_sigprint(); + + return static_cast(*this); + } + + + +// +// extra functions defined in BaseCube_eval_expr + +template +inline +Cube +BaseCube_eval_expr::eval() const + { + arma_extra_debug_sigprint(); + + return Cube( static_cast(*this) ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Base_bones.hpp b/src/armadillo/include/armadillo_bits/Base_bones.hpp new file mode 100644 index 0000000..ac94785 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Base_bones.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Base +//! @{ + + + +template +struct Base_extra_yes + { + arma_warn_unused inline const Op i() const; //!< matrix inverse + + arma_warn_unused inline bool is_sympd() const; + arma_warn_unused inline bool is_sympd(typename get_pod_type::result tol) const; + }; + + +template +struct Base_extra_no + { + }; + + +template +struct Base_extra {}; + +template +struct Base_extra { typedef Base_extra_yes result; }; + +template +struct Base_extra { typedef Base_extra_no result; }; + + + +template +struct Base_eval_Mat + { + arma_warn_unused arma_inline const derived& eval() const; + }; + + +template +struct Base_eval_expr + { + arma_warn_unused inline Mat eval() const; //!< force the immediate evaluation of a delayed expression + }; + + +template +struct Base_eval {}; + +template +struct Base_eval { typedef Base_eval_Mat result; }; + +template +struct Base_eval { typedef Base_eval_expr result; }; + + + +template +struct Base_trans_cx + { + arma_warn_unused arma_inline const Op t() const; + arma_warn_unused arma_inline const Op ht() const; + arma_warn_unused arma_inline const Op st() const; // simple transpose: no complex conjugates + }; + + +template +struct Base_trans_default + { + arma_warn_unused arma_inline const Op t() const; + arma_warn_unused arma_inline const Op ht() const; + arma_warn_unused arma_inline const Op st() const; // return op_htrans instead of op_strans, as it's handled better by matrix multiplication code + }; + + +template +struct Base_trans {}; + +template +struct Base_trans { typedef Base_trans_cx result; }; + +template +struct Base_trans { typedef Base_trans_default result; }; + + + +//! Class for static polymorphism, modelled after the "Curiously Recurring Template Pattern" (CRTP). +//! Used for type-safe downcasting in functions that restrict their input(s) to be classes that are +//! derived from Base (eg. Mat, Op, Glue, diagview, subview). +//! A Base object can be converted to a Mat object by the unwrap class. + +template +struct Base + : public Base_extra::value>::result + , public Base_eval::value>::result + , public Base_trans::value>::result + { + arma_inline const derived& get_ref() const; + + arma_cold inline void print( const std::string extra_text = "") const; + arma_cold inline void print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void raw_print( const std::string extra_text = "") const; + arma_cold inline void raw_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void brief_print( const std::string extra_text = "") const; + arma_cold inline void brief_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_warn_unused inline elem_type min() const; + arma_warn_unused inline elem_type max() const; + + inline elem_type min(uword& index_of_min_val) const; + inline elem_type max(uword& index_of_max_val) const; + + inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; + inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; + + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + arma_warn_unused inline bool is_symmetric() const; + arma_warn_unused inline bool is_symmetric(const typename get_pod_type::result tol) const; + + arma_warn_unused inline bool is_hermitian() const; + arma_warn_unused inline bool is_hermitian(const typename get_pod_type::result tol) const; + + arma_warn_unused inline bool is_zero(const typename get_pod_type::result tol = 0) const; + + arma_warn_unused inline bool is_trimatu() const; + arma_warn_unused inline bool is_trimatl() const; + arma_warn_unused inline bool is_diagmat() const; + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_square() const; + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_colvec() const; + arma_warn_unused inline bool is_rowvec() const; + arma_warn_unused inline bool is_finite() const; + + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; + + arma_warn_unused inline const Op as_col() const; + arma_warn_unused inline const Op as_row() const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Base_meat.hpp b/src/armadillo/include/armadillo_bits/Base_meat.hpp new file mode 100644 index 0000000..646f33a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Base_meat.hpp @@ -0,0 +1,1031 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Base +//! @{ + + + +template +arma_inline +const derived& +Base::get_ref() const + { + return static_cast(*this); + } + + + +template +inline +void +Base::print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, true); + } + + + +template +inline +void +Base::print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, true); + } + + + +template +inline +void +Base::raw_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, false); + } + + + +template +inline +void +Base::raw_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, false); + } + + + +template +inline +void +Base::brief_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::brief_print(get_cout_stream(), tmp.M); + } + + + +template +inline +void +Base::brief_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::brief_print(user_stream, tmp.M); + } + + + +template +inline +elem_type +Base::min() const + { + return op_min::min( (*this).get_ref() ); + } + + + +template +inline +elem_type +Base::max() const + { + return op_max::max( (*this).get_ref() ); + } + + + +template +inline +elem_type +Base::min(uword& index_of_min_val) const + { + const Proxy P( (*this).get_ref() ); + + return op_min::min_with_index(P, index_of_min_val); + } + + + +template +inline +elem_type +Base::max(uword& index_of_max_val) const + { + const Proxy P( (*this).get_ref() ); + + return op_max::max_with_index(P, index_of_max_val); + } + + + +template +inline +elem_type +Base::min(uword& row_of_min_val, uword& col_of_min_val) const + { + const Proxy P( (*this).get_ref() ); + + uword index = 0; + + const elem_type val = op_min::min_with_index(P, index); + + const uword local_n_rows = P.get_n_rows(); + + row_of_min_val = index % local_n_rows; + col_of_min_val = index / local_n_rows; + + return val; + } + + + +template +inline +elem_type +Base::max(uword& row_of_max_val, uword& col_of_max_val) const + { + const Proxy P( (*this).get_ref() ); + + uword index = 0; + + const elem_type val = op_max::max_with_index(P, index); + + const uword local_n_rows = P.get_n_rows(); + + row_of_max_val = index % local_n_rows; + col_of_max_val = index / local_n_rows; + + return val; + } + + + +template +inline +uword +Base::index_min() const + { + const Proxy P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_min(): object has no elements"); + } + else + { + op_min::min_with_index(P, index); + } + + return index; + } + + + +template +inline +uword +Base::index_max() const + { + const Proxy P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_max(): object has no elements"); + } + else + { + op_max::max_with_index(P, index); + } + + return index; + } + + + +template +inline +bool +Base::is_symmetric() const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U( (*this).get_ref() ); + + const Mat& A = U.M; + + if(A.n_rows != A.n_cols) { return false; } + if(A.n_elem <= 1 ) { return true; } + + const uword N = A.n_rows; + const uword Nm1 = N-1; + + const elem_type* A_col = A.memptr(); + + for(uword j=0; j < Nm1; ++j) + { + const uword jp1 = j+1; + + const elem_type* A_row = &(A.at(j,jp1)); + + for(uword i=jp1; i < N; ++i) + { + if(A_col[i] != (*A_row)) { return false; } + + A_row += N; + } + + A_col += N; + } + + return true; + } + + + +template +inline +bool +Base::is_symmetric(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(tol == T(0)) { return (*this).is_symmetric(); } + + arma_debug_check( (tol < T(0)), "is_symmetric(): parameter 'tol' must be >= 0" ); + + const quasi_unwrap U( (*this).get_ref() ); + + const Mat& A = U.M; + + if(A.n_rows != A.n_cols) { return false; } + if(A.n_elem <= 1 ) { return true; } + + const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) ); + + if(norm_A == T(0)) { return true; } + + const T norm_A_Ast = as_scalar( arma::max(sum(abs(A - A.st()), 1), 0) ); + + return ( (norm_A_Ast / norm_A) <= tol ); + } + + + +template +inline +bool +Base::is_hermitian() const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const quasi_unwrap U( (*this).get_ref() ); + + const Mat& A = U.M; + + if(A.n_rows != A.n_cols) { return false; } + if(A.n_elem == 0 ) { return true; } + + const uword N = A.n_rows; + + const elem_type* A_col = A.memptr(); + + for(uword j=0; j < N; ++j) + { + if( access::tmp_imag(A_col[j]) != T(0) ) { return false; } + + A_col += N; + } + + A_col = A.memptr(); + + const uword Nm1 = N-1; + + for(uword j=0; j < Nm1; ++j) + { + const uword jp1 = j+1; + + const elem_type* A_row = &(A.at(j,jp1)); + + for(uword i=jp1; i < N; ++i) + { + if(A_col[i] != access::alt_conj(*A_row)) { return false; } + + A_row += N; + } + + A_col += N; + } + + return true; + } + + + +template +inline +bool +Base::is_hermitian(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(tol == T(0)) { return (*this).is_hermitian(); } + + arma_debug_check( (tol < T(0)), "is_hermitian(): parameter 'tol' must be >= 0" ); + + const quasi_unwrap U( (*this).get_ref() ); + + const Mat& A = U.M; + + if(A.n_rows != A.n_cols) { return false; } + if(A.n_elem == 0 ) { return true; } + + const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) ); + + if(norm_A == T(0)) { return true; } + + const T norm_A_At = as_scalar( arma::max(sum(abs(A - A.t()), 1), 0) ); + + return ( (norm_A_At / norm_A) <= tol ); + } + + + +template +inline +bool +Base::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + arma_debug_check( (tol < T(0)), "is_zero(): parameter 'tol' must be >= 0" ); + + if(Proxy::use_at || is_Mat::stored_type>::value) + { + const quasi_unwrap U( (*this).get_ref() ); + + return arrayops::is_zero( U.M.memptr(), U.M.n_elem, tol ); + } + + const Proxy P( (*this).get_ref() ); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { return false; } + + const typename Proxy::ea_type Pea = P.get_ea(); + + if(is_cx::yes) + { + for(uword i=0; i tol) { return false; } + if(eop_aux::arma_abs(val_imag) > tol) { return false; } + } + } + else // not complex + { + for(uword i=0; i tol) { return false; } + } + } + + return true; + } + + + +template +inline +bool +Base::is_trimatu() const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U( (*this).get_ref() ); + + if(U.M.n_rows != U.M.n_cols) { return false; } + + if(U.M.n_elem <= 1) { return true; } + + return trimat_helper::is_triu(U.M); + } + + + +template +inline +bool +Base::is_trimatl() const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U( (*this).get_ref() ); + + if(U.M.n_rows != U.M.n_cols) { return false; } + + if(U.M.n_elem <= 1) { return true; } + + return trimat_helper::is_tril(U.M); + } + + + +template +inline +bool +Base::is_diagmat() const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U( (*this).get_ref() ); + + const Mat& A = U.M; + + if(A.n_elem <= 1) { return true; } + + // NOTE: we're NOT assuming the matrix has a square size + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const elem_type* A_mem = A.memptr(); + + if(A_mem[1] != elem_type(0)) { return false; } + + // if we got to this point, do a thorough check + + for(uword A_col=0; A_col < A_n_cols; ++A_col) + { + for(uword A_row=0; A_row < A_n_rows; ++A_row) + { + if( (A_mem[A_row] != elem_type(0)) && (A_row != A_col) ) { return false; } + } + + A_mem += A_n_rows; + } + + return true; + } + + + +template +inline +bool +Base::is_empty() const + { + arma_extra_debug_sigprint(); + + const Proxy P( (*this).get_ref() ); + + return (P.get_n_elem() == uword(0)); + } + + + +template +inline +bool +Base::is_square() const + { + arma_extra_debug_sigprint(); + + const Proxy P( (*this).get_ref() ); + + return (P.get_n_rows() == P.get_n_cols()); + } + + + +template +inline +bool +Base::is_vec() const + { + arma_extra_debug_sigprint(); + + if( (Proxy::is_row) || (Proxy::is_col) || (Proxy::is_xvec) ) { return true; } + + const Proxy P( (*this).get_ref() ); + + return ( (P.get_n_rows() == uword(1)) || (P.get_n_cols() == uword(1)) ); + } + + + +template +inline +bool +Base::is_colvec() const + { + arma_extra_debug_sigprint(); + + if(Proxy::is_col) { return true; } + + const Proxy P( (*this).get_ref() ); + + return (P.get_n_cols() == uword(1)); + } + + + +template +inline +bool +Base::is_rowvec() const + { + arma_extra_debug_sigprint(); + + if(Proxy::is_row) { return true; } + + const Proxy P( (*this).get_ref() ); + + return (P.get_n_rows() == uword(1)); + } + + + +template +inline +bool +Base::is_finite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap U( (*this).get_ref() ); + + return arrayops::is_finite( U.M.memptr(), U.M.n_elem ); + } + else + { + const Proxy P( (*this).get_ref() ); + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i +inline +bool +Base::has_inf() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap U( (*this).get_ref() ); + + return arrayops::has_inf( U.M.memptr(), U.M.n_elem ); + } + else + { + const Proxy P( (*this).get_ref() ); + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i +inline +bool +Base::has_nan() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap U( (*this).get_ref() ); + + return arrayops::has_nan( U.M.memptr(), U.M.n_elem ); + } + else + { + const Proxy P( (*this).get_ref() ); + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i +inline +bool +Base::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap U( (*this).get_ref() ); + + return (arrayops::is_finite( U.M.memptr(), U.M.n_elem ) == false); + } + else + { + const Proxy P( (*this).get_ref() ); + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i +inline +const Op +Base::as_col() const + { + return Op( (*this).get_ref() ); + } + + + +template +inline +const Op +Base::as_row() const + { + return Op( (*this).get_ref() ); + } + + + +// +// extra functions defined in Base_extra_yes + +template +inline +const Op +Base_extra_yes::i() const + { + return Op(static_cast(*this)); + } + + + +template +inline +bool +Base_extra_yes::is_sympd() const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Mat X = static_cast(*this); + + // default value for tol + const T tol = T(100) * std::numeric_limits::epsilon() * norm(X, "fro"); + + if(X.is_hermitian(tol) == false) { return false; } + + if(X.is_empty()) { return false; } + + X.diag() -= elem_type(tol); + + return auxlib::chol_simple(X); + } + + + +template +inline +bool +Base_extra_yes::is_sympd(typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + arma_debug_check( (tol < T(0)), "is_sympd(): parameter 'tol' must be >= 0" ); + + Mat X = static_cast(*this); + + if(X.is_hermitian(tol) == false) { return false; } + + if(X.is_empty()) { return false; } + + X.diag() -= elem_type(tol); + + return auxlib::chol_simple(X); + } + + + +// +// extra functions defined in Base_eval_Mat + +template +arma_inline +const derived& +Base_eval_Mat::eval() const + { + arma_extra_debug_sigprint(); + + return static_cast(*this); + } + + + +// +// extra functions defined in Base_eval_expr + +template +inline +Mat +Base_eval_expr::eval() const + { + arma_extra_debug_sigprint(); + + return Mat( static_cast(*this) ); + } + + + +// +// extra functions defined in Base_trans_cx + +template +arma_inline +const Op +Base_trans_cx::t() const + { + return Op( static_cast(*this) ); + } + + + +template +arma_inline +const Op +Base_trans_cx::ht() const + { + return Op( static_cast(*this) ); + } + + + +template +arma_inline +const Op +Base_trans_cx::st() const + { + return Op( static_cast(*this) ); + } + + + +// +// extra functions defined in Base_trans_default + +template +arma_inline +const Op +Base_trans_default::t() const + { + return Op( static_cast(*this) ); + } + + + +template +arma_inline +const Op +Base_trans_default::ht() const + { + return Op( static_cast(*this) ); + } + + + +template +arma_inline +const Op +Base_trans_default::st() const + { + return Op( static_cast(*this) ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Col_bones.hpp b/src/armadillo/include/armadillo_bits/Col_bones.hpp new file mode 100644 index 0000000..b3f0ab6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Col_bones.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Col +//! @{ + +//! Class for column vectors (matrices with only one column) + +template +class Col : public Mat + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_col = true; + static constexpr bool is_row = false; + static constexpr bool is_xvec = false; + + inline Col(); + inline Col(const Col& X); + + inline explicit Col(const uword n_elem); + inline explicit Col(const uword in_rows, const uword in_cols); + inline explicit Col(const SizeMat& s); + + template inline explicit Col(const uword n_elem, const arma_initmode_indicator&); + template inline explicit Col(const uword in_rows, const uword in_cols, const arma_initmode_indicator&); + template inline explicit Col(const SizeMat& s, const arma_initmode_indicator&); + + template inline Col(const uword n_elem, const fill::fill_class& f); + template inline Col(const uword in_rows, const uword in_cols, const fill::fill_class& f); + template inline Col(const SizeMat& s, const fill::fill_class& f); + + inline Col(const uword N, const fill::scalar_holder f); + inline Col(const uword in_rows, const uword in_cols, const fill::scalar_holder f); + inline Col(const SizeMat& s, const fill::scalar_holder f); + + inline Col(const char* text); + inline Col& operator=(const char* text); + + inline Col(const std::string& text); + inline Col& operator=(const std::string& text); + + inline Col(const std::vector& x); + inline Col& operator=(const std::vector& x); + + inline Col(const std::initializer_list& list); + inline Col& operator=(const std::initializer_list& list); + + inline Col(Col&& m); + inline Col& operator=(Col&& m); + + // inline Col(Mat&& m); + // inline Col& operator=(Mat&& m); + + inline Col& operator=(const eT val); + inline Col& operator=(const Col& m); + + template inline Col(const Base& X); + template inline Col& operator=(const Base& X); + + template inline explicit Col(const SpBase& X); + template inline Col& operator=(const SpBase& X); + + inline Col( eT* aux_mem, const uword aux_length, const bool copy_aux_mem = true, const bool strict = false); + inline Col(const eT* aux_mem, const uword aux_length); + + template + inline explicit Col(const Base& A, const Base& B); + + template inline Col(const BaseCube& X); + template inline Col& operator=(const BaseCube& X); + + inline Col(const subview_cube& X); + inline Col& operator=(const subview_cube& X); + + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const eT val); + + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; + + arma_warn_unused arma_inline const Op,op_strans> as_row() const; + + arma_inline subview_col row(const uword row_num); + arma_inline const subview_col row(const uword row_num) const; + + using Mat::rows; + using Mat::operator(); + + arma_inline subview_col rows(const uword in_row1, const uword in_row2); + arma_inline const subview_col rows(const uword in_row1, const uword in_row2) const; + + arma_inline subview_col subvec(const uword in_row1, const uword in_row2); + arma_inline const subview_col subvec(const uword in_row1, const uword in_row2) const; + + arma_inline subview_col rows(const span& row_span); + arma_inline const subview_col rows(const span& row_span) const; + + arma_inline subview_col subvec(const span& row_span); + arma_inline const subview_col subvec(const span& row_span) const; + + arma_inline subview_col operator()(const span& row_span); + arma_inline const subview_col operator()(const span& row_span) const; + + arma_inline subview_col subvec(const uword start_row, const SizeMat& s); + arma_inline const subview_col subvec(const uword start_row, const SizeMat& s) const; + + arma_inline subview_col head(const uword N); + arma_inline const subview_col head(const uword N) const; + + arma_inline subview_col tail(const uword N); + arma_inline const subview_col tail(const uword N) const; + + arma_inline subview_col head_rows(const uword N); + arma_inline const subview_col head_rows(const uword N) const; + + arma_inline subview_col tail_rows(const uword N); + arma_inline const subview_col tail_rows(const uword N) const; + + + inline void shed_row (const uword row_num); + inline void shed_rows(const uword in_row1, const uword in_row2); + + template inline void shed_rows(const Base& indices); + + arma_deprecated inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero); + inline void insert_rows(const uword row_num, const uword N); + + template inline void insert_rows(const uword row_num, const Base& X); + + + arma_warn_unused arma_inline eT& at(const uword i); + arma_warn_unused arma_inline const eT& at(const uword i) const; + + arma_warn_unused arma_inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at(const uword in_row, const uword in_col) const; + + + typedef eT* row_iterator; + typedef const eT* const_row_iterator; + + inline row_iterator begin_row(const uword row_num); + inline const_row_iterator begin_row(const uword row_num) const; + + inline row_iterator end_row (const uword row_num); + inline const_row_iterator end_row (const uword row_num) const; + + + template class fixed; + + + protected: + + inline Col(const arma_fixed_indicator&, const uword in_n_elem, const eT* in_mem); + + + public: + + #if defined(ARMA_EXTRA_COL_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_COL_PROTO) + #endif + }; + + + +template +template +class Col::fixed : public Col + { + private: + + static constexpr bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); + + arma_align_mem eT mem_local_extra[ (use_extra) ? fixed_n_elem : 1 ]; + + + public: + + typedef fixed Col_fixed_type; + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_col = true; + static constexpr bool is_row = false; + static constexpr bool is_xvec = false; + + static const uword n_rows; // value provided below the class definition + static const uword n_cols; // value provided below the class definition + static const uword n_elem; // value provided below the class definition + + arma_inline fixed(); + arma_inline fixed(const fixed& X); + inline fixed(const subview_cube& X); + + inline fixed(const fill::scalar_holder f); + template inline fixed(const fill::fill_class& f); + template inline fixed(const Base& A); + template inline fixed(const Base& A, const Base& B); + + inline fixed(const eT* aux_mem); + + inline fixed(const char* text); + inline fixed(const std::string& text); + + template inline Col& operator=(const Base& A); + + inline Col& operator=(const eT val); + inline Col& operator=(const char* text); + inline Col& operator=(const std::string& text); + inline Col& operator=(const subview_cube& X); + + using Col::operator(); + + inline fixed(const std::initializer_list& list); + inline Col& operator=(const std::initializer_list& list); + + arma_inline Col& operator=(const fixed& X); + + #if defined(ARMA_GOOD_COMPILER) + template inline Col& operator=(const eOp& X); + template inline Col& operator=(const eGlue& X); + #endif + + arma_warn_unused arma_inline const Op< Col_fixed_type, op_htrans > t() const; + arma_warn_unused arma_inline const Op< Col_fixed_type, op_htrans > ht() const; + arma_warn_unused arma_inline const Op< Col_fixed_type, op_strans > st() const; + + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; + + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; + + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; + + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; + + inline const Col& fill(const eT val); + inline const Col& zeros(); + inline const Col& ones(); + }; + + + +// these definitions are outside of the class due to bizarre C++ rules; +// C++17 has inline variables to address this shortcoming + +template +template +const uword Col::fixed::n_rows = fixed_n_elem; + +template +template +const uword Col::fixed::n_cols = 1u; + +template +template +const uword Col::fixed::n_elem = fixed_n_elem; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Col_meat.hpp b/src/armadillo/include/armadillo_bits/Col_meat.hpp new file mode 100644 index 0000000..a6a1945 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Col_meat.hpp @@ -0,0 +1,1888 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Col +//! @{ + + +//! construct an empty column vector +template +inline +Col::Col() + : Mat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Col::Col(const Col& X) + : Mat(arma_vec_indicator(), X.n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + + arrayops::copy((*this).memptr(), X.memptr(), X.n_elem); + } + + + +//! construct a column vector with the specified number of n_elem +template +inline +Col::Col(const uword in_n_elem) + : Mat(arma_vec_indicator(), in_n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +template +inline +Col::Col(const uword in_n_rows, const uword in_n_cols) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +template +inline +Col::Col(const SizeMat& s) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Col::Col(const uword in_n_elem, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), in_n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + + if(do_zeros) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Col::Col(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Col::Col(const SizeMat& s, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +template +template +inline +Col::Col(const uword in_n_elem, const fill::fill_class& f) + : Mat(arma_vec_indicator(), in_n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + + (*this).fill(f); + } + + + +template +template +inline +Col::Col(const uword in_n_rows, const uword in_n_cols, const fill::fill_class& f) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + (*this).fill(f); + } + + + +template +template +inline +Col::Col(const SizeMat& s, const fill::fill_class& f) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + (*this).fill(f); + } + + + +template +inline +Col::Col(const uword in_n_elem, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), in_n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + + (*this).fill(f.scalar); + } + + + +template +inline +Col::Col(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + (*this).fill(f.scalar); + } + + + +template +inline +Col::Col(const SizeMat& s, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + (*this).fill(f.scalar); + } + + + +template +inline +Col::Col(const char* text) + : Mat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + (*this).operator=(text); + } + + + +template +inline +Col& +Col::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + Mat tmp(text); + + arma_debug_check( ((tmp.n_elem > 0) && (tmp.is_vec() == false)), "Mat::init(): requested size is not compatible with column vector layout" ); + + access::rw(tmp.n_rows) = tmp.n_elem; + access::rw(tmp.n_cols) = 1; + + (*this).steal_mem(tmp); + + return *this; + } + + + +template +inline +Col::Col(const std::string& text) + : Mat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + (*this).operator=(text); + } + + + +template +inline +Col& +Col::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + Mat tmp(text); + + arma_debug_check( ((tmp.n_elem > 0) && (tmp.is_vec() == false)), "Mat::init(): requested size is not compatible with column vector layout" ); + + access::rw(tmp.n_rows) = tmp.n_elem; + access::rw(tmp.n_cols) = 1; + + (*this).steal_mem(tmp); + + return *this; + } + + + +//! create a column vector from std::vector +template +inline +Col::Col(const std::vector& x) + : Mat(arma_vec_indicator(), uword(x.size()), 1, 1) + { + arma_extra_debug_sigprint_this(this); + + const uword N = uword(x.size()); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } + } + + + +//! create a column vector from std::vector +template +inline +Col& +Col::operator=(const std::vector& x) + { + arma_extra_debug_sigprint(); + + const uword N = uword(x.size()); + + Mat::init_warm(N, 1); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } + + return *this; + } + + + +template +inline +Col::Col(const std::initializer_list& list) + : Mat(arma_vec_indicator(), uword(list.size()), 1, 1) + { + arma_extra_debug_sigprint_this(this); + + const uword N = uword(list.size()); + + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + } + + + +template +inline +Col& +Col::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + Mat::init_warm(N, 1); + + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + + return *this; + } + + + +template +inline +Col::Col(Col&& X) + : Mat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + access::rw(Mat::n_rows) = X.n_rows; + access::rw(Mat::n_cols) = 1; + access::rw(Mat::n_elem) = X.n_elem; + access::rw(Mat::n_alloc) = X.n_alloc; + + if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) + { + access::rw(Mat::mem_state) = X.mem_state; + access::rw(Mat::mem) = X.mem; + + access::rw(X.n_rows) = 0; + access::rw(X.n_cols) = 1; + access::rw(X.n_elem) = 0; + access::rw(X.n_alloc) = 0; + access::rw(X.mem_state) = 0; + access::rw(X.mem) = nullptr; + } + else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) + { + (*this).init_cold(); + + arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); + + if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) + { + access::rw(X.n_rows) = 0; + access::rw(X.n_cols) = 1; + access::rw(X.n_elem) = 0; + access::rw(X.mem) = nullptr; + } + } + } + + + +template +inline +Col& +Col::operator=(Col&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + (*this).steal_mem(X, true); + + return *this; + } + + + +// template +// inline +// Col::Col(Mat&& X) +// : Mat(arma_vec_indicator(), 1) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_cols != 1) { const Mat& XX = X; Mat::operator=(XX); return; } +// +// access::rw(Mat::n_rows) = X.n_rows; +// access::rw(Mat::n_cols) = 1; +// access::rw(Mat::n_elem) = X.n_elem; +// access::rw(Mat::n_alloc) = X.n_alloc; +// +// if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) +// { +// access::rw(Mat::mem_state) = X.mem_state; +// access::rw(Mat::mem) = X.mem; +// +// access::rw(X.n_rows) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.n_alloc) = 0; +// access::rw(X.mem_state) = 0; +// access::rw(X.mem) = nullptr; +// } +// else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) +// { +// (*this).init_cold(); +// +// arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); +// +// if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) +// { +// access::rw(X.n_rows) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.mem) = nullptr; +// } +// } +// } +// +// +// +// template +// inline +// Col& +// Col::operator=(Mat&& X) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_cols != 1) { const Mat& XX = X; Mat::operator=(XX); return *this; } +// +// (*this).steal_mem(X, true); +// +// return *this; +// } + + + +template +inline +Col& +Col::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + Mat::operator=(val); + + return *this; + } + + + +template +inline +Col& +Col::operator=(const Col& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X); + + return *this; + } + + + +template +template +inline +Col::Col(const Base& X) + : Mat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X.get_ref()); + } + + + +template +template +inline +Col& +Col::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X.get_ref()); + + return *this; + } + + + +template +template +inline +Col::Col(const SpBase& X) + : Mat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint_this(this); + + Mat::operator=(X.get_ref()); + } + + + +template +template +inline +Col& +Col::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X.get_ref()); + + return *this; + } + + + +//! construct a column vector from a given auxiliary array of eTs +template +inline +Col::Col(eT* aux_mem, const uword aux_length, const bool copy_aux_mem, const bool strict) + : Mat(aux_mem, aux_length, 1, copy_aux_mem, strict) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 1; + } + + + +//! construct a column vector from a given auxiliary array of eTs +template +inline +Col::Col(const eT* aux_mem, const uword aux_length) + : Mat(aux_mem, aux_length, 1) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 1; + } + + + +template +template +inline +Col::Col + ( + const Base::pod_type, T1>& A, + const Base::pod_type, T2>& B + ) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 1; + + Mat::init(A,B); + } + + + +template +template +inline +Col::Col(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 1; + + Mat::operator=(X); + } + + + +template +template +inline +Col& +Col::operator=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X); + + return *this; + } + + + +template +inline +Col::Col(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 1; + + Mat::operator=(X); + } + + + +template +inline +Col& +Col::operator=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X); + + return *this; + } + + + +template +inline +mat_injector< Col > +Col::operator<<(const eT val) + { + return mat_injector< Col >(*this, val); + } + + + +template +arma_inline +const Op,op_htrans> +Col::t() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_htrans> +Col::ht() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +Col::st() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +Col::as_row() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +subview_col +Col::row(const uword in_row1) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_row1 >= Mat::n_rows), "Col::row(): indices out of bounds or incorrectly used" ); + + return subview_col(*this, 0, in_row1, 1); + } + + + +template +arma_inline +const subview_col +Col::row(const uword in_row1) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_row1 >= Mat::n_rows), "Col::row(): indices out of bounds or incorrectly used" ); + + return subview_col(*this, 0, in_row1, 1); + } + + + +template +arma_inline +subview_col +Col::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::rows(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return subview_col(*this, 0, in_row1, subview_n_rows); + } + + + +template +arma_inline +const subview_col +Col::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::rows(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return subview_col(*this, 0, in_row1, subview_n_rows); + } + + + +template +arma_inline +subview_col +Col::subvec(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return subview_col(*this, 0, in_row1, subview_n_rows); + } + + + +template +arma_inline +const subview_col +Col::subvec(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return subview_col(*this, 0, in_row1, subview_n_rows); + } + + + +template +arma_inline +subview_col +Col::rows(const span& row_span) + { + arma_extra_debug_sigprint(); + + return subvec(row_span); + } + + + +template +arma_inline +const subview_col +Col::rows(const span& row_span) const + { + arma_extra_debug_sigprint(); + + return subvec(row_span); + } + + + +template +arma_inline +subview_col +Col::subvec(const span& row_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = Mat::n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword subvec_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ), "Col::subvec(): indices out of bounds or incorrectly used" ); + + return subview_col(*this, 0, in_row1, subvec_n_rows); + } + + + +template +arma_inline +const subview_col +Col::subvec(const span& row_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = Mat::n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword subvec_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ), "Col::subvec(): indices out of bounds or incorrectly used" ); + + return subview_col(*this, 0, in_row1, subvec_n_rows); + } + + + +template +arma_inline +subview_col +Col::operator()(const span& row_span) + { + arma_extra_debug_sigprint(); + + return subvec(row_span); + } + + + +template +arma_inline +const subview_col +Col::operator()(const span& row_span) const + { + arma_extra_debug_sigprint(); + + return subvec(row_span); + } + + + +template +arma_inline +subview_col +Col::subvec(const uword start_row, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_cols != 1), "Col::subvec(): given size does not specify a column vector" ); + + arma_debug_check_bounds( ( (start_row >= Mat::n_rows) || ((start_row + s.n_rows) > Mat::n_rows) ), "Col::subvec(): size out of bounds" ); + + return subview_col(*this, 0, start_row, s.n_rows); + } + + + +template +arma_inline +const subview_col +Col::subvec(const uword start_row, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_cols != 1), "Col::subvec(): given size does not specify a column vector" ); + + arma_debug_check_bounds( ( (start_row >= Mat::n_rows) || ((start_row + s.n_rows) > Mat::n_rows) ), "Col::subvec(): size out of bounds" ); + + return subview_col(*this, 0, start_row, s.n_rows); + } + + + +template +arma_inline +subview_col +Col::head(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_rows), "Col::head(): size out of bounds" ); + + return subview_col(*this, 0, 0, N); + } + + + +template +arma_inline +const subview_col +Col::head(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_rows), "Col::head(): size out of bounds" ); + + return subview_col(*this, 0, 0, N); + } + + + +template +arma_inline +subview_col +Col::tail(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_rows), "Col::tail(): size out of bounds" ); + + const uword start_row = Mat::n_rows - N; + + return subview_col(*this, 0, start_row, N); + } + + + +template +arma_inline +const subview_col +Col::tail(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_rows), "Col::tail(): size out of bounds" ); + + const uword start_row = Mat::n_rows - N; + + return subview_col(*this, 0, start_row, N); + } + + + +template +arma_inline +subview_col +Col::head_rows(const uword N) + { + arma_extra_debug_sigprint(); + + return (*this).head(N); + } + + + +template +arma_inline +const subview_col +Col::head_rows(const uword N) const + { + arma_extra_debug_sigprint(); + + return (*this).head(N); + } + + + +template +arma_inline +subview_col +Col::tail_rows(const uword N) + { + arma_extra_debug_sigprint(); + + return (*this).tail(N); + } + + + +template +arma_inline +const subview_col +Col::tail_rows(const uword N) const + { + arma_extra_debug_sigprint(); + + return (*this).tail(N); + } + + + +//! remove specified row +template +inline +void +Col::shed_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= Mat::n_rows, "Col::shed_row(): index out of bounds" ); + + shed_rows(row_num, row_num); + } + + + +//! remove specified rows +template +inline +void +Col::shed_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= Mat::n_rows), + "Col::shed_rows(): indices out of bounds or incorrectly used" + ); + + const uword n_keep_front = in_row1; + const uword n_keep_back = Mat::n_rows - (in_row2 + 1); + + Col X(n_keep_front + n_keep_back, arma_nozeros_indicator()); + + eT* X_mem = X.memptr(); + const eT* t_mem = (*this).memptr(); + + if(n_keep_front > 0) + { + arrayops::copy( X_mem, t_mem, n_keep_front ); + } + + if(n_keep_back > 0) + { + arrayops::copy( &(X_mem[n_keep_front]), &(t_mem[in_row2+1]), n_keep_back); + } + + Mat::steal_mem(X); + } + + + +//! remove specified rows +template +template +inline +void +Col::shed_rows(const Base& indices) + { + arma_extra_debug_sigprint(); + + Mat::shed_rows(indices); + } + + + +template +inline +void +Col::insert_rows(const uword row_num, const uword N, const bool set_to_zero) + { + arma_extra_debug_sigprint(); + + arma_ignore(set_to_zero); + + (*this).insert_rows(row_num, N); + } + + + +template +inline +void +Col::insert_rows(const uword row_num, const uword N) + { + arma_extra_debug_sigprint(); + + const uword t_n_rows = Mat::n_rows; + + const uword A_n_rows = row_num; + const uword B_n_rows = t_n_rows - row_num; + + // insertion at row_num == n_rows is in effect an append operation + arma_debug_check_bounds( (row_num > t_n_rows), "Col::insert_rows(): index out of bounds" ); + + if(N == 0) { return; } + + Col out(t_n_rows + N, arma_nozeros_indicator()); + + eT* out_mem = out.memptr(); + const eT* t_mem = (*this).memptr(); + + if(A_n_rows > 0) + { + arrayops::copy( out_mem, t_mem, A_n_rows ); + } + + if(B_n_rows > 0) + { + arrayops::copy( &(out_mem[row_num + N]), &(t_mem[row_num]), B_n_rows ); + } + + arrayops::fill_zeros( &(out_mem[row_num]), N ); + + Mat::steal_mem(out); + } + + + +//! insert the given object at the specified row position; +//! the given object must have one column +template +template +inline +void +Col::insert_rows(const uword row_num, const Base& X) + { + arma_extra_debug_sigprint(); + + Mat::insert_rows(row_num, X); + } + + + +template +arma_inline +eT& +Col::at(const uword i) + { + return access::rw(Mat::mem[i]); + } + + + +template +arma_inline +const eT& +Col::at(const uword i) const + { + return Mat::mem[i]; + } + + + +template +arma_inline +eT& +Col::at(const uword in_row, const uword) + { + return access::rw( Mat::mem[in_row] ); + } + + + +template +arma_inline +const eT& +Col::at(const uword in_row, const uword) const + { + return Mat::mem[in_row]; + } + + + +template +inline +typename Col::row_iterator +Col::begin_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::begin_row(): index out of bounds" ); + + return Mat::memptr() + row_num; + } + + + +template +inline +typename Col::const_row_iterator +Col::begin_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::begin_row(): index out of bounds" ); + + return Mat::memptr() + row_num; + } + + + +template +inline +typename Col::row_iterator +Col::end_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::end_row(): index out of bounds" ); + + return Mat::memptr() + row_num + 1; + } + + + +template +inline +typename Col::const_row_iterator +Col::end_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::end_row(): index out of bounds" ); + + return Mat::memptr() + row_num + 1; + } + + + +template +template +arma_inline +Col::fixed::fixed() + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + } + } + + + +template +template +arma_inline +Col::fixed::fixed(const fixed& X) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + eT* dest = (use_extra) ? mem_local_extra : Mat::mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + } + + + +template +template +arma_inline +Col::fixed::fixed(const subview_cube& X) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Col::operator=(X); + } + + + +template +template +inline +Col::fixed::fixed(const fill::scalar_holder f) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).fill(f.scalar); + } + + + +template +template +template +inline +Col::fixed::fixed(const fill::fill_class&) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } + } + + + +template +template +template +arma_inline +Col::fixed::fixed(const Base& A) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Col::operator=(A.get_ref()); + } + + + +template +template +template +arma_inline +Col::fixed::fixed(const Base& A, const Base& B) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Col::init(A,B); + } + + + +template +template +inline +Col::fixed::fixed(const eT* aux_mem) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + eT* dest = (use_extra) ? mem_local_extra : Mat::mem_local; + + arrayops::copy( dest, aux_mem, fixed_n_elem ); + } + + + +template +template +inline +Col::fixed::fixed(const char* text) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Col::operator=(text); + } + + + +template +template +inline +Col::fixed::fixed(const std::string& text) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Col::operator=(text); + } + + + +template +template +template +Col& +Col::fixed::operator=(const Base& A) + { + arma_extra_debug_sigprint(); + + Col::operator=(A.get_ref()); + + return *this; + } + + + +template +template +Col& +Col::fixed::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + Col::operator=(val); + + return *this; + } + + + +template +template +Col& +Col::fixed::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + Col::operator=(text); + + return *this; + } + + + +template +template +Col& +Col::fixed::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + Col::operator=(text); + + return *this; + } + + + +template +template +Col& +Col::fixed::operator=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + Col::operator=(X); + + return *this; + } + + + +template +template +inline +Col::fixed::fixed(const std::initializer_list& list) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(list); + } + + + +template +template +inline +Col& +Col::fixed::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + arma_debug_check( (N > fixed_n_elem), "Col::fixed: initialiser list is too long" ); + + eT* this_mem = (*this).memptr(); + + arrayops::copy( this_mem, list.begin(), N ); + + for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } + + return *this; + } + + + +template +template +arma_inline +Col& +Col::fixed::operator=(const fixed& X) + { + arma_extra_debug_sigprint(); + + if(this != &X) + { + eT* dest = (use_extra) ? mem_local_extra : Mat::mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + } + + return *this; + } + + + +#if defined(ARMA_GOOD_COMPILER) + + template + template + template + inline + Col& + Col::fixed::operator=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias == false) + { + arma_debug_assert_same_size(fixed_n_elem, uword(1), X.get_n_rows(), X.get_n_cols(), "Col::fixed::operator="); + + eop_type::apply(*this, X); + } + else + { + arma_extra_debug_print("bad_alias = true"); + + Col tmp(X); + + (*this) = tmp; + } + + return *this; + } + + + + template + template + template + inline + Col& + Col::fixed::operator=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias == false) + { + arma_debug_assert_same_size(fixed_n_elem, uword(1), X.get_n_rows(), X.get_n_cols(), "Col::fixed::operator="); + + eglue_type::apply(*this, X); + } + else + { + arma_extra_debug_print("bad_alias = true"); + + Col tmp(X); + + (*this) = tmp; + } + + return *this; + } + +#endif + + + +template +template +arma_inline +const Op< typename Col::template fixed::Col_fixed_type, op_htrans > +Col::fixed::t() const + { + return Op< typename Col::template fixed::Col_fixed_type, op_htrans >(*this); + } + + + +template +template +arma_inline +const Op< typename Col::template fixed::Col_fixed_type, op_htrans > +Col::fixed::ht() const + { + return Op< typename Col::template fixed::Col_fixed_type, op_htrans >(*this); + } + + + +template +template +arma_inline +const Op< typename Col::template fixed::Col_fixed_type, op_strans > +Col::fixed::st() const + { + return Op< typename Col::template fixed::Col_fixed_type, op_strans >(*this); + } + + + +template +template +arma_inline +const eT& +Col::fixed::at_alt(const uword ii) const + { + #if defined(ARMA_HAVE_ALIGNED_ATTRIBUTE) + + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + + #else + const eT* mem_aligned = (use_extra) ? mem_local_extra : Mat::mem_local; + + memory::mark_as_aligned(mem_aligned); + + return mem_aligned[ii]; + #endif + } + + + +template +template +arma_inline +eT& +Col::fixed::operator[] (const uword ii) + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Col::fixed::operator[] (const uword ii) const + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Col::fixed::at(const uword ii) + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Col::fixed::at(const uword ii) const + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Col::fixed::operator() (const uword ii) + { + arma_debug_check_bounds( (ii >= fixed_n_elem), "Col::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Col::fixed::operator() (const uword ii) const + { + arma_debug_check_bounds( (ii >= fixed_n_elem), "Col::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Col::fixed::at(const uword in_row, const uword) + { + return (use_extra) ? mem_local_extra[in_row] : Mat::mem_local[in_row]; + } + + + +template +template +arma_inline +const eT& +Col::fixed::at(const uword in_row, const uword) const + { + return (use_extra) ? mem_local_extra[in_row] : Mat::mem_local[in_row]; + } + + + +template +template +arma_inline +eT& +Col::fixed::operator() (const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= fixed_n_elem) || (in_col > 0)), "Col::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[in_row] : Mat::mem_local[in_row]; + } + + + +template +template +arma_inline +const eT& +Col::fixed::operator() (const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= fixed_n_elem) || (in_col > 0)), "Col::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[in_row] : Mat::mem_local[in_row]; + } + + + +template +template +arma_inline +eT* +Col::fixed::memptr() + { + return (use_extra) ? mem_local_extra : Mat::mem_local; + } + + + +template +template +arma_inline +const eT* +Col::fixed::memptr() const + { + return (use_extra) ? mem_local_extra : Mat::mem_local; + } + + + +template +template +inline +const Col& +Col::fixed::fill(const eT val) + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, val ); + + return *this; + } + + + +template +template +inline +const Col& +Col::fixed::zeros() + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + + return *this; + } + + + +template +template +inline +const Col& +Col::fixed::ones() + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(1) ); + + return *this; + } + + + +template +inline +Col::Col(const arma_fixed_indicator&, const uword in_n_elem, const eT* in_mem) + : Mat(arma_fixed_indicator(), in_n_elem, 1, 1, in_mem) + { + arma_extra_debug_sigprint_this(this); + } + + + +#if defined(ARMA_EXTRA_COL_MEAT) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_COL_MEAT) +#endif + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/CubeToMatOp_bones.hpp b/src/armadillo/include/armadillo_bits/CubeToMatOp_bones.hpp new file mode 100644 index 0000000..cd2ba59 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/CubeToMatOp_bones.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup CubeToMatOp +//! @{ + + + +template +class CubeToMatOp : public Base< typename T1::elem_type, CubeToMatOp > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline explicit CubeToMatOp(const T1& in_m); + inline CubeToMatOp(const T1& in_m, const uword in_aux_uword); + inline ~CubeToMatOp(); + + arma_aligned const T1& m; //!< the operand; must be derived from BaseCube + arma_aligned uword aux_uword; //!< auxiliary data, uword format + + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/CubeToMatOp_meat.hpp b/src/armadillo/include/armadillo_bits/CubeToMatOp_meat.hpp new file mode 100644 index 0000000..abe83e8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/CubeToMatOp_meat.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup CubeToMatOp +//! @{ + + + +template +inline +CubeToMatOp::CubeToMatOp(const T1& in_m) + : m(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +CubeToMatOp::CubeToMatOp(const T1& in_m, const uword in_aux_uword) + : m(in_m) + , aux_uword(in_aux_uword) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +CubeToMatOp::~CubeToMatOp() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Cube_bones.hpp b/src/armadillo/include/armadillo_bits/Cube_bones.hpp new file mode 100644 index 0000000..5cf364a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Cube_bones.hpp @@ -0,0 +1,564 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Cube +//! @{ + + + +struct Cube_prealloc + { + static constexpr uword mat_ptrs_size = 4; + static constexpr uword mem_n_elem = 64; + }; + + + +//! Dense cube class + +template +class Cube : public BaseCube< eT, Cube > + { + public: + + typedef eT elem_type; //!< the type of elements stored in the cube + typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT + + const uword n_rows; //!< number of rows in each slice (read-only) + const uword n_cols; //!< number of columns in each slice (read-only) + const uword n_elem_slice; //!< number of elements in each slice (read-only) + const uword n_slices; //!< number of slices in the cube (read-only) + const uword n_elem; //!< number of elements in the cube (read-only) + const uword n_alloc; //!< number of allocated elements (read-only); NOTE: n_alloc can be 0, even if n_elem > 0 + const uword mem_state; + + // mem_state = 0: normal cube which manages its own memory + // mem_state = 1: use auxiliary memory until a size change + // mem_state = 2: use auxiliary memory and don't allow the number of elements to be changed + // mem_state = 3: fixed size (eg. via template based size specification) + + arma_aligned const eT* const mem; //!< pointer to the memory used for storing elements (memory is read-only) + + + protected: + + using mat_type = Mat; + + #if defined(ARMA_USE_OPENMP) + using raw_mat_ptr_type = mat_type*; + using atomic_mat_ptr_type = mat_type*; + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + using raw_mat_ptr_type = mat_type*; + using atomic_mat_ptr_type = std::atomic; + #else + using raw_mat_ptr_type = mat_type*; + using atomic_mat_ptr_type = mat_type*; + #endif + + atomic_mat_ptr_type* mat_ptrs = nullptr; + + #if (!defined(ARMA_DONT_USE_STD_MUTEX)) + mutable std::mutex mat_mutex; // required for slice() + #endif + + arma_aligned atomic_mat_ptr_type mat_ptrs_local[ Cube_prealloc::mat_ptrs_size ]; + arma_align_mem eT mem_local[ Cube_prealloc::mem_n_elem ]; // local storage, for small cubes + + + public: + + inline ~Cube(); + inline Cube(); + + inline explicit Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); + inline explicit Cube(const SizeCube& s); + + template inline explicit Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const arma_initmode_indicator&); + template inline explicit Cube(const SizeCube& s, const arma_initmode_indicator&); + + template inline Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const fill::fill_class& f); + template inline Cube(const SizeCube& s, const fill::fill_class& f); + + inline Cube(const uword in_rows, const uword in_cols, const uword in_slices, const fill::scalar_holder f); + inline Cube(const SizeCube& s, const fill::scalar_holder f); + + inline Cube(Cube&& m); + inline Cube& operator=(Cube&& m); + + inline Cube( eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const uword aux_n_slices, const bool copy_aux_mem = true, const bool strict = false, const bool prealloc_mat = false); + inline Cube(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const uword aux_n_slices); + + inline Cube& operator= (const eT val); + inline Cube& operator+=(const eT val); + inline Cube& operator-=(const eT val); + inline Cube& operator*=(const eT val); + inline Cube& operator/=(const eT val); + + inline Cube(const Cube& m); + inline Cube& operator= (const Cube& m); + inline Cube& operator+=(const Cube& m); + inline Cube& operator-=(const Cube& m); + inline Cube& operator%=(const Cube& m); + inline Cube& operator/=(const Cube& m); + + template + inline explicit Cube(const BaseCube& A, const BaseCube& B); + + inline Cube(const subview_cube& X); + inline Cube& operator= (const subview_cube& X); + inline Cube& operator+=(const subview_cube& X); + inline Cube& operator-=(const subview_cube& X); + inline Cube& operator%=(const subview_cube& X); + inline Cube& operator/=(const subview_cube& X); + + template inline Cube(const subview_cube_slices& X); + template inline Cube& operator= (const subview_cube_slices& X); + template inline Cube& operator+=(const subview_cube_slices& X); + template inline Cube& operator-=(const subview_cube_slices& X); + template inline Cube& operator%=(const subview_cube_slices& X); + template inline Cube& operator/=(const subview_cube_slices& X); + + arma_inline subview_cube row(const uword in_row); + arma_inline const subview_cube row(const uword in_row) const; + + arma_inline subview_cube col(const uword in_col); + arma_inline const subview_cube col(const uword in_col) const; + + inline Mat& slice(const uword in_slice); + inline const Mat& slice(const uword in_slice) const; + + arma_inline subview_cube rows(const uword in_row1, const uword in_row2); + arma_inline const subview_cube rows(const uword in_row1, const uword in_row2) const; + + arma_inline subview_cube cols(const uword in_col1, const uword in_col2); + arma_inline const subview_cube cols(const uword in_col1, const uword in_col2) const; + + arma_inline subview_cube slices(const uword in_slice1, const uword in_slice2); + arma_inline const subview_cube slices(const uword in_slice1, const uword in_slice2) const; + + arma_inline subview_cube subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2); + arma_inline const subview_cube subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2) const; + + inline subview_cube subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s); + inline const subview_cube subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const; + + inline subview_cube subcube(const span& row_span, const span& col_span, const span& slice_span); + inline const subview_cube subcube(const span& row_span, const span& col_span, const span& slice_span) const; + + inline subview_cube operator()(const span& row_span, const span& col_span, const span& slice_span); + inline const subview_cube operator()(const span& row_span, const span& col_span, const span& slice_span) const; + + inline subview_cube operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s); + inline const subview_cube operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const; + + arma_inline subview_cube tube(const uword in_row1, const uword in_col1); + arma_inline const subview_cube tube(const uword in_row1, const uword in_col1) const; + + arma_inline subview_cube tube(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); + arma_inline const subview_cube tube(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const; + + arma_inline subview_cube tube(const uword in_row1, const uword in_col1, const SizeMat& s); + arma_inline const subview_cube tube(const uword in_row1, const uword in_col1, const SizeMat& s) const; + + inline subview_cube tube(const span& row_span, const span& col_span); + inline const subview_cube tube(const span& row_span, const span& col_span) const; + + inline subview_cube head_slices(const uword N); + inline const subview_cube head_slices(const uword N) const; + + inline subview_cube tail_slices(const uword N); + inline const subview_cube tail_slices(const uword N) const; + + template arma_inline subview_elem1 elem(const Base& a); + template arma_inline const subview_elem1 elem(const Base& a) const; + + template arma_inline subview_elem1 operator()(const Base& a); + template arma_inline const subview_elem1 operator()(const Base& a) const; + + + arma_inline subview_cube_each1 each_slice(); + arma_inline const subview_cube_each1 each_slice() const; + + template inline subview_cube_each2 each_slice(const Base& indices); + template inline const subview_cube_each2 each_slice(const Base& indices) const; + + inline Cube& each_slice(const std::function< void( Mat&) >& F); + inline const Cube& each_slice(const std::function< void(const Mat&) >& F) const; + + inline Cube& each_slice(const std::function< void( Mat&) >& F, const bool use_mp); + inline const Cube& each_slice(const std::function< void(const Mat&) >& F, const bool use_mp) const; + + + template arma_inline subview_cube_slices slices(const Base& indices); + template arma_inline const subview_cube_slices slices(const Base& indices) const; + + + inline void shed_row(const uword row_num); + inline void shed_col(const uword col_num); + inline void shed_slice(const uword slice_num); + + inline void shed_rows(const uword in_row1, const uword in_row2); + inline void shed_cols(const uword in_col1, const uword in_col2); + inline void shed_slices(const uword in_slice1, const uword in_slice2); + + template inline void shed_slices(const Base& indices); + + arma_deprecated inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero); + arma_deprecated inline void insert_cols(const uword row_num, const uword N, const bool set_to_zero); + arma_deprecated inline void insert_slices(const uword slice_num, const uword N, const bool set_to_zero); + + inline void insert_rows(const uword row_num, const uword N); + inline void insert_cols(const uword row_num, const uword N); + inline void insert_slices(const uword slice_num, const uword N); + + template inline void insert_rows(const uword row_num, const BaseCube& X); + template inline void insert_cols(const uword col_num, const BaseCube& X); + template inline void insert_slices(const uword slice_num, const BaseCube& X); + template inline void insert_slices(const uword slice_num, const Base& X); + + + template inline Cube(const GenCube& X); + template inline Cube& operator= (const GenCube& X); + template inline Cube& operator+=(const GenCube& X); + template inline Cube& operator-=(const GenCube& X); + template inline Cube& operator%=(const GenCube& X); + template inline Cube& operator/=(const GenCube& X); + + template inline Cube(const OpCube& X); + template inline Cube& operator= (const OpCube& X); + template inline Cube& operator+=(const OpCube& X); + template inline Cube& operator-=(const OpCube& X); + template inline Cube& operator%=(const OpCube& X); + template inline Cube& operator/=(const OpCube& X); + + template inline Cube(const eOpCube& X); + template inline Cube& operator= (const eOpCube& X); + template inline Cube& operator+=(const eOpCube& X); + template inline Cube& operator-=(const eOpCube& X); + template inline Cube& operator%=(const eOpCube& X); + template inline Cube& operator/=(const eOpCube& X); + + template inline Cube(const mtOpCube& X); + template inline Cube& operator= (const mtOpCube& X); + template inline Cube& operator+=(const mtOpCube& X); + template inline Cube& operator-=(const mtOpCube& X); + template inline Cube& operator%=(const mtOpCube& X); + template inline Cube& operator/=(const mtOpCube& X); + + template inline Cube(const GlueCube& X); + template inline Cube& operator= (const GlueCube& X); + template inline Cube& operator+=(const GlueCube& X); + template inline Cube& operator-=(const GlueCube& X); + template inline Cube& operator%=(const GlueCube& X); + template inline Cube& operator/=(const GlueCube& X); + + template inline Cube(const eGlueCube& X); + template inline Cube& operator= (const eGlueCube& X); + template inline Cube& operator+=(const eGlueCube& X); + template inline Cube& operator-=(const eGlueCube& X); + template inline Cube& operator%=(const eGlueCube& X); + template inline Cube& operator/=(const eGlueCube& X); + + template inline Cube(const mtGlueCube& X); + template inline Cube& operator= (const mtGlueCube& X); + template inline Cube& operator+=(const mtGlueCube& X); + template inline Cube& operator-=(const mtGlueCube& X); + template inline Cube& operator%=(const mtGlueCube& X); + template inline Cube& operator/=(const mtGlueCube& X); + + + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; + + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + + arma_warn_unused arma_inline eT& at(const uword i); + arma_warn_unused arma_inline const eT& at(const uword i) const; + + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col, const uword in_slice) const; + #endif + + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col, const uword in_slice) const; + + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col, const uword in_slice) const; + + arma_inline const Cube& operator++(); + arma_inline void operator++(int); + + arma_inline const Cube& operator--(); + arma_inline void operator--(int); + + arma_warn_unused arma_inline bool is_empty() const; + + arma_warn_unused inline bool internal_is_finite() const; + arma_warn_unused inline bool internal_has_inf() const; + arma_warn_unused inline bool internal_has_nan() const; + arma_warn_unused inline bool internal_has_nonfinite() const; + + arma_warn_unused arma_inline bool in_range(const uword i) const; + arma_warn_unused arma_inline bool in_range(const span& x) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const uword in_slice) const; + arma_warn_unused inline bool in_range(const span& row_span, const span& col_span, const span& slice_span) const; + + arma_warn_unused inline bool in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const; + + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; + + arma_warn_unused arma_inline eT* slice_memptr(const uword slice); + arma_warn_unused arma_inline const eT* slice_memptr(const uword slice) const; + + arma_warn_unused arma_inline eT* slice_colptr(const uword in_slice, const uword in_col); + arma_warn_unused arma_inline const eT* slice_colptr(const uword in_slice, const uword in_col) const; + + inline Cube& set_size(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& set_size(const SizeCube& s); + + inline Cube& reshape(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& reshape(const SizeCube& s); + + inline Cube& resize(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& resize(const SizeCube& s); + + + template inline Cube& copy_size(const Cube& m); + + template inline Cube& for_each(functor F); + template inline const Cube& for_each(functor F) const; + + template inline Cube& transform(functor F); + template inline Cube& imbue(functor F); + + inline Cube& replace(const eT old_val, const eT new_val); + + inline Cube& clean(const pod_type threshold); + + inline Cube& clamp(const eT min_val, const eT max_val); + + inline Cube& fill(const eT val); + + inline Cube& zeros(); + inline Cube& zeros(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& zeros(const SizeCube& s); + + inline Cube& ones(); + inline Cube& ones(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& ones(const SizeCube& s); + + inline Cube& randu(); + inline Cube& randu(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& randu(const SizeCube& s); + + inline Cube& randn(); + inline Cube& randn(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& randn(const SizeCube& s); + + inline void reset(); + inline void soft_reset(); + + + template inline void set_real(const BaseCube& X); + template inline void set_imag(const BaseCube& X); + + + arma_warn_unused inline eT min() const; + arma_warn_unused inline eT max() const; + + inline eT min(uword& index_of_min_val) const; + inline eT max(uword& index_of_max_val) const; + + inline eT min(uword& row_of_min_val, uword& col_of_min_val, uword& slice_of_min_val) const; + inline eT max(uword& row_of_max_val, uword& col_of_max_val, uword& slice_of_max_val) const; + + + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; + + arma_cold inline bool load(const std::string name, const file_type type = auto_detect); + arma_cold inline bool load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_cold inline bool load( std::istream& is, const file_type type = auto_detect); + + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = auto_detect); + + + // iterators + + typedef eT* iterator; + typedef const eT* const_iterator; + + typedef eT* slice_iterator; + typedef const eT* const_slice_iterator; + + inline iterator begin(); + inline const_iterator begin() const; + inline const_iterator cbegin() const; + + inline iterator end(); + inline const_iterator end() const; + inline const_iterator cend() const; + + inline slice_iterator begin_slice(const uword slice_num); + inline const_slice_iterator begin_slice(const uword slice_num) const; + + inline slice_iterator end_slice(const uword slice_num); + inline const_slice_iterator end_slice(const uword slice_num) const; + + inline void clear(); + inline bool empty() const; + inline uword size() const; + + arma_warn_unused inline eT& front(); + arma_warn_unused inline const eT& front() const; + + arma_warn_unused inline eT& back(); + arma_warn_unused inline const eT& back() const; + + inline void swap(Cube& B); + + inline void steal_mem(Cube& X); //!< don't use this unless you're writing code internal to Armadillo + inline void steal_mem(Cube& X, const bool is_move); //!< don't use this unless you're writing code internal to Armadillo + + template class fixed; + + + protected: + + inline void init_cold(); + inline void init_warm(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); + + template + inline void init(const BaseCube& A, const BaseCube& B); + + inline void delete_mat(); + inline void create_mat(); + + inline Mat* create_mat_ptr(const uword in_slice) const; + inline Mat* get_mat_ptr(const uword in_slice) const; + + friend class glue_join; + friend class op_reshape; + friend class op_resize; + friend class subview_cube; + + + public: + + #if defined(ARMA_EXTRA_CUBE_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_CUBE_PROTO) + #endif + }; + + + +template +template +class Cube::fixed : public Cube + { + private: + + static constexpr uword fixed_n_elem = fixed_n_rows * fixed_n_cols * fixed_n_slices; + static constexpr uword fixed_n_elem_slice = fixed_n_rows * fixed_n_cols; + + static constexpr bool use_extra = (fixed_n_elem > Cube_prealloc::mem_n_elem); + + arma_aligned atomic_mat_ptr_type mat_ptrs_local_extra[ (fixed_n_slices > Cube_prealloc::mat_ptrs_size) ? fixed_n_slices : 1 ]; + arma_align_mem eT mem_local_extra[ use_extra ? fixed_n_elem : 1 ]; + + arma_inline void mem_setup(); + + + public: + + inline fixed(); + inline fixed(const fixed& X); + + inline fixed(const fill::scalar_holder f); + template inline fixed(const fill::fill_class& f); + template inline fixed(const BaseCube& A); + template inline fixed(const BaseCube& A, const BaseCube& B); + + using Cube::operator=; + using Cube::operator(); + + inline Cube& operator=(const fixed& X); + + + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; + + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col, const uword in_slice) const; + #endif + + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col, const uword in_slice) const; + + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col, const uword in_slice) const; + }; + + + +class Cube_aux + { + public: + + template arma_inline static void prefix_pp(Cube& x); + template arma_inline static void prefix_pp(Cube< std::complex >& x); + + template arma_inline static void postfix_pp(Cube& x); + template arma_inline static void postfix_pp(Cube< std::complex >& x); + + template arma_inline static void prefix_mm(Cube& x); + template arma_inline static void prefix_mm(Cube< std::complex >& x); + + template arma_inline static void postfix_mm(Cube& x); + template arma_inline static void postfix_mm(Cube< std::complex >& x); + + template inline static void set_real(Cube& out, const BaseCube& X); + template inline static void set_imag(Cube& out, const BaseCube& X); + + template inline static void set_real(Cube< std::complex >& out, const BaseCube< T,T1>& X); + template inline static void set_imag(Cube< std::complex >& out, const BaseCube< T,T1>& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Cube_meat.hpp b/src/armadillo/include/armadillo_bits/Cube_meat.hpp new file mode 100644 index 0000000..265dcf4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Cube_meat.hpp @@ -0,0 +1,5920 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Cube +//! @{ + + +template +inline +Cube::~Cube() + { + arma_extra_debug_sigprint_this(this); + + delete_mat(); + + if( (mem_state == 0) && (n_alloc > 0) ) + { + arma_extra_debug_print("Cube::destructor: releasing memory"); + memory::release( access::rw(mem) ); + } + + // try to expose buggy user code that accesses deleted objects + if(arma_config::debug) { access::rw(mem) = nullptr; } + + arma_type_check(( is_supported_elem_type::value == false )); + } + + + +template +inline +Cube::Cube() + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + } + + + +//! construct the cube to have user specified dimensions +template +inline +Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem_slice(in_n_rows*in_n_cols) + , n_slices(in_n_slices) + , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +template +inline +Cube::Cube(const SizeCube& s) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem_slice(s.n_rows*s.n_cols) + , n_slices(s.n_slices) + , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const arma_initmode_indicator&) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem_slice(in_n_rows*in_n_cols) + , n_slices(in_n_slices) + , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Cube::Cube(const SizeCube& s, const arma_initmode_indicator&) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem_slice(s.n_rows*s.n_cols) + , n_slices(s.n_slices) + , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! construct the cube to have user specified dimensions and fill with specified pattern +template +template +inline +Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const fill::fill_class&) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem_slice(in_n_rows*in_n_cols) + , n_slices(in_n_slices) + , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } + + arma_static_check( (is_same_type::yes), "Cube::Cube(): unsupported fill type" ); + } + + + +template +template +inline +Cube::Cube(const SizeCube& s, const fill::fill_class&) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem_slice(s.n_rows*s.n_cols) + , n_slices(s.n_slices) + , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } + + arma_static_check( (is_same_type::yes), "Cube::Cube(): unsupported fill type" ); + } + + + +//! construct the cube to have user specified dimensions and fill with specified value +template +inline +Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const fill::scalar_holder f) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem_slice(in_n_rows*in_n_cols) + , n_slices(in_n_slices) + , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f.scalar); + } + + + +template +inline +Cube::Cube(const SizeCube& s, const fill::scalar_holder f) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem_slice(s.n_rows*s.n_cols) + , n_slices(s.n_slices) + , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f.scalar); + } + + + +template +inline +Cube::Cube(Cube&& in_cube) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &in_cube); + + (*this).steal_mem(in_cube, true); + } + + + +template +inline +Cube& +Cube::operator=(Cube&& in_cube) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &in_cube); + + (*this).steal_mem(in_cube, true); + + return *this; + } + + + +template +inline +void +Cube::init_cold() + { + arma_extra_debug_sigprint( arma_str::format("n_rows = %u, n_cols = %u, n_slices = %u") % n_rows % n_cols % n_slices ); + + #if defined(ARMA_64BIT_WORD) + const char* error_message = "Cube::init(): requested size is too large"; + #else + const char* error_message = "Cube::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + arma_debug_check + ( + ( + ( (n_rows > 0x0FFF) || (n_cols > 0x0FFF) || (n_slices > 0xFF) ) + ? ( (double(n_rows) * double(n_cols) * double(n_slices)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message + ); + + + if(n_elem <= Cube_prealloc::mem_n_elem) + { + if(n_elem > 0) { arma_extra_debug_print("Cube::init(): using local memory"); } + + access::rw(mem) = (n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; + } + else + { + arma_extra_debug_print("Cube::init(): acquiring memory"); + + access::rw(mem) = memory::acquire(n_elem); + access::rw(n_alloc) = n_elem; + } + + create_mat(); + } + + + +template +inline +void +Cube::init_warm(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) + { + arma_extra_debug_sigprint( arma_str::format("in_n_rows = %u, in_n_cols = %u, in_n_slices = %u") % in_n_rows % in_n_cols % in_n_slices ); + + if( (n_rows == in_n_rows) && (n_cols == in_n_cols) && (n_slices == in_n_slices) ) { return; } + + const uword t_mem_state = mem_state; + + bool err_state = false; + char* err_msg = nullptr; + + const char* error_message_1 = "Cube::init(): size is fixed and hence cannot be changed"; + + arma_debug_set_error( err_state, err_msg, (t_mem_state == 3), error_message_1 ); + + #if defined(ARMA_64BIT_WORD) + const char* error_message_2 = "Cube::init(): requested size is too large"; + #else + const char* error_message_2 = "Cube::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + arma_debug_set_error + ( + err_state, + err_msg, + ( + ( (in_n_rows > 0x0FFF) || (in_n_cols > 0x0FFF) || (in_n_slices > 0xFF) ) + ? ( (double(in_n_rows) * double(in_n_cols) * double(in_n_slices)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message_2 + ); + + arma_debug_check(err_state, err_msg); + + const uword old_n_elem = n_elem; + const uword new_n_elem = in_n_rows * in_n_cols * in_n_slices; + + if(old_n_elem == new_n_elem) + { + arma_extra_debug_print("Cube::init(): reusing memory"); + + delete_mat(); + + access::rw(n_rows) = in_n_rows; + access::rw(n_cols) = in_n_cols; + access::rw(n_elem_slice) = in_n_rows*in_n_cols; + access::rw(n_slices) = in_n_slices; + + create_mat(); + + return; + } + + arma_debug_check( (t_mem_state == 2), "Cube::init(): mismatch between size of auxiliary memory and requested size" ); + + delete_mat(); + + if(new_n_elem <= Cube_prealloc::mem_n_elem) + { + if(n_alloc > 0) + { + arma_extra_debug_print("Cube::init(): releasing memory"); + memory::release( access::rw(mem) ); + } + + if(new_n_elem > 0) { arma_extra_debug_print("Cube::init(): using local memory"); } + + access::rw(mem) = (new_n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; + } + else // condition: new_n_elem > Cube_prealloc::mem_n_elem + { + if(new_n_elem > n_alloc) + { + if(n_alloc > 0) + { + arma_extra_debug_print("Cube::init(): releasing memory"); + memory::release( access::rw(mem) ); + + // in case memory::acquire() throws an exception + access::rw(mem) = nullptr; + access::rw(n_rows) = 0; + access::rw(n_cols) = 0; + access::rw(n_elem_slice) = 0; + access::rw(n_slices) = 0; + access::rw(n_elem) = 0; + access::rw(n_alloc) = 0; + } + + arma_extra_debug_print("Cube::init(): acquiring memory"); + access::rw(mem) = memory::acquire(new_n_elem); + access::rw(n_alloc) = new_n_elem; + } + else // condition: new_n_elem <= n_alloc + { + arma_extra_debug_print("Cube::init(): reusing memory"); + } + } + + access::rw(n_rows) = in_n_rows; + access::rw(n_cols) = in_n_cols; + access::rw(n_elem_slice) = in_n_rows*in_n_cols; + access::rw(n_slices) = in_n_slices; + access::rw(n_elem) = new_n_elem; + access::rw(mem_state) = 0; + + create_mat(); + } + + + +//! for constructing a complex cube out of two non-complex cubes +template +template +inline +void +Cube::init + ( + const BaseCube::pod_type,T1>& X, + const BaseCube::pod_type,T2>& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type T; + + arma_type_check(( is_cx::no )); //!< compile-time abort if eT is not std::complex + arma_type_check(( is_cx< T>::yes )); //!< compile-time abort if T is std::complex + + arma_type_check(( is_same_type< std::complex, eT >::no )); //!< compile-time abort if types are not compatible + + const ProxyCube PX(X.get_ref()); + const ProxyCube PY(Y.get_ref()); + + arma_debug_assert_same_size(PX, PY, "Cube()"); + + const uword local_n_rows = PX.get_n_rows(); + const uword local_n_cols = PX.get_n_cols(); + const uword local_n_slices = PX.get_n_slices(); + + init_warm(local_n_rows, local_n_cols, local_n_slices); + + eT* out_mem = (*this).memptr(); + + const bool use_at = ( ProxyCube::use_at || ProxyCube::use_at ); + + if(use_at == false) + { + typedef typename ProxyCube::ea_type ea_type1; + typedef typename ProxyCube::ea_type ea_type2; + + const uword N = n_elem; + + ea_type1 A = PX.get_ea(); + ea_type2 B = PY.get_ea(); + + for(uword i=0; i(A[i], B[i]); + } + } + else + { + for(uword uslice = 0; uslice < local_n_slices; ++uslice) + for(uword ucol = 0; ucol < local_n_cols; ++ucol ) + for(uword urow = 0; urow < local_n_rows; ++urow ) + { + *out_mem = std::complex( PX.at(urow,ucol,uslice), PY.at(urow,ucol,uslice) ); + out_mem++; + } + } + } + + + +template +inline +void +Cube::delete_mat() + { + arma_extra_debug_sigprint(); + + if((n_slices > 0) && (mat_ptrs != nullptr)) + { + for(uword s=0; s < n_slices; ++s) + { + raw_mat_ptr_type mat_ptr = raw_mat_ptr_type(mat_ptrs[s]); // explicit cast to indicate load from std::atomic*> + + if(mat_ptr != nullptr) + { + arma_extra_debug_print( arma_str::format("Cube::delete_mat(): destroying matrix %u") % s ); + delete mat_ptr; + mat_ptrs[s] = nullptr; + } + } + + if( (mem_state <= 2) && (n_slices > Cube_prealloc::mat_ptrs_size) ) + { + arma_extra_debug_print("Cube::delete_mat(): freeing mat_ptrs array"); + delete [] mat_ptrs; + mat_ptrs = nullptr; + } + } + } + + + +template +inline +void +Cube::create_mat() + { + arma_extra_debug_sigprint(); + + if(n_slices == 0) { mat_ptrs = nullptr; return; } + + if(mem_state <= 2) + { + if(n_slices <= Cube_prealloc::mat_ptrs_size) + { + arma_extra_debug_print("Cube::create_mat(): using local memory for mat_ptrs array"); + + mat_ptrs = mat_ptrs_local; + } + else + { + arma_extra_debug_print("Cube::create_mat(): allocating mat_ptrs array"); + + mat_ptrs = new(std::nothrow) atomic_mat_ptr_type[n_slices]; + + arma_check_bad_alloc( (mat_ptrs == nullptr), "Cube::create_mat(): out of memory" ); + } + } + + for(uword s=0; s < n_slices; ++s) { mat_ptrs[s] = nullptr; } + } + + + +template +inline +Mat* +Cube::create_mat_ptr(const uword in_slice) const + { + arma_extra_debug_sigprint(); + + arma_extra_debug_print( arma_str::format("Cube::create_mat_ptr(): creating matrix %u") % in_slice ); + + const eT* mat_mem = (n_elem_slice > 0) ? slice_memptr(in_slice) : nullptr; + + Mat* mat_ptr = new(std::nothrow) Mat('j', mat_mem, n_rows, n_cols); + + return mat_ptr; + } + + + +template +inline +Mat* +Cube::get_mat_ptr(const uword in_slice) const + { + arma_extra_debug_sigprint(); + + raw_mat_ptr_type mat_ptr = nullptr; + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp atomic read + mat_ptr = mat_ptrs[in_slice]; + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + mat_ptr = mat_ptrs[in_slice].load(); + } + #else + { + mat_ptr = mat_ptrs[in_slice]; + } + #endif + + if(mat_ptr == nullptr) + { + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_Cube_mat_ptrs) + { + #pragma omp atomic read + mat_ptr = mat_ptrs[in_slice]; + + if(mat_ptr == nullptr) { mat_ptr = create_mat_ptr(in_slice); } + + #pragma omp atomic write + mat_ptrs[in_slice] = mat_ptr; + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(mat_mutex); + + mat_ptr = mat_ptrs[in_slice].load(); + + if(mat_ptr == nullptr) { mat_ptr = create_mat_ptr(in_slice); } + + mat_ptrs[in_slice].store(mat_ptr); + } + #else + { + mat_ptr = create_mat_ptr(in_slice); + + mat_ptrs[in_slice] = mat_ptr; + } + #endif + + arma_check_bad_alloc( (mat_ptr == nullptr), "Cube::get_mat_ptr(): out of memory" ); + } + + return mat_ptr; + } + + + +//! Set the cube to be equal to the specified scalar. +//! NOTE: the size of the cube will be 1x1x1 +template +inline +Cube& +Cube::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + init_warm(1,1,1); + + access::rw(mem[0]) = val; + + return *this; + } + + + +//! In-place addition of a scalar to all elements of the cube +template +inline +Cube& +Cube::operator+=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_plus( memptr(), val, n_elem ); + + return *this; + } + + + +//! In-place subtraction of a scalar from all elements of the cube +template +inline +Cube& +Cube::operator-=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_minus( memptr(), val, n_elem ); + + return *this; + } + + + +//! In-place multiplication of all elements of the cube with a scalar +template +inline +Cube& +Cube::operator*=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_mul( memptr(), val, n_elem ); + + return *this; + } + + + +//! In-place division of all elements of the cube with a scalar +template +inline +Cube& +Cube::operator/=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_div( memptr(), val, n_elem ); + + return *this; + } + + + +//! construct a cube from a given cube +template +inline +Cube::Cube(const Cube& x) + : n_rows(x.n_rows) + , n_cols(x.n_cols) + , n_elem_slice(x.n_elem_slice) + , n_slices(x.n_slices) + , n_elem(x.n_elem) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &x); + + init_cold(); + + arrayops::copy( memptr(), x.mem, n_elem ); + } + + + +//! construct a cube from a given cube +template +inline +Cube& +Cube::operator=(const Cube& x) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &x); + + if(this != &x) + { + init_warm(x.n_rows, x.n_cols, x.n_slices); + + arrayops::copy( memptr(), x.mem, n_elem ); + } + + return *this; + } + + + +//! construct a cube from a given auxiliary array of eTs. +//! if copy_aux_mem is true, new memory is allocated and the array is copied. +//! if copy_aux_mem is false, the auxiliary array is used directly (without allocating memory and copying). +template +inline +Cube::Cube(eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const uword aux_n_slices, const bool copy_aux_mem, const bool strict, const bool prealloc_mat) + : n_rows ( aux_n_rows ) + , n_cols ( aux_n_cols ) + , n_elem_slice( aux_n_rows*aux_n_cols ) + , n_slices ( aux_n_slices ) + , n_elem ( aux_n_rows*aux_n_cols*aux_n_slices ) + , n_alloc ( 0 ) + , mem_state ( copy_aux_mem ? 0 : (strict ? 2 : 1) ) + , mem ( copy_aux_mem ? nullptr : aux_mem ) + { + arma_extra_debug_sigprint_this(this); + + arma_ignore(prealloc_mat); // kept only for compatibility with old user code + + if(copy_aux_mem) + { + init_cold(); + + arrayops::copy( memptr(), aux_mem, n_elem ); + } + else + { + create_mat(); + } + } + + + +//! construct a cube from a given auxiliary read-only array of eTs. +//! the array is copied. +template +inline +Cube::Cube(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const uword aux_n_slices) + : n_rows(aux_n_rows) + , n_cols(aux_n_cols) + , n_elem_slice(aux_n_rows*aux_n_cols) + , n_slices(aux_n_slices) + , n_elem(aux_n_rows*aux_n_cols*aux_n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + arrayops::copy( memptr(), aux_mem, n_elem ); + } + + + +//! in-place cube addition +template +inline +Cube& +Cube::operator+=(const Cube& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "addition"); + + arrayops::inplace_plus( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +//! in-place cube subtraction +template +inline +Cube& +Cube::operator-=(const Cube& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "subtraction"); + + arrayops::inplace_minus( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +//! in-place element-wise cube multiplication +template +inline +Cube& +Cube::operator%=(const Cube& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "element-wise multiplication"); + + arrayops::inplace_mul( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +//! in-place element-wise cube division +template +inline +Cube& +Cube::operator/=(const Cube& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "element-wise division"); + + arrayops::inplace_div( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +//! for constructing a complex cube out of two non-complex cubes +template +template +inline +Cube::Cube + ( + const BaseCube::pod_type,T1>& A, + const BaseCube::pod_type,T2>& B + ) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init(A,B); + } + + + +//! construct a cube from a subview_cube instance (eg. construct a cube from a delayed subcube operation) +template +inline +Cube::Cube(const subview_cube& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem_slice(X.n_elem_slice) + , n_slices(X.n_slices) + , n_elem(X.n_elem) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + subview_cube::extract(*this, X); + } + + + +//! construct a cube from a subview_cube instance (eg. construct a cube from a delayed subcube operation) +template +inline +Cube& +Cube::operator=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + const bool alias = (this == &(X.m)); + + if(alias == false) + { + init_warm(X.n_rows, X.n_cols, X.n_slices); + + subview_cube::extract(*this, X); + } + else + { + Cube tmp(X); + + steal_mem(tmp); + } + + return *this; + } + + + +//! in-place cube addition (using a subcube on the right-hand-side) +template +inline +Cube& +Cube::operator+=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::plus_inplace(*this, X); + + return *this; + } + + + +//! in-place cube subtraction (using a subcube on the right-hand-side) +template +inline +Cube& +Cube::operator-=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::minus_inplace(*this, X); + + return *this; + } + + + +//! in-place element-wise cube mutiplication (using a subcube on the right-hand-side) +template +inline +Cube& +Cube::operator%=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::schur_inplace(*this, X); + + return *this; + } + + + +//! in-place element-wise cube division (using a subcube on the right-hand-side) +template +inline +Cube& +Cube::operator/=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::div_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Cube::Cube(const subview_cube_slices& X) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + subview_cube_slices::extract(*this, X); + } + + + +template +template +inline +Cube& +Cube::operator=(const subview_cube_slices& X) + { + arma_extra_debug_sigprint(); + + const bool alias = (this == &(X.m)); + + if(alias == false) + { + subview_cube_slices::extract(*this, X); + } + else + { + Cube tmp(X); + + steal_mem(tmp); + } + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator+=(const subview_cube_slices& X) + { + arma_extra_debug_sigprint(); + + subview_cube_slices::plus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator-=(const subview_cube_slices& X) + { + arma_extra_debug_sigprint(); + + subview_cube_slices::minus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator%=(const subview_cube_slices& X) + { + arma_extra_debug_sigprint(); + + subview_cube_slices::schur_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator/=(const subview_cube_slices& X) + { + arma_extra_debug_sigprint(); + + subview_cube_slices::div_inplace(*this, X); + + return *this; + } + + + +//! creation of subview_cube (subcube comprised of specified row) +template +arma_inline +subview_cube +Cube::row(const uword in_row) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_row >= n_rows), "Cube::row(): index out of bounds" ); + + return (*this).rows(in_row, in_row); + } + + + +//! creation of subview_cube (subcube comprised of specified row) +template +arma_inline +const subview_cube +Cube::row(const uword in_row) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_row >= n_rows), "Cube::row(): index out of bounds" ); + + return (*this).rows(in_row, in_row); + } + + + +//! creation of subview_cube (subcube comprised of specified column) +template +arma_inline +subview_cube +Cube::col(const uword in_col) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_col >= n_cols), "Cube::col(): index out of bounds" ); + + return (*this).cols(in_col, in_col); + } + + + +//! creation of subview_cube (subcube comprised of specified column) +template +arma_inline +const subview_cube +Cube::col(const uword in_col) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_col >= n_cols), "Cube::col(): index out of bounds" ); + + return (*this).cols(in_col, in_col); + } + + + +//! provide the reference to the matrix representing a single slice +template +inline +Mat& +Cube::slice(const uword in_slice) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_slice >= n_slices), "Cube::slice(): index out of bounds" ); + + return *(get_mat_ptr(in_slice)); + } + + + +//! provide the reference to the matrix representing a single slice +template +inline +const Mat& +Cube::slice(const uword in_slice) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_slice >= n_slices), "Cube::slice(): index out of bounds" ); + + return *(get_mat_ptr(in_slice)); + } + + + +//! creation of subview_cube (subcube comprised of specified rows) +template +arma_inline +subview_cube +Cube::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "Cube::rows(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_rows = in_row2 - in_row1 + 1; + + return subview_cube(*this, in_row1, 0, 0, subcube_n_rows, n_cols, n_slices); + } + + + +//! creation of subview_cube (subcube comprised of specified rows) +template +arma_inline +const subview_cube +Cube::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "Cube::rows(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_rows = in_row2 - in_row1 + 1; + + return subview_cube(*this, in_row1, 0, 0, subcube_n_rows, n_cols, n_slices); + } + + + +//! creation of subview_cube (subcube comprised of specified columns) +template +arma_inline +subview_cube +Cube::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "Cube::cols(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_cols = in_col2 - in_col1 + 1; + + return subview_cube(*this, 0, in_col1, 0, n_rows, subcube_n_cols, n_slices); + } + + + +//! creation of subview_cube (subcube comprised of specified columns) +template +arma_inline +const subview_cube +Cube::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "Cube::cols(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_cols = in_col2 - in_col1 + 1; + + return subview_cube(*this, 0, in_col1, 0, n_rows, subcube_n_cols, n_slices); + } + + + +//! creation of subview_cube (subcube comprised of specified slices) +template +arma_inline +subview_cube +Cube::slices(const uword in_slice1, const uword in_slice2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_slice1 > in_slice2) || (in_slice2 >= n_slices), + "Cube::slices(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_slices = in_slice2 - in_slice1 + 1; + + return subview_cube(*this, 0, 0, in_slice1, n_rows, n_cols, subcube_n_slices); + } + + + +//! creation of subview_cube (subcube comprised of specified slices) +template +arma_inline +const subview_cube +Cube::slices(const uword in_slice1, const uword in_slice2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_slice1 > in_slice2) || (in_slice2 >= n_slices), + "Cube::slices(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_slices = in_slice2 - in_slice1 + 1; + + return subview_cube(*this, 0, 0, in_slice1, n_rows, n_cols, subcube_n_slices); + } + + + +//! creation of subview_cube (generic subcube) +template +arma_inline +subview_cube +Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || + (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), + "Cube::subcube(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_rows = in_row2 - in_row1 + 1; + const uword subcube_n_cols = in_col2 - in_col1 + 1; + const uword subcube_n_slices = in_slice2 - in_slice1 + 1; + + return subview_cube(*this, in_row1, in_col1, in_slice1, subcube_n_rows, subcube_n_cols, subcube_n_slices); + } + + + +//! creation of subview_cube (generic subcube) +template +arma_inline +const subview_cube +Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || + (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), + "Cube::subcube(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_rows = in_row2 - in_row1 + 1; + const uword subcube_n_cols = in_col2 - in_col1 + 1; + const uword subcube_n_slices = in_slice2 - in_slice1 + 1; + + return subview_cube(*this, in_row1, in_col1, in_slice1, subcube_n_rows, subcube_n_cols, subcube_n_slices); + } + + + +//! creation of subview_cube (generic subcube) +template +inline +subview_cube +Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + const uword l_n_slices = n_slices; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + const uword s_n_slices = s.n_slices; + + arma_debug_check_bounds + ( + ( in_row1 >= l_n_rows) || ( in_col1 >= l_n_cols) || ( in_slice1 >= l_n_slices) + || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + s_n_slices) > l_n_slices), + "Cube::subcube(): indices or size out of bounds" + ); + + return subview_cube(*this, in_row1, in_col1, in_slice1, s_n_rows, s_n_cols, s_n_slices); + } + + + +//! creation of subview_cube (generic subcube) +template +inline +const subview_cube +Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + const uword l_n_slices = n_slices; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + const uword s_n_slices = s.n_slices; + + arma_debug_check_bounds + ( + ( in_row1 >= l_n_rows) || ( in_col1 >= l_n_cols) || ( in_slice1 >= l_n_slices) + || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + s_n_slices) > l_n_slices), + "Cube::subcube(): indices or size out of bounds" + ); + + return subview_cube(*this, in_row1, in_col1, in_slice1, s_n_rows, s_n_cols, s_n_slices); + } + + + +//! creation of subview_cube (generic subcube) +template +inline +subview_cube +Cube::subcube(const span& row_span, const span& col_span, const span& slice_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + const bool slice_all = slice_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword subcube_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword subcube_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + const uword in_slice1 = slice_all ? 0 : slice_span.a; + const uword in_slice2 = slice_span.b; + const uword subcube_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + || + ( slice_all ? false : ((in_slice1 > in_slice2) || (in_slice2 >= local_n_slices)) ) + , + "Cube::subcube(): indices out of bounds or incorrectly used" + ); + + return subview_cube(*this, in_row1, in_col1, in_slice1, subcube_n_rows, subcube_n_cols, subcube_n_slices); + } + + + +//! creation of subview_cube (generic subcube) +template +inline +const subview_cube +Cube::subcube(const span& row_span, const span& col_span, const span& slice_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + const bool slice_all = slice_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword subcube_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword subcube_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + const uword in_slice1 = slice_all ? 0 : slice_span.a; + const uword in_slice2 = slice_span.b; + const uword subcube_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + || + ( slice_all ? false : ((in_slice1 > in_slice2) || (in_slice2 >= local_n_slices)) ) + , + "Cube::subcube(): indices out of bounds or incorrectly used" + ); + + return subview_cube(*this, in_row1, in_col1, in_slice1, subcube_n_rows, subcube_n_cols, subcube_n_slices); + } + + + +template +inline +subview_cube +Cube::operator()(const span& row_span, const span& col_span, const span& slice_span) + { + arma_extra_debug_sigprint(); + + return (*this).subcube(row_span, col_span, slice_span); + } + + + +template +inline +const subview_cube +Cube::operator()(const span& row_span, const span& col_span, const span& slice_span) const + { + arma_extra_debug_sigprint(); + + return (*this).subcube(row_span, col_span, slice_span); + } + + + +template +inline +subview_cube +Cube::operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return (*this).subcube(in_row1, in_col1, in_slice1, s); + } + + + +template +inline +const subview_cube +Cube::operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const + { + arma_extra_debug_sigprint(); + + return (*this).subcube(in_row1, in_col1, in_slice1, s); + } + + + +template +arma_inline +subview_cube +Cube::tube(const uword in_row1, const uword in_col1) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + ((in_row1 >= n_rows) || (in_col1 >= n_cols)), + "Cube::tube(): indices out of bounds" + ); + + return subview_cube(*this, in_row1, in_col1, 0, 1, 1, n_slices); + } + + + +template +arma_inline +const subview_cube +Cube::tube(const uword in_row1, const uword in_col1) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + ((in_row1 >= n_rows) || (in_col1 >= n_cols)), + "Cube::tube(): indices out of bounds" + ); + + return subview_cube(*this, in_row1, in_col1, 0, 1, 1, n_slices); + } + + + +template +arma_inline +subview_cube +Cube::tube(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || + (in_row2 >= n_rows) || (in_col2 >= n_cols), + "Cube::tube(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_rows = in_row2 - in_row1 + 1; + const uword subcube_n_cols = in_col2 - in_col1 + 1; + + return subview_cube(*this, in_row1, in_col1, 0, subcube_n_rows, subcube_n_cols, n_slices); + } + + + +template +arma_inline +const subview_cube +Cube::tube(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || + (in_row2 >= n_rows) || (in_col2 >= n_cols), + "Cube::tube(): indices out of bounds or incorrectly used" + ); + + const uword subcube_n_rows = in_row2 - in_row1 + 1; + const uword subcube_n_cols = in_col2 - in_col1 + 1; + + return subview_cube(*this, in_row1, in_col1, 0, subcube_n_rows, subcube_n_cols, n_slices); + } + + + +template +arma_inline +subview_cube +Cube::tube(const uword in_row1, const uword in_col1, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "Cube::tube(): indices or size out of bounds" + ); + + return subview_cube(*this, in_row1, in_col1, 0, s_n_rows, s_n_cols, n_slices); + } + + + +template +arma_inline +const subview_cube +Cube::tube(const uword in_row1, const uword in_col1, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "Cube::tube(): indices or size out of bounds" + ); + + return subview_cube(*this, in_row1, in_col1, 0, s_n_rows, s_n_cols, n_slices); + } + + + +template +inline +subview_cube +Cube::tube(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword subcube_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword subcube_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Cube::tube(): indices out of bounds or incorrectly used" + ); + + return subview_cube(*this, in_row1, in_col1, 0, subcube_n_rows, subcube_n_cols, n_slices); + } + + + +template +inline +const subview_cube +Cube::tube(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword subcube_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword subcube_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Cube::tube(): indices out of bounds or incorrectly used" + ); + + return subview_cube(*this, in_row1, in_col1, 0, subcube_n_rows, subcube_n_cols, n_slices); + } + + + +template +inline +subview_cube +Cube::head_slices(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_slices), "Cube::head_slices(): size out of bounds" ); + + return subview_cube(*this, 0, 0, 0, n_rows, n_cols, N); + } + + + +template +inline +const subview_cube +Cube::head_slices(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_slices), "Cube::head_slices(): size out of bounds" ); + + return subview_cube(*this, 0, 0, 0, n_rows, n_cols, N); + } + + + +template +inline +subview_cube +Cube::tail_slices(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_slices), "Cube::tail_slices(): size out of bounds" ); + + const uword start_slice = n_slices - N; + + return subview_cube(*this, 0, 0, start_slice, n_rows, n_cols, N); + } + + + +template +inline +const subview_cube +Cube::tail_slices(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_slices), "Cube::tail_slices(): size out of bounds" ); + + const uword start_slice = n_slices - N; + + return subview_cube(*this, 0, 0, start_slice, n_rows, n_cols, N); + } + + + +template +template +arma_inline +subview_elem1 +Cube::elem(const Base& a) + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +template +arma_inline +const subview_elem1 +Cube::elem(const Base& a) const + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +template +arma_inline +subview_elem1 +Cube::operator()(const Base& a) + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +template +arma_inline +const subview_elem1 +Cube::operator()(const Base& a) const + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +arma_inline +subview_cube_each1 +Cube::each_slice() + { + arma_extra_debug_sigprint(); + + return subview_cube_each1(*this); + } + + + +template +arma_inline +const subview_cube_each1 +Cube::each_slice() const + { + arma_extra_debug_sigprint(); + + return subview_cube_each1(*this); + } + + + +template +template +inline +subview_cube_each2 +Cube::each_slice(const Base& indices) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2(*this, indices); + } + + + +template +template +inline +const subview_cube_each2 +Cube::each_slice(const Base& indices) const + { + arma_extra_debug_sigprint(); + + return subview_cube_each2(*this, indices); + } + + + +//! apply a lambda function to each slice, where each slice is interpreted as a matrix +template +inline +Cube& +Cube::each_slice(const std::function< void(Mat&) >& F) + { + arma_extra_debug_sigprint(); + + for(uword slice_id=0; slice_id < n_slices; ++slice_id) + { + Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); + + F(tmp); + } + + return *this; + } + + + +template +inline +const Cube& +Cube::each_slice(const std::function< void(const Mat&) >& F) const + { + arma_extra_debug_sigprint(); + + for(uword slice_id=0; slice_id < n_slices; ++slice_id) + { + const Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); + + F(tmp); + } + + return *this; + } + + + +template +inline +Cube& +Cube::each_slice(const std::function< void(Mat&) >& F, const bool use_mp) + { + arma_extra_debug_sigprint(); + + if((use_mp == false) || (arma_config::openmp == false)) + { + return (*this).each_slice(F); + } + + #if defined(ARMA_USE_OPENMP) + { + const uword local_n_slices = n_slices; + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword slice_id=0; slice_id < local_n_slices; ++slice_id) + { + Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); + + F(tmp); + } + } + #endif + + return *this; + } + + + +template +inline +const Cube& +Cube::each_slice(const std::function< void(const Mat&) >& F, const bool use_mp) const + { + arma_extra_debug_sigprint(); + + if((use_mp == false) || (arma_config::openmp == false)) + { + return (*this).each_slice(F); + } + + #if defined(ARMA_USE_OPENMP) + { + const uword local_n_slices = n_slices; + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword slice_id=0; slice_id < local_n_slices; ++slice_id) + { + Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); + + F(tmp); + } + } + #endif + + return *this; + } + + + +template +template +inline +subview_cube_slices +Cube::slices(const Base& indices) + { + arma_extra_debug_sigprint(); + + return subview_cube_slices(*this, indices); + } + + + +template +template +inline +const subview_cube_slices +Cube::slices(const Base& indices) const + { + arma_extra_debug_sigprint(); + + return subview_cube_slices(*this, indices); + } + + + +//! remove specified row +template +inline +void +Cube::shed_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= n_rows, "Cube::shed_row(): index out of bounds" ); + + shed_rows(row_num, row_num); + } + + + +//! remove specified column +template +inline +void +Cube::shed_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "Cube::shed_col(): index out of bounds" ); + + shed_cols(col_num, col_num); + } + + + +//! remove specified slice +template +inline +void +Cube::shed_slice(const uword slice_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( slice_num >= n_slices, "Cube::shed_slice(): index out of bounds" ); + + shed_slices(slice_num, slice_num); + } + + + +//! remove specified rows +template +inline +void +Cube::shed_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "Cube::shed_rows(): indices out of bounds or incorrectly used" + ); + + const uword n_keep_front = in_row1; + const uword n_keep_back = n_rows - (in_row2 + 1); + + Cube X(n_keep_front + n_keep_back, n_cols, n_slices, arma_nozeros_indicator()); + + if(n_keep_front > 0) + { + X.rows( 0, (n_keep_front-1) ) = rows( 0, (in_row1-1) ); + } + + if(n_keep_back > 0) + { + X.rows( n_keep_front, (n_keep_front+n_keep_back-1) ) = rows( (in_row2+1), (n_rows-1) ); + } + + steal_mem(X); + } + + + +//! remove specified columns +template +inline +void +Cube::shed_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "Cube::shed_cols(): indices out of bounds or incorrectly used" + ); + + const uword n_keep_front = in_col1; + const uword n_keep_back = n_cols - (in_col2 + 1); + + Cube X(n_rows, n_keep_front + n_keep_back, n_slices, arma_nozeros_indicator()); + + if(n_keep_front > 0) + { + X.cols( 0, (n_keep_front-1) ) = cols( 0, (in_col1-1) ); + } + + if(n_keep_back > 0) + { + X.cols( n_keep_front, (n_keep_front+n_keep_back-1) ) = cols( (in_col2+1), (n_cols-1) ); + } + + steal_mem(X); + } + + + +//! remove specified slices +template +inline +void +Cube::shed_slices(const uword in_slice1, const uword in_slice2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_slice1 > in_slice2) || (in_slice2 >= n_slices), + "Cube::shed_slices(): indices out of bounds or incorrectly used" + ); + + const uword n_keep_front = in_slice1; + const uword n_keep_back = n_slices - (in_slice2 + 1); + + Cube X(n_rows, n_cols, n_keep_front + n_keep_back, arma_nozeros_indicator()); + + if(n_keep_front > 0) + { + X.slices( 0, (n_keep_front-1) ) = slices( 0, (in_slice1-1) ); + } + + if(n_keep_back > 0) + { + X.slices( n_keep_front, (n_keep_front+n_keep_back-1) ) = slices( (in_slice2+1), (n_slices-1) ); + } + + steal_mem(X); + } + + + +//! remove specified slices +template +template +inline +void +Cube::shed_slices(const Base& indices) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U(indices.get_ref()); + const Mat& tmp1 = U.M; + + arma_debug_check( ((tmp1.is_vec() == false) && (tmp1.is_empty() == false)), "Cube::shed_slices(): list of indices must be a vector" ); + + if(tmp1.is_empty()) { return; } + + const Col tmp2(const_cast(tmp1.memptr()), tmp1.n_elem, false, false); + + const Col& slices_to_shed = (tmp2.is_sorted("strictascend") == false) + ? Col(unique(tmp2)) + : Col(const_cast(tmp2.memptr()), tmp2.n_elem, false, false); + + const uword* slices_to_shed_mem = slices_to_shed.memptr(); + const uword N = slices_to_shed.n_elem; + + if(arma_config::debug) + { + for(uword i=0; i= n_slices), "Cube::shed_slices(): indices out of bounds" ); + } + } + + Col tmp3(n_slices, arma_nozeros_indicator()); + + uword* tmp3_mem = tmp3.memptr(); + + uword i = 0; + uword count = 0; + + for(uword j=0; j < n_slices; ++j) + { + if(i < N) + { + if( j != slices_to_shed_mem[i] ) + { + tmp3_mem[count] = j; + ++count; + } + else + { + ++i; + } + } + else + { + tmp3_mem[count] = j; + ++count; + } + } + + const Col slices_to_keep(tmp3.memptr(), count, false, false); + + Cube X = (*this).slices(slices_to_keep); + + steal_mem(X); + } + + + +template +inline +void +Cube::insert_rows(const uword row_num, const uword N, const bool set_to_zero) + { + arma_extra_debug_sigprint(); + + arma_ignore(set_to_zero); + + (*this).insert_rows(row_num, N); + } + + + +template +inline +void +Cube::insert_rows(const uword row_num, const uword N) + { + arma_extra_debug_sigprint(); + + const uword t_n_rows = n_rows; + + const uword A_n_rows = row_num; + const uword B_n_rows = t_n_rows - row_num; + + // insertion at row_num == n_rows is in effect an append operation + arma_debug_check_bounds( (row_num > t_n_rows), "Cube::insert_rows(): index out of bounds" ); + + if(N == 0) { return; } + + Cube out(t_n_rows + N, n_cols, n_slices, arma_nozeros_indicator()); + + if(A_n_rows > 0) + { + out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); + } + + if(B_n_rows > 0) + { + out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows-1); + } + + out.rows(row_num, row_num + N - 1).zeros(); + + steal_mem(out); + } + + + +template +inline +void +Cube::insert_cols(const uword col_num, const uword N, const bool set_to_zero) + { + arma_extra_debug_sigprint(); + + arma_ignore(set_to_zero); + + (*this).insert_cols(col_num, N); + } + + + +template +inline +void +Cube::insert_cols(const uword col_num, const uword N) + { + arma_extra_debug_sigprint(); + + const uword t_n_cols = n_cols; + + const uword A_n_cols = col_num; + const uword B_n_cols = t_n_cols - col_num; + + // insertion at col_num == n_cols is in effect an append operation + arma_debug_check_bounds( (col_num > t_n_cols), "Cube::insert_cols(): index out of bounds" ); + + if(N == 0) { return; } + + Cube out(n_rows, t_n_cols + N, n_slices, arma_nozeros_indicator()); + + if(A_n_cols > 0) + { + out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); + } + + if(B_n_cols > 0) + { + out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols-1); + } + + out.cols(col_num, col_num + N - 1).zeros(); + + steal_mem(out); + } + + + +template +inline +void +Cube::insert_slices(const uword slice_num, const uword N, const bool set_to_zero) + { + arma_extra_debug_sigprint(); + + arma_ignore(set_to_zero); + + (*this).insert_slices(slice_num, N); + } + + + +template +inline +void +Cube::insert_slices(const uword slice_num, const uword N) + { + arma_extra_debug_sigprint(); + + const uword t_n_slices = n_slices; + + const uword A_n_slices = slice_num; + const uword B_n_slices = t_n_slices - slice_num; + + // insertion at slice_num == n_slices is in effect an append operation + arma_debug_check_bounds( (slice_num > t_n_slices), "Cube::insert_slices(): index out of bounds" ); + + if(N == 0) { return; } + + Cube out(n_rows, n_cols, t_n_slices + N, arma_nozeros_indicator()); + + if(A_n_slices > 0) + { + out.slices(0, A_n_slices-1) = slices(0, A_n_slices-1); + } + + if(B_n_slices > 0) + { + out.slices(slice_num + N, t_n_slices + N - 1) = slices(slice_num, t_n_slices-1); + } + + //out.slices(slice_num, slice_num + N - 1).zeros(); + + for(uword i=slice_num; i < (slice_num + N); ++i) + { + arrayops::fill_zeros(out.slice_memptr(i), out.n_elem_slice); + } + + steal_mem(out); + } + + + +template +template +inline +void +Cube::insert_rows(const uword row_num, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp(X.get_ref()); + const Cube& C = tmp.M; + + const uword N = C.n_rows; + + const uword t_n_rows = n_rows; + + const uword A_n_rows = row_num; + const uword B_n_rows = t_n_rows - row_num; + + // insertion at row_num == n_rows is in effect an append operation + arma_debug_check_bounds( (row_num > t_n_rows), "Cube::insert_rows(): index out of bounds" ); + + arma_debug_check + ( + ( (C.n_cols != n_cols) || (C.n_slices != n_slices) ), + "Cube::insert_rows(): given object has incompatible dimensions" + ); + + if(N == 0) { return; } + + Cube out(t_n_rows + N, n_cols, n_slices, arma_nozeros_indicator()); + + if(A_n_rows > 0) + { + out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); + } + + if(B_n_rows > 0) + { + out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows - 1); + } + + out.rows(row_num, row_num + N - 1) = C; + + steal_mem(out); + } + + + +template +template +inline +void +Cube::insert_cols(const uword col_num, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp(X.get_ref()); + const Cube& C = tmp.M; + + const uword N = C.n_cols; + + const uword t_n_cols = n_cols; + + const uword A_n_cols = col_num; + const uword B_n_cols = t_n_cols - col_num; + + // insertion at col_num == n_cols is in effect an append operation + arma_debug_check_bounds( (col_num > t_n_cols), "Cube::insert_cols(): index out of bounds" ); + + arma_debug_check + ( + ( (C.n_rows != n_rows) || (C.n_slices != n_slices) ), + "Cube::insert_cols(): given object has incompatible dimensions" + ); + + if(N == 0) { return; } + + Cube out(n_rows, t_n_cols + N, n_slices, arma_nozeros_indicator()); + + if(A_n_cols > 0) + { + out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); + } + + if(B_n_cols > 0) + { + out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols - 1); + } + + out.cols(col_num, col_num + N - 1) = C; + + steal_mem(out); + } + + + +//! insert the given object at the specified slice position; +//! the given object must have the same number of rows and columns as the cube +template +template +inline +void +Cube::insert_slices(const uword slice_num, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp(X.get_ref()); + const Cube& C = tmp.M; + + const uword N = C.n_slices; + + const uword t_n_slices = n_slices; + + const uword A_n_slices = slice_num; + const uword B_n_slices = t_n_slices - slice_num; + + // insertion at slice_num == n_slices is in effect an append operation + arma_debug_check_bounds( (slice_num > t_n_slices), "Cube::insert_slices(): index out of bounds" ); + + arma_debug_check + ( + ( (C.n_rows != n_rows) || (C.n_cols != n_cols) ), + "Cube::insert_slices(): given object has incompatible dimensions" + ); + + if(N == 0) { return; } + + Cube out(n_rows, n_cols, t_n_slices + N, arma_nozeros_indicator()); + + if(A_n_slices > 0) + { + out.slices(0, A_n_slices-1) = slices(0, A_n_slices-1); + } + + if(B_n_slices > 0) + { + out.slices(slice_num + N, t_n_slices + N - 1) = slices(slice_num, t_n_slices - 1); + } + + out.slices(slice_num, slice_num + N - 1) = C; + + steal_mem(out); + } + + + +template +template +inline +void +Cube::insert_slices(const uword slice_num, const Base& X) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U(X.get_ref()); + + const Cube C(const_cast(U.M.memptr()), U.M.n_rows, U.M.n_cols, uword(1), false, true); + + (*this).insert_slices(slice_num, C); + } + + + +//! create a cube from GenCube, ie. run the previously delayed element generation operations +template +template +inline +Cube::Cube(const GenCube& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem_slice(X.n_rows*X.n_cols) + , n_slices(X.n_slices) + , n_elem(X.n_rows*X.n_cols*X.n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + X.apply(*this); + } + + + +template +template +inline +Cube& +Cube::operator=(const GenCube& X) + { + arma_extra_debug_sigprint(); + + init_warm(X.n_rows, X.n_cols, X.n_slices); + + X.apply(*this); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator+=(const GenCube& X) + { + arma_extra_debug_sigprint(); + + X.apply_inplace_plus(*this); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator-=(const GenCube& X) + { + arma_extra_debug_sigprint(); + + X.apply_inplace_minus(*this); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator%=(const GenCube& X) + { + arma_extra_debug_sigprint(); + + X.apply_inplace_schur(*this); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator/=(const GenCube& X) + { + arma_extra_debug_sigprint(); + + X.apply_inplace_div(*this); + + return *this; + } + + + +//! create a cube from OpCube, ie. run the previously delayed unary operations +template +template +inline +Cube::Cube(const OpCube& X) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + } + + + +//! create a cube from OpCube, ie. run the previously delayed unary operations +template +template +inline +Cube& +Cube::operator=(const OpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + + return *this; + } + + + +//! in-place cube addition, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator+=(const OpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Cube m(X); + + return (*this).operator+=(m); + } + + + +//! in-place cube subtraction, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator-=(const OpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Cube m(X); + + return (*this).operator-=(m); + } + + + +//! in-place cube element-wise multiplication, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator%=(const OpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Cube m(X); + + return (*this).operator%=(m); + } + + + +//! in-place cube element-wise division, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator/=(const OpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Cube m(X); + + return (*this).operator/=(m); + } + + + +//! create a cube from eOpCube, ie. run the previously delayed unary operations +template +template +inline +Cube::Cube(const eOpCube& X) + : n_rows(X.get_n_rows()) + , n_cols(X.get_n_cols()) + , n_elem_slice(X.get_n_elem_slice()) + , n_slices(X.get_n_slices()) + , n_elem(X.get_n_elem()) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + init_cold(); + + eop_type::apply(*this, X); + } + + + +//! create a cube from eOpCube, ie. run the previously delayed unary operations +template +template +inline +Cube& +Cube::operator=(const eOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { Cube tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols(), X.get_n_slices()); + + eop_type::apply(*this, X); + + return *this; + } + + + +//! in-place cube addition, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator+=(const eOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator+=(tmp); } + + eop_type::apply_inplace_plus(*this, X); + + return *this; + } + + + +//! in-place cube subtraction, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator-=(const eOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator-=(tmp); } + + eop_type::apply_inplace_minus(*this, X); + + return *this; + } + + + +//! in-place cube element-wise multiplication, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator%=(const eOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator%=(tmp); } + + eop_type::apply_inplace_schur(*this, X); + + return *this; + } + + + +//! in-place cube element-wise division, with the right-hand-side operand having delayed operations +template +template +inline +Cube& +Cube::operator/=(const eOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator/=(tmp); } + + eop_type::apply_inplace_div(*this, X); + + return *this; + } + + + +template +template +inline +Cube::Cube(const mtOpCube& X) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + op_type::apply(*this, X); + } + + + +template +template +inline +Cube& +Cube::operator=(const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + op_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator+=(const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +Cube& +Cube::operator-=(const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +Cube& +Cube::operator%=(const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +Cube& +Cube::operator/=(const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator/=(m); + } + + + +//! create a cube from GlueCube, ie. run the previously delayed binary operations +template +template +inline +Cube::Cube(const GlueCube& X) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + this->operator=(X); + } + + + +//! create a cube from GlueCube, ie. run the previously delayed binary operations +template +template +inline +Cube& +Cube::operator=(const GlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_type::apply(*this, X); + + return *this; + } + + +//! in-place cube addition, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator+=(const GlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Cube m(X); + + return (*this).operator+=(m); + } + + + +//! in-place cube subtraction, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator-=(const GlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Cube m(X); + + return (*this).operator-=(m); + } + + + +//! in-place cube element-wise multiplication, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator%=(const GlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Cube m(X); + + return (*this).operator%=(m); + } + + + +//! in-place cube element-wise division, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator/=(const GlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Cube m(X); + + return (*this).operator/=(m); + } + + + +//! create a cube from eGlueCube, ie. run the previously delayed binary operations +template +template +inline +Cube::Cube(const eGlueCube& X) + : n_rows(X.get_n_rows()) + , n_cols(X.get_n_cols()) + , n_elem_slice(X.get_n_elem_slice()) + , n_slices(X.get_n_slices()) + , n_elem(X.get_n_elem()) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + init_cold(); + + eglue_type::apply(*this, X); + } + + + +//! create a cube from eGlueCube, ie. run the previously delayed binary operations +template +template +inline +Cube& +Cube::operator=(const eGlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { Cube tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols(), X.get_n_slices()); + + eglue_type::apply(*this, X); + + return *this; + } + + + +//! in-place cube addition, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator+=(const eGlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator+=(tmp); } + + eglue_type::apply_inplace_plus(*this, X); + + return *this; + } + + + +//! in-place cube subtraction, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator-=(const eGlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator-=(tmp); } + + eglue_type::apply_inplace_minus(*this, X); + + return *this; + } + + + +//! in-place cube element-wise multiplication, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator%=(const eGlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator%=(tmp); } + + eglue_type::apply_inplace_schur(*this, X); + + return *this; + } + + + +//! in-place cube element-wise division, with the right-hand-side operands having delayed operations +template +template +inline +Cube& +Cube::operator/=(const eGlueCube& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator/=(tmp); } + + eglue_type::apply_inplace_div(*this, X); + + return *this; + } + + + +template +template +inline +Cube::Cube(const mtGlueCube& X) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + glue_type::apply(*this, X); + } + + + +template +template +inline +Cube& +Cube::operator=(const mtGlueCube& X) + { + arma_extra_debug_sigprint(); + + glue_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Cube& +Cube::operator+=(const mtGlueCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +Cube& +Cube::operator-=(const mtGlueCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +Cube& +Cube::operator%=(const mtGlueCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +Cube& +Cube::operator/=(const mtGlueCube& X) + { + arma_extra_debug_sigprint(); + + const Cube m(X); + + return (*this).operator/=(m); + } + + + +//! linear element accessor (treats the cube as a vector); no bounds check; assumes memory is aligned +template +arma_inline +const eT& +Cube::at_alt(const uword i) const + { + const eT* mem_aligned = mem; + + memory::mark_as_aligned(mem_aligned); + + return mem_aligned[i]; + } + + + +//! linear element accessor (treats the cube as a vector); bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +eT& +Cube::operator() (const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "Cube::operator(): index out of bounds" ); + + return access::rw(mem[i]); + } + + + +//! linear element accessor (treats the cube as a vector); bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +const eT& +Cube::operator() (const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "Cube::operator(): index out of bounds" ); + + return mem[i]; + } + + +//! linear element accessor (treats the cube as a vector); no bounds check. +template +arma_inline +eT& +Cube::operator[] (const uword i) + { + return access::rw(mem[i]); + } + + + +//! linear element accessor (treats the cube as a vector); no bounds check +template +arma_inline +const eT& +Cube::operator[] (const uword i) const + { + return mem[i]; + } + + + +//! linear element accessor (treats the cube as a vector); no bounds check. +template +arma_inline +eT& +Cube::at(const uword i) + { + return access::rw(mem[i]); + } + + + +//! linear element accessor (treats the cube as a vector); no bounds check +template +arma_inline +const eT& +Cube::at(const uword i) const + { + return mem[i]; + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +eT& +Cube::operator() (const uword in_row, const uword in_col, const uword in_slice) + { + arma_debug_check_bounds + ( + (in_row >= n_rows) || + (in_col >= n_cols) || + (in_slice >= n_slices) + , + "Cube::operator(): index out of bounds" + ); + + return access::rw(mem[in_slice*n_elem_slice + in_col*n_rows + in_row]); + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +const eT& +Cube::operator() (const uword in_row, const uword in_col, const uword in_slice) const + { + arma_debug_check_bounds + ( + (in_row >= n_rows) || + (in_col >= n_cols) || + (in_slice >= n_slices) + , + "Cube::operator(): index out of bounds" + ); + + return mem[in_slice*n_elem_slice + in_col*n_rows + in_row]; + } + + + +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + eT& + Cube::operator[] (const uword in_row, const uword in_col, const uword in_slice) + { + return access::rw( mem[in_slice*n_elem_slice + in_col*n_rows + in_row] ); + } + + + + //! element accessor; no bounds check + template + arma_inline + const eT& + Cube::operator[] (const uword in_row, const uword in_col, const uword in_slice) const + { + return mem[in_slice*n_elem_slice + in_col*n_rows + in_row]; + } + +#endif + + + +//! element accessor; no bounds check +template +arma_inline +eT& +Cube::at(const uword in_row, const uword in_col, const uword in_slice) + { + return access::rw( mem[in_slice*n_elem_slice + in_col*n_rows + in_row] ); + } + + + +//! element accessor; no bounds check +template +arma_inline +const eT& +Cube::at(const uword in_row, const uword in_col, const uword in_slice) const + { + return mem[in_slice*n_elem_slice + in_col*n_rows + in_row]; + } + + + +//! prefix ++ +template +arma_inline +const Cube& +Cube::operator++() + { + Cube_aux::prefix_pp(*this); + + return *this; + } + + + +//! postfix ++ (must not return the object by reference) +template +arma_inline +void +Cube::operator++(int) + { + Cube_aux::postfix_pp(*this); + } + + + +//! prefix -- +template +arma_inline +const Cube& +Cube::operator--() + { + Cube_aux::prefix_mm(*this); + return *this; + } + + + +//! postfix -- (must not return the object by reference) +template +arma_inline +void +Cube::operator--(int) + { + Cube_aux::postfix_mm(*this); + } + + + +//! returns true if the cube has no elements +template +arma_inline +bool +Cube::is_empty() const + { + return (n_elem == 0); + } + + + +template +inline +bool +Cube::internal_is_finite() const + { + arma_extra_debug_sigprint(); + + return arrayops::is_finite(memptr(), n_elem); + } + + + +template +inline +bool +Cube::internal_has_inf() const + { + arma_extra_debug_sigprint(); + + return arrayops::has_inf(memptr(), n_elem); + } + + + +template +inline +bool +Cube::internal_has_nan() const + { + arma_extra_debug_sigprint(); + + return arrayops::has_nan(memptr(), n_elem); + } + + + +template +inline +bool +Cube::internal_has_nonfinite() const + { + arma_extra_debug_sigprint(); + + return (arrayops::is_finite(memptr(), n_elem) == false); + } + + + +//! returns true if the given index is currently in range +template +arma_inline +bool +Cube::in_range(const uword i) const + { + return (i < n_elem); + } + + + +//! returns true if the given start and end indices are currently in range +template +arma_inline +bool +Cube::in_range(const span& x) const + { + arma_extra_debug_sigprint(); + + if(x.whole) + { + return true; + } + else + { + const uword a = x.a; + const uword b = x.b; + + return ( (a <= b) && (b < n_elem) ); + } + } + + + +//! returns true if the given location is currently in range +template +arma_inline +bool +Cube::in_range(const uword in_row, const uword in_col, const uword in_slice) const + { + return ( (in_row < n_rows) && (in_col < n_cols) && (in_slice < n_slices) ); + } + + + +template +inline +bool +Cube::in_range(const span& row_span, const span& col_span, const span& slice_span) const + { + arma_extra_debug_sigprint(); + + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + const uword in_slice1 = slice_span.a; + const uword in_slice2 = slice_span.b; + + + const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) ); + const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) ); + const bool slices_ok = slice_span.whole ? true : ( (in_slice1 <= in_slice2) && (in_slice2 < n_slices) ); + + + return ( rows_ok && cols_ok && slices_ok ); + } + + + +template +inline +bool +Cube::in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const + { + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + const uword l_n_slices = n_slices; + + if( + ( in_row >= l_n_rows) || ( in_col >= l_n_cols) || ( in_slice >= l_n_slices) + || ((in_row + s.n_rows) > l_n_rows) || ((in_col + s.n_cols) > l_n_cols) || ((in_slice + s.n_slices) > l_n_slices) + ) + { + return false; + } + else + { + return true; + } + } + + + +//! returns a pointer to array of eTs used by the cube +template +arma_inline +eT* +Cube::memptr() + { + return const_cast(mem); + } + + + +//! returns a pointer to array of eTs used by the cube +template +arma_inline +const eT* +Cube::memptr() const + { + return mem; + } + + + +//! returns a pointer to array of eTs used by the specified slice in the cube +template +arma_inline +eT* +Cube::slice_memptr(const uword uslice) + { + return const_cast( &mem[ uslice*n_elem_slice ] ); + } + + + +//! returns a pointer to array of eTs used by the specified slice in the cube +template +arma_inline +const eT* +Cube::slice_memptr(const uword uslice) const + { + return &mem[ uslice*n_elem_slice ]; + } + + + +//! returns a pointer to array of eTs used by the specified slice in the cube +template +arma_inline +eT* +Cube::slice_colptr(const uword uslice, const uword col) + { + return const_cast( &mem[ uslice*n_elem_slice + col*n_rows] ); + } + + + +//! returns a pointer to array of eTs used by the specified slice in the cube +template +arma_inline +const eT* +Cube::slice_colptr(const uword uslice, const uword col) const + { + return &mem[ uslice*n_elem_slice + col*n_rows ]; + } + + + +//! change the cube to have user specified dimensions (data is not preserved) +template +inline +Cube& +Cube::set_size(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + init_warm(new_n_rows, new_n_cols, new_n_slices); + + return *this; + } + + + +//! change the cube to have user specified dimensions (data is preserved) +template +inline +Cube& +Cube::reshape(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + op_reshape::apply_cube_inplace((*this), new_n_rows, new_n_cols, new_n_slices); + + return *this; + } + + + +//! change the cube to have user specified dimensions (data is preserved) +template +inline +Cube& +Cube::resize(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + op_resize::apply_cube_inplace((*this), new_n_rows, new_n_cols, new_n_slices); + + return *this; + } + + + +template +inline +Cube& +Cube::set_size(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + init_warm(s.n_rows, s.n_cols, s.n_slices); + + return *this; + } + + + +template +inline +Cube& +Cube::reshape(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + op_reshape::apply_cube_inplace((*this), s.n_rows, s.n_cols, s.n_slices); + + return *this; + } + + + +template +inline +Cube& +Cube::resize(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + op_resize::apply_cube_inplace((*this), s.n_rows, s.n_cols, s.n_slices); + + return *this; + } + + + +//! change the cube (without preserving data) to have the same dimensions as the given cube +template +template +inline +Cube& +Cube::copy_size(const Cube& m) + { + arma_extra_debug_sigprint(); + + init_warm(m.n_rows, m.n_cols, m.n_slices); + + return *this; + } + + + +//! apply a functor to each element +template +template +inline +Cube& +Cube::for_each(functor F) + { + arma_extra_debug_sigprint(); + + eT* data = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + F(data[ii]); + F(data[jj]); + } + + if(ii < N) + { + F(data[ii]); + } + + return *this; + } + + + +template +template +inline +const Cube& +Cube::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + const eT* data = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + F(data[ii]); + F(data[jj]); + } + + if(ii < N) + { + F(data[ii]); + } + + return *this; + } + + + +//! transform each element in the cube using a functor +template +template +inline +Cube& +Cube::transform(functor F) + { + arma_extra_debug_sigprint(); + + eT* out_mem = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + eT tmp_ii = out_mem[ii]; + eT tmp_jj = out_mem[jj]; + + tmp_ii = eT( F(tmp_ii) ); + tmp_jj = eT( F(tmp_jj) ); + + out_mem[ii] = tmp_ii; + out_mem[jj] = tmp_jj; + } + + if(ii < N) + { + out_mem[ii] = eT( F(out_mem[ii]) ); + } + + return *this; + } + + + +//! imbue (fill) the cube with values provided by a functor +template +template +inline +Cube& +Cube::imbue(functor F) + { + arma_extra_debug_sigprint(); + + eT* out_mem = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + const eT tmp_ii = eT( F() ); + const eT tmp_jj = eT( F() ); + + out_mem[ii] = tmp_ii; + out_mem[jj] = tmp_jj; + } + + if(ii < N) + { + out_mem[ii] = eT( F() ); + } + + return *this; + } + + + +template +inline +Cube& +Cube::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + arrayops::replace(memptr(), n_elem, old_val, new_val); + + return *this; + } + + + +template +inline +Cube& +Cube::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + arrayops::clean(memptr(), n_elem, threshold); + + return *this; + } + + + +template +inline +Cube& +Cube::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Cube::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Cube::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "Cube::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + arrayops::clamp(memptr(), n_elem, min_val, max_val); + + return *this; + } + + + +//! fill the cube with the specified value +template +inline +Cube& +Cube::fill(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_set( memptr(), val, n_elem ); + + return *this; + } + + + +template +inline +Cube& +Cube::zeros() + { + arma_extra_debug_sigprint(); + + arrayops::fill_zeros(memptr(), n_elem); + + return *this; + } + + + +template +inline +Cube& +Cube::zeros(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols, new_n_slices); + + return (*this).zeros(); + } + + + +template +inline +Cube& +Cube::zeros(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return (*this).zeros(s.n_rows, s.n_cols, s.n_slices); + } + + + +template +inline +Cube& +Cube::ones() + { + arma_extra_debug_sigprint(); + + return (*this).fill(eT(1)); + } + + + +template +inline +Cube& +Cube::ones(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols, new_n_slices); + + return (*this).fill(eT(1)); + } + + + +template +inline +Cube& +Cube::ones(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return (*this).ones(s.n_rows, s.n_cols, s.n_slices); + } + + + +template +inline +Cube& +Cube::randu() + { + arma_extra_debug_sigprint(); + + arma_rng::randu::fill( memptr(), n_elem ); + + return *this; + } + + + +template +inline +Cube& +Cube::randu(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols, new_n_slices); + + return (*this).randu(); + } + + + +template +inline +Cube& +Cube::randu(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return (*this).randu(s.n_rows, s.n_cols, s.n_slices); + } + + + +template +inline +Cube& +Cube::randn() + { + arma_extra_debug_sigprint(); + + arma_rng::randn::fill( memptr(), n_elem ); + + return *this; + } + + + +template +inline +Cube& +Cube::randn(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols, new_n_slices); + + return (*this).randn(); + } + + + +template +inline +Cube& +Cube::randn(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return (*this).randn(s.n_rows, s.n_cols, s.n_slices); + } + + + +template +inline +void +Cube::reset() + { + arma_extra_debug_sigprint(); + + init_warm(0,0,0); + } + + + +template +inline +void +Cube::soft_reset() + { + arma_extra_debug_sigprint(); + + // don't change the size if the cube has a fixed size + if(mem_state <= 1) + { + reset(); + } + else + { + zeros(); + } + } + + + +template +template +inline +void +Cube::set_real(const BaseCube::pod_type,T1>& X) + { + arma_extra_debug_sigprint(); + + Cube_aux::set_real(*this, X); + } + + + +template +template +inline +void +Cube::set_imag(const BaseCube::pod_type,T1>& X) + { + arma_extra_debug_sigprint(); + + Cube_aux::set_imag(*this, X); + } + + + +template +inline +eT +Cube::min() const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Cube::min(): object has no elements"); + + return Datum::nan; + } + + return op_min::direct_min(memptr(), n_elem); + } + + + +template +inline +eT +Cube::max() const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Cube::max(): object has no elements"); + + return Datum::nan; + } + + return op_max::direct_max(memptr(), n_elem); + } + + + +template +inline +eT +Cube::min(uword& index_of_min_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Cube::min(): object has no elements"); + + index_of_min_val = uword(0); + + return Datum::nan; + } + + return op_min::direct_min(memptr(), n_elem, index_of_min_val); + } + + + +template +inline +eT +Cube::max(uword& index_of_max_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Cube::max(): object has no elements"); + + index_of_max_val = uword(0); + + return Datum::nan; + } + + return op_max::direct_max(memptr(), n_elem, index_of_max_val); + } + + + +template +inline +eT +Cube::min(uword& row_of_min_val, uword& col_of_min_val, uword& slice_of_min_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Cube::min(): object has no elements"); + + row_of_min_val = uword(0); + col_of_min_val = uword(0); + slice_of_min_val = uword(0); + + return Datum::nan; + } + + uword i; + + eT val = op_min::direct_min(memptr(), n_elem, i); + + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + row_of_min_val = j % n_rows; + col_of_min_val = j / n_rows; + slice_of_min_val = in_slice; + + return val; + } + + + +template +inline +eT +Cube::max(uword& row_of_max_val, uword& col_of_max_val, uword& slice_of_max_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Cube::max(): object has no elements"); + + row_of_max_val = uword(0); + col_of_max_val = uword(0); + slice_of_max_val = uword(0); + + return Datum::nan; + } + + uword i; + + eT val = op_max::direct_max(memptr(), n_elem, i); + + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + row_of_max_val = j % n_rows; + col_of_max_val = j / n_rows; + slice_of_max_val = in_slice; + + return val; + } + + + +//! save the cube to a file +template +inline +bool +Cube::save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + bool save_okay = false; + + switch(type) + { + case raw_ascii: + save_okay = diskio::save_raw_ascii(*this, name); + break; + + case arma_ascii: + save_okay = diskio::save_arma_ascii(*this, name); + break; + + case raw_binary: + save_okay = diskio::save_raw_binary(*this, name); + break; + + case arma_binary: + save_okay = diskio::save_arma_binary(*this, name); + break; + + case ppm_binary: + save_okay = diskio::save_ppm_binary(*this, name); + break; + + case hdf5_binary: + return (*this).save(hdf5_name(name)); + break; + + case hdf5_binary_trans: // kept for compatibility with earlier versions of Armadillo + return (*this).save(hdf5_name(name, std::string(), hdf5_opts::trans)); + break; + + default: + arma_debug_warn_level(1, "Cube::save(): unsupported file type"); + save_okay = false; + } + + if(save_okay == false) { arma_debug_warn_level(3, "Cube::save(): write failed; file: ", name); } + + return save_okay; + } + + + +template +inline +bool +Cube::save(const hdf5_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + // handling of hdf5_binary_trans kept for compatibility with earlier versions of Armadillo + + if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) + { + arma_stop_runtime_error("Cube::save(): unsupported file type for hdf5_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & hdf5_opts::flag_trans ) || (type == hdf5_binary_trans); + const bool append = bool(spec.opts.flags & hdf5_opts::flag_append ); + const bool replace = bool(spec.opts.flags & hdf5_opts::flag_replace); + + if(append && replace) + { + arma_stop_runtime_error("Cube::save(): only one of 'append' or 'replace' options can be used"); + return false; + } + + bool save_okay = false; + std::string err_msg; + + if(do_trans) + { + Cube tmp; + + op_strans_cube::apply_noalias(tmp, (*this)); + + save_okay = diskio::save_hdf5_binary(tmp, spec, err_msg); + } + else + { + save_okay = diskio::save_hdf5_binary(*this, spec, err_msg); + } + + if(save_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Cube::save(): ", err_msg, "; file: ", spec.filename); + } + else + { + arma_debug_warn_level(3, "Cube::save(): write failed; file: ", spec.filename); + } + } + + return save_okay; + } + + + +//! save the cube to a stream +template +inline +bool +Cube::save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + bool save_okay = false; + + switch(type) + { + case raw_ascii: + save_okay = diskio::save_raw_ascii(*this, os); + break; + + case arma_ascii: + save_okay = diskio::save_arma_ascii(*this, os); + break; + + case raw_binary: + save_okay = diskio::save_raw_binary(*this, os); + break; + + case arma_binary: + save_okay = diskio::save_arma_binary(*this, os); + break; + + case ppm_binary: + save_okay = diskio::save_ppm_binary(*this, os); + break; + + default: + arma_debug_warn_level(1, "Cube::save(): unsupported file type"); + save_okay = false; + } + + if(save_okay == false) { arma_debug_warn_level(3, "Cube::save(): stream write failed"); } + + return save_okay; + } + + + +//! load a cube from a file +template +inline +bool +Cube::load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + bool load_okay = false; + std::string err_msg; + + switch(type) + { + case auto_detect: + load_okay = diskio::load_auto_detect(*this, name, err_msg); + break; + + case raw_ascii: + load_okay = diskio::load_raw_ascii(*this, name, err_msg); + break; + + case arma_ascii: + load_okay = diskio::load_arma_ascii(*this, name, err_msg); + break; + + case raw_binary: + load_okay = diskio::load_raw_binary(*this, name, err_msg); + break; + + case arma_binary: + load_okay = diskio::load_arma_binary(*this, name, err_msg); + break; + + case ppm_binary: + load_okay = diskio::load_ppm_binary(*this, name, err_msg); + break; + + case hdf5_binary: + return (*this).load(hdf5_name(name)); + break; + + case hdf5_binary_trans: // kept for compatibility with earlier versions of Armadillo + return (*this).load(hdf5_name(name, std::string(), hdf5_opts::trans)); + break; + + default: + arma_debug_warn_level(1, "Cube::load(): unsupported file type"); + load_okay = false; + } + + if(load_okay == false) + { + (*this).soft_reset(); + + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Cube::load(): ", err_msg, "; file: ", name); + } + else + { + arma_debug_warn_level(3, "Cube::load(): read failed; file: ", name); + } + } + + return load_okay; + } + + + +template +inline +bool +Cube::load(const hdf5_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) + { + arma_stop_runtime_error("Cube::load(): unsupported file type for hdf5_name()"); + return false; + } + + bool load_okay = false; + std::string err_msg; + + const bool do_trans = bool(spec.opts.flags & hdf5_opts::flag_trans) || (type == hdf5_binary_trans); + + if(do_trans) + { + Cube tmp; + + load_okay = diskio::load_hdf5_binary(tmp, spec, err_msg); + + if(load_okay) { op_strans_cube::apply_noalias((*this), tmp); } + } + else + { + load_okay = diskio::load_hdf5_binary(*this, spec, err_msg); + } + + + if(load_okay == false) + { + (*this).soft_reset(); + + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Cube::load(): ", err_msg, "; file: ", spec.filename); + } + else + { + arma_debug_warn_level(3, "Cube::load(): read failed; file: ", spec.filename); + } + } + + return load_okay; + } + + + +//! load a cube from a stream +template +inline +bool +Cube::load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + bool load_okay = false; + std::string err_msg; + + switch(type) + { + case auto_detect: + load_okay = diskio::load_auto_detect(*this, is, err_msg); + break; + + case raw_ascii: + load_okay = diskio::load_raw_ascii(*this, is, err_msg); + break; + + case arma_ascii: + load_okay = diskio::load_arma_ascii(*this, is, err_msg); + break; + + case raw_binary: + load_okay = diskio::load_raw_binary(*this, is, err_msg); + break; + + case arma_binary: + load_okay = diskio::load_arma_binary(*this, is, err_msg); + break; + + case ppm_binary: + load_okay = diskio::load_ppm_binary(*this, is, err_msg); + break; + + default: + arma_debug_warn_level(1, "Cube::load(): unsupported file type"); + load_okay = false; + } + + if(load_okay == false) + { + (*this).soft_reset(); + + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Cube::load(): ", err_msg); + } + else + { + arma_debug_warn_level(3, "Cube::load(): stream read failed"); + } + } + + return load_okay; + } + + + +template +inline +bool +Cube::quiet_save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(name, type); + } + + + +template +inline +bool +Cube::quiet_save(const hdf5_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(spec, type); + } + + + +template +inline +bool +Cube::quiet_save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(os, type); + } + + + +template +inline +bool +Cube::quiet_load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(name, type); + } + + + +template +inline +bool +Cube::quiet_load(const hdf5_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(spec, type); + } + + + +template +inline +bool +Cube::quiet_load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(is, type); + } + + + +template +inline +typename Cube::iterator +Cube::begin() + { + arma_extra_debug_sigprint(); + + return memptr(); + } + + + +template +inline +typename Cube::const_iterator +Cube::begin() const + { + arma_extra_debug_sigprint(); + + return memptr(); + } + + + +template +inline +typename Cube::const_iterator +Cube::cbegin() const + { + arma_extra_debug_sigprint(); + + return memptr(); + } + + + +template +inline +typename Cube::iterator +Cube::end() + { + arma_extra_debug_sigprint(); + + return memptr() + n_elem; + } + + + +template +inline +typename Cube::const_iterator +Cube::end() const + { + arma_extra_debug_sigprint(); + + return memptr() + n_elem; + } + + + +template +inline +typename Cube::const_iterator +Cube::cend() const + { + arma_extra_debug_sigprint(); + + return memptr() + n_elem; + } + + + +template +inline +typename Cube::slice_iterator +Cube::begin_slice(const uword slice_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (slice_num >= n_slices), "begin_slice(): index out of bounds" ); + + return slice_memptr(slice_num); + } + + + +template +inline +typename Cube::const_slice_iterator +Cube::begin_slice(const uword slice_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (slice_num >= n_slices), "begin_slice(): index out of bounds" ); + + return slice_memptr(slice_num); + } + + + +template +inline +typename Cube::slice_iterator +Cube::end_slice(const uword slice_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (slice_num >= n_slices), "end_slice(): index out of bounds" ); + + return slice_memptr(slice_num) + n_elem_slice; + } + + + +template +inline +typename Cube::const_slice_iterator +Cube::end_slice(const uword slice_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (slice_num >= n_slices), "end_slice(): index out of bounds" ); + + return slice_memptr(slice_num) + n_elem_slice; + } + + + +//! resets this cube to an empty matrix +template +inline +void +Cube::clear() + { + reset(); + } + + + +//! returns true if the cube has no elements +template +inline +bool +Cube::empty() const + { + return (n_elem == 0); + } + + + +//! returns the number of elements in this cube +template +inline +uword +Cube::size() const + { + return n_elem; + } + + + +template +inline +eT& +Cube::front() + { + arma_debug_check( (n_elem == 0), "Cube::front(): cube is empty" ); + + return access::rw(mem[0]); + } + + + +template +inline +const eT& +Cube::front() const + { + arma_debug_check( (n_elem == 0), "Cube::front(): cube is empty" ); + + return mem[0]; + } + + + +template +inline +eT& +Cube::back() + { + arma_debug_check( (n_elem == 0), "Cube::back(): cube is empty" ); + + return access::rw(mem[n_elem-1]); + } + + + +template +inline +const eT& +Cube::back() const + { + arma_debug_check( (n_elem == 0), "Cube::back(): cube is empty" ); + + return mem[n_elem-1]; + } + + + +template +inline +void +Cube::swap(Cube& B) + { + Cube& A = (*this); + + arma_extra_debug_sigprint(arma_str::format("A = %x B = %x") % &A % &B); + + if( (A.mem_state == 0) && (B.mem_state == 0) && (A.n_elem > Cube_prealloc::mem_n_elem) && (B.n_elem > Cube_prealloc::mem_n_elem) ) + { + A.delete_mat(); + B.delete_mat(); + + std::swap( access::rw(A.n_rows), access::rw(B.n_rows) ); + std::swap( access::rw(A.n_cols), access::rw(B.n_cols) ); + std::swap( access::rw(A.n_elem_slice), access::rw(B.n_elem_slice) ); + std::swap( access::rw(A.n_slices), access::rw(B.n_slices) ); + std::swap( access::rw(A.n_elem), access::rw(B.n_elem) ); + std::swap( access::rw(A.mem), access::rw(B.mem) ); + + A.create_mat(); + B.create_mat(); + } + else + if( (A.mem_state == 0) && (B.mem_state == 0) && (A.n_elem <= Cube_prealloc::mem_n_elem) && (B.n_elem <= Cube_prealloc::mem_n_elem) ) + { + A.delete_mat(); + B.delete_mat(); + + std::swap( access::rw(A.n_rows), access::rw(B.n_rows) ); + std::swap( access::rw(A.n_cols), access::rw(B.n_cols) ); + std::swap( access::rw(A.n_elem_slice), access::rw(B.n_elem_slice) ); + std::swap( access::rw(A.n_slices), access::rw(B.n_slices) ); + std::swap( access::rw(A.n_elem), access::rw(B.n_elem) ); + + const uword N = (std::max)(A.n_elem, B.n_elem); + + eT* A_mem = A.memptr(); + eT* B_mem = B.memptr(); + + for(uword i=0; i C = A; + + A.steal_mem(B); + B.steal_mem(C); + } + else + { + Cube C = B; + + B.steal_mem(A); + A.steal_mem(C); + } + } + } + + + +//! try to steal the memory from a given cube; +//! if memory can't be stolen, copy the given cube +template +inline +void +Cube::steal_mem(Cube& x) + { + arma_extra_debug_sigprint(); + + (*this).steal_mem(x, false); + } + + + +template +inline +void +Cube::steal_mem(Cube& x, const bool is_move) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + if( (mem_state <= 1) && ( (x.n_alloc > Cube_prealloc::mem_n_elem) || (x.mem_state == 1) || (is_move && (x.mem_state == 2)) ) ) + { + arma_extra_debug_print("Cube::steal_mem(): stealing memory"); + + reset(); + + const uword x_n_slices = x.n_slices; + + access::rw(n_rows) = x.n_rows; + access::rw(n_cols) = x.n_cols; + access::rw(n_elem_slice) = x.n_elem_slice; + access::rw(n_slices) = x_n_slices; + access::rw(n_elem) = x.n_elem; + access::rw(n_alloc) = x.n_alloc; + access::rw(mem_state) = x.mem_state; + access::rw(mem) = x.mem; + + if(x_n_slices > Cube_prealloc::mat_ptrs_size) + { + arma_extra_debug_print("Cube::steal_mem(): stealing mat_ptrs array"); + + mat_ptrs = x.mat_ptrs; + x.mat_ptrs = nullptr; + } + else + { + arma_extra_debug_print("Cube::steal_mem(): copying mat_ptrs array"); + + mat_ptrs = mat_ptrs_local; + + for(uword i=0; i < x_n_slices; ++i) + { + mat_ptrs[i] = raw_mat_ptr_type(x.mat_ptrs[i]); // cast required by std::atomic + x.mat_ptrs[i] = nullptr; + } + } + + access::rw(x.n_rows) = 0; + access::rw(x.n_cols) = 0; + access::rw(x.n_elem_slice) = 0; + access::rw(x.n_slices) = 0; + access::rw(x.n_elem) = 0; + access::rw(x.n_alloc) = 0; + access::rw(x.mem_state) = 0; + access::rw(x.mem) = nullptr; + } + else + { + arma_extra_debug_print("Cube::steal_mem(): copying memory"); + + (*this).operator=(x); + + if( (is_move) && (x.mem_state == 0) && (x.n_alloc <= Cube_prealloc::mem_n_elem) ) + { + x.reset(); + } + } + } + + + +// +// Cube::fixed + + + +template +template +arma_inline +void +Cube::fixed::mem_setup() + { + arma_extra_debug_sigprint(); + + if(fixed_n_elem > 0) + { + access::rw(Cube::n_rows) = fixed_n_rows; + access::rw(Cube::n_cols) = fixed_n_cols; + access::rw(Cube::n_elem_slice) = fixed_n_rows * fixed_n_cols; + access::rw(Cube::n_slices) = fixed_n_slices; + access::rw(Cube::n_elem) = fixed_n_elem; + access::rw(Cube::n_alloc) = 0; + access::rw(Cube::mem_state) = 3; + access::rw(Cube::mem) = (fixed_n_elem > Cube_prealloc::mem_n_elem) ? mem_local_extra : mem_local; + Cube::mat_ptrs = (fixed_n_slices > Cube_prealloc::mat_ptrs_size) ? mat_ptrs_local_extra : mat_ptrs_local; + + create_mat(); + } + else + { + access::rw(Cube::n_rows) = 0; + access::rw(Cube::n_cols) = 0; + access::rw(Cube::n_elem_slice) = 0; + access::rw(Cube::n_slices) = 0; + access::rw(Cube::n_elem) = 0; + access::rw(Cube::n_alloc) = 0; + access::rw(Cube::mem_state) = 3; + access::rw(Cube::mem) = nullptr; + Cube::mat_ptrs = nullptr; + } + } + + + +template +template +inline +Cube::fixed::fixed() + { + arma_extra_debug_sigprint_this(this); + + mem_setup(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Cube::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(mem_local[0]); + + arrayops::fill_zeros(mem_use, fixed_n_elem); + } + } + + + +template +template +inline +Cube::fixed::fixed(const fixed& X) + { + arma_extra_debug_sigprint_this(this); + + mem_setup(); + + eT* dest = (use_extra) ? mem_local_extra : mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + } + + + +template +template +inline +Cube::fixed::fixed(const fill::scalar_holder f) + { + arma_extra_debug_sigprint_this(this); + + mem_setup(); + + (*this).fill(f.scalar); + } + + + +template +template +template +inline +Cube::fixed::fixed(const fill::fill_class&) + { + arma_extra_debug_sigprint_this(this); + + mem_setup(); + + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } + + arma_static_check( (is_same_type::yes), "Cube::fixed::fixed(): unsupported fill type" ); + } + + + +template +template +template +inline +Cube::fixed::fixed(const BaseCube& A) + { + arma_extra_debug_sigprint_this(this); + + mem_setup(); + + Cube::operator=(A.get_ref()); + } + + + +template +template +template +inline +Cube::fixed::fixed(const BaseCube& A, const BaseCube& B) + { + arma_extra_debug_sigprint_this(this); + + mem_setup(); + + Cube::init(A,B); + } + + + +template +template +inline +Cube& +Cube::fixed::operator=(const fixed& X) + { + arma_extra_debug_sigprint(); + + eT* dest = (use_extra) ? mem_local_extra : mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + + return *this; + } + + + +template +template +arma_inline +eT& +Cube::fixed::operator[] (const uword i) + { + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +const eT& +Cube::fixed::operator[] (const uword i) const + { + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +eT& +Cube::fixed::at(const uword i) + { + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +const eT& +Cube::fixed::at(const uword i) const + { + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +eT& +Cube::fixed::operator() (const uword i) + { + arma_debug_check_bounds( (i >= fixed_n_elem), "Cube::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +const eT& +Cube::fixed::operator() (const uword i) const + { + arma_debug_check_bounds( (i >= fixed_n_elem), "Cube::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +#if defined(__cpp_multidimensional_subscript) + + template + template + arma_inline + eT& + Cube::fixed::operator[] (const uword in_row, const uword in_col, const uword in_slice) + { + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + + template + template + arma_inline + const eT& + Cube::fixed::operator[] (const uword in_row, const uword in_col, const uword in_slice) const + { + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + +#endif + + + +template +template +arma_inline +eT& +Cube::fixed::at(const uword in_row, const uword in_col, const uword in_slice) + { + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +const eT& +Cube::fixed::at(const uword in_row, const uword in_col, const uword in_slice) const + { + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +eT& +Cube::fixed::operator() (const uword in_row, const uword in_col, const uword in_slice) + { + arma_debug_check_bounds + ( + (in_row >= fixed_n_rows ) || + (in_col >= fixed_n_cols ) || + (in_slice >= fixed_n_slices) + , + "operator(): index out of bounds" + ); + + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +template +template +arma_inline +const eT& +Cube::fixed::operator() (const uword in_row, const uword in_col, const uword in_slice) const + { + arma_debug_check_bounds + ( + (in_row >= fixed_n_rows ) || + (in_col >= fixed_n_cols ) || + (in_slice >= fixed_n_slices) + , + "Cube::operator(): index out of bounds" + ); + + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + +// +// Cube_aux + + + +//! prefix ++ +template +arma_inline +void +Cube_aux::prefix_pp(Cube& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +arma_inline +void +Cube_aux::prefix_pp(Cube< std::complex >& x) + { + x += T(1); + } + + + +//! postfix ++ +template +arma_inline +void +Cube_aux::postfix_pp(Cube& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +arma_inline +void +Cube_aux::postfix_pp(Cube< std::complex >& x) + { + x += T(1); + } + + + +//! prefix -- +template +arma_inline +void +Cube_aux::prefix_mm(Cube& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +arma_inline +void +Cube_aux::prefix_mm(Cube< std::complex >& x) + { + x -= T(1); + } + + + +//! postfix -- +template +arma_inline +void +Cube_aux::postfix_mm(Cube& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +arma_inline +void +Cube_aux::postfix_mm(Cube< std::complex >& x) + { + x -= T(1); + } + + + +template +inline +void +Cube_aux::set_real(Cube& out, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp(X.get_ref()); + const Cube& A = tmp.M; + + arma_debug_assert_same_size( out, A, "Cube::set_real()" ); + + out = A; + } + + + +template +inline +void +Cube_aux::set_imag(Cube&, const BaseCube&) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +Cube_aux::set_real(Cube< std::complex >& out, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube P(X.get_ref()); + + const uword local_n_rows = P.get_n_rows(); + const uword local_n_cols = P.get_n_cols(); + const uword local_n_slices = P.get_n_slices(); + + arma_debug_assert_same_size + ( + out.n_rows, out.n_cols, out.n_slices, + local_n_rows, local_n_cols, local_n_slices, + "Cube::set_real()" + ); + + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + const uword N = out.n_elem; + + for(uword i=0; i +inline +void +Cube_aux::set_imag(Cube< std::complex >& out, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube P(X.get_ref()); + + const uword local_n_rows = P.get_n_rows(); + const uword local_n_cols = P.get_n_cols(); + const uword local_n_slices = P.get_n_slices(); + + arma_debug_assert_same_size + ( + out.n_rows, out.n_cols, out.n_slices, + local_n_rows, local_n_cols, local_n_slices, + "Cube::set_imag()" + ); + + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + const uword N = out.n_elem; + + for(uword i=0; i +class GenCube + : public BaseCube< eT, GenCube > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool use_at = false; + static constexpr bool is_simple = (is_same_type::value) || (is_same_type::value); + + arma_aligned const uword n_rows; + arma_aligned const uword n_cols; + arma_aligned const uword n_slices; + + arma_inline GenCube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); + arma_inline ~GenCube(); + + arma_inline eT operator[] (const uword i) const; + arma_inline eT at (const uword r, const uword c, const uword s) const; + arma_inline eT at_alt (const uword i) const; + + inline void apply (Cube& out) const; + inline void apply_inplace_plus (Cube& out) const; + inline void apply_inplace_minus(Cube& out) const; + inline void apply_inplace_schur(Cube& out) const; + inline void apply_inplace_div (Cube& out) const; + + inline void apply(subview_cube& out) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/GenCube_meat.hpp b/src/armadillo/include/armadillo_bits/GenCube_meat.hpp new file mode 100644 index 0000000..61735f3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/GenCube_meat.hpp @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup GenCube +//! @{ + + + +template +arma_inline +GenCube::GenCube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) + : n_rows (in_n_rows ) + , n_cols (in_n_cols ) + , n_slices(in_n_slices) + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +GenCube::~GenCube() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +eT +GenCube::operator[](const uword) const + { + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + + return eT(0); // prevent pedantic compiler warnings + } + + + +template +arma_inline +eT +GenCube::at(const uword, const uword, const uword) const + { + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + + return eT(0); // prevent pedantic compiler warnings + } + + + +template +arma_inline +eT +GenCube::at_alt(const uword) const + { + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + + return eT(0); // prevent pedantic compiler warnings + } + + + +template +inline +void +GenCube::apply(Cube& out) const + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the cube has already been set to the correct size; + // this is done by either the Cube contructor or operator=() + + if(is_same_type::yes) { out.zeros(); } + else if(is_same_type::yes) { out.ones(); } + } + + + +template +inline +void +GenCube::apply_inplace_plus(Cube& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "addition"); + + if(is_same_type::yes) + { + arrayops::inplace_plus(out.memptr(), eT(1), out.n_elem); + } + } + + + + +template +inline +void +GenCube::apply_inplace_minus(Cube& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "subtraction"); + + if(is_same_type::yes) + { + arrayops::inplace_minus(out.memptr(), eT(1), out.n_elem); + } + } + + + + +template +inline +void +GenCube::apply_inplace_schur(Cube& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise multiplication"); + + if(is_same_type::yes) + { + arrayops::inplace_mul(out.memptr(), eT(0), out.n_elem); + // NOTE: not using arrayops::fill_zeros(), as 'out' may have NaN elements + } + } + + + + +template +inline +void +GenCube::apply_inplace_div(Cube& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise division"); + + if(is_same_type::yes) + { + arrayops::inplace_div(out.memptr(), eT(0), out.n_elem); + } + } + + + +template +inline +void +GenCube::apply(subview_cube& out) const + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the subcube has the same dimensions as the GenCube object + // this is checked by subview_cube::operator=() + + if(is_same_type::yes) { out.zeros(); } + else if(is_same_type::yes) { out.ones(); } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Gen_bones.hpp b/src/armadillo/include/armadillo_bits/Gen_bones.hpp new file mode 100644 index 0000000..172e5b9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Gen_bones.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Gen +//! @{ + + +//! support class for generator functions (zeros, ones, eye) +template +class Gen + : public Base< typename T1::elem_type, Gen > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool use_at = (is_same_type::value); + static constexpr bool is_simple = (is_same_type::value) || (is_same_type::value); + + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + + arma_aligned const uword n_rows; + arma_aligned const uword n_cols; + + arma_inline Gen(const uword in_n_rows, const uword in_n_cols); + arma_inline ~Gen(); + + arma_inline elem_type operator[] (const uword ii) const; + arma_inline elem_type at (const uword r, const uword c) const; + arma_inline elem_type at_alt (const uword ii) const; + + inline void apply (Mat& out) const; + inline void apply_inplace_plus (Mat& out) const; + inline void apply_inplace_minus(Mat& out) const; + inline void apply_inplace_schur(Mat& out) const; + inline void apply_inplace_div (Mat& out) const; + + inline void apply(subview& out) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Gen_meat.hpp b/src/armadillo/include/armadillo_bits/Gen_meat.hpp new file mode 100644 index 0000000..96b8940 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Gen_meat.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Gen +//! @{ + + + +template +arma_inline +Gen::Gen(const uword in_n_rows, const uword in_n_cols) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +Gen::~Gen() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +typename T1::elem_type +Gen::operator[](const uword ii) const + { + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + else if(is_same_type::yes) { return ((ii % n_rows) == (ii / n_rows)) ? eT(1) : eT(0); } + + return eT(0); // prevent pedantic compiler warnings + } + + + +template +arma_inline +typename T1::elem_type +Gen::at(const uword r, const uword c) const + { + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + else if(is_same_type::yes) { return (r == c) ? eT(1) : eT(0); } + + return eT(0); // prevent pedantic compiler warnings + } + + + +template +arma_inline +typename T1::elem_type +Gen::at_alt(const uword ii) const + { + return operator[](ii); + } + + + +template +inline +void +Gen::apply(Mat& out) const + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the matrix has already been set to the correct size; + // this is done by either the Mat contructor or operator=() + + if(is_same_type::yes) { out.zeros(); } + else if(is_same_type::yes) { out.ones(); } + else if(is_same_type::yes) { out.eye(); } + } + + + +template +inline +void +Gen::apply_inplace_plus(Mat& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "addition"); + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) + { + arrayops::inplace_plus(out.memptr(), eT(1), out.n_elem); + } + else + if(is_same_type::yes) + { + const uword N = (std::min)(n_rows, n_cols); + + for(uword ii=0; ii < N; ++ii) { out.at(ii,ii) += eT(1); } + } + } + + + + +template +inline +void +Gen::apply_inplace_minus(Mat& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "subtraction"); + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) + { + arrayops::inplace_minus(out.memptr(), eT(1), out.n_elem); + } + else + if(is_same_type::yes) + { + const uword N = (std::min)(n_rows, n_cols); + + for(uword ii=0; ii < N; ++ii) { out.at(ii,ii) -= eT(1); } + } + } + + + + +template +inline +void +Gen::apply_inplace_schur(Mat& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise multiplication"); + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) + { + arrayops::inplace_mul(out.memptr(), eT(0), out.n_elem); + // NOTE: not using arrayops::fill_zeros(), as 'out' may have NaN elements + } + else + if(is_same_type::yes) + { + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + if(r != c) { out.at(r,c) *= eT(0); } + } + } + } + + + + +template +inline +void +Gen::apply_inplace_div(Mat& out) const + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise division"); + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) + { + arrayops::inplace_div(out.memptr(), eT(0), out.n_elem); + } + else + if(is_same_type::yes) + { + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + if(r != c) { out.at(r,c) /= eT(0); } + } + } + } + + + +template +inline +void +Gen::apply(subview& out) const + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the submatrix has the same dimensions as the Gen object + // this is checked by subview::operator=() + + if(is_same_type::yes) { out.zeros(); } + else if(is_same_type::yes) { out.ones(); } + else if(is_same_type::yes) { out.eye(); } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/GlueCube_bones.hpp b/src/armadillo/include/armadillo_bits/GlueCube_bones.hpp new file mode 100644 index 0000000..75173ef --- /dev/null +++ b/src/armadillo/include/armadillo_bits/GlueCube_bones.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup GlueCube +//! @{ + + + +//! analog of the Glue class, intended for Cube objects +template +class GlueCube : public BaseCube< typename T1::elem_type, GlueCube > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline GlueCube(const BaseCube& in_A, const BaseCube& in_B); + inline ~GlueCube(); + + const T1& A; //!< first operand; must be derived from BaseCube + const T2& B; //!< second operand; must be derived from BaseCube + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/GlueCube_meat.hpp b/src/armadillo/include/armadillo_bits/GlueCube_meat.hpp new file mode 100644 index 0000000..4195650 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/GlueCube_meat.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup GlueCube +//! @{ + + + +template +inline +GlueCube::GlueCube(const BaseCube& in_A, const BaseCube& in_B) + : A(in_A.get_ref()) + , B(in_B.get_ref()) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +GlueCube::~GlueCube() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Glue_bones.hpp b/src/armadillo/include/armadillo_bits/Glue_bones.hpp new file mode 100644 index 0000000..197ae74 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Glue_bones.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Glue +//! @{ + + + +template +struct Glue_traits {}; + + +template +struct Glue_traits + { + static constexpr bool is_row = glue_type::template traits::is_row; + static constexpr bool is_col = glue_type::template traits::is_col; + static constexpr bool is_xvec = glue_type::template traits::is_xvec; + }; + +template +struct Glue_traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + }; + + +template +class Glue + : public Base< typename T1::elem_type, Glue > + , public Glue_traits::value> + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline Glue(const T1& in_A, const T2& in_B); + inline Glue(const T1& in_A, const T2& in_B, const uword in_aux_uword); + inline ~Glue(); + + const T1& A; //!< first operand; must be derived from Base + const T2& B; //!< second operand; must be derived from Base + uword aux_uword; //!< storage of auxiliary data, uword format + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Glue_meat.hpp b/src/armadillo/include/armadillo_bits/Glue_meat.hpp new file mode 100644 index 0000000..713fb16 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Glue_meat.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Glue +//! @{ + + + +template +inline +Glue::Glue(const T1& in_A, const T2& in_B) + : A(in_A) + , B(in_B) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Glue::Glue(const T1& in_A, const T2& in_B, const uword in_aux_uword) + : A(in_A) + , B(in_B) + , aux_uword(in_aux_uword) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Glue::~Glue() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/MapMat_bones.hpp b/src/armadillo/include/armadillo_bits/MapMat_bones.hpp new file mode 100644 index 0000000..7ab46b0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/MapMat_bones.hpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup MapMat +//! @{ + + + +// this class is for internal use only; subject to change and/or removal without notice +template +class MapMat + { + public: + + typedef eT elem_type; //!< the type of elements stored in the matrix + typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const uword n_rows; //!< number of rows (read-only) + const uword n_cols; //!< number of columns (read-only) + const uword n_elem; //!< number of elements (read-only) + + + private: + + typedef typename std::map map_type; + + arma_aligned map_type* map_ptr; + + + public: + + inline ~MapMat(); + inline MapMat(); + + inline explicit MapMat(const uword in_n_rows, const uword in_n_cols); + inline explicit MapMat(const SizeMat& s); + + inline MapMat(const MapMat& x); + inline void operator=(const MapMat& x); + + inline explicit MapMat(const SpMat& x); + inline void operator=(const SpMat& x); + + inline MapMat(MapMat&& x); + inline void operator=(MapMat&& x); + + inline void reset(); + inline void set_size(const uword in_n_rows); + inline void set_size(const uword in_n_rows, const uword in_n_cols); + inline void set_size(const SizeMat& s); + + inline void zeros(); + inline void zeros(const uword in_n_rows); + inline void zeros(const uword in_n_rows, const uword in_n_cols); + inline void zeros(const SizeMat& s); + + inline void eye(); + inline void eye(const uword in_n_rows, const uword in_n_cols); + inline void eye(const SizeMat& s); + + inline void speye(); + inline void speye(const uword in_n_rows, const uword in_n_cols); + inline void speye(const SizeMat& s); + + arma_warn_unused arma_inline MapMat_val operator[](const uword index); + arma_warn_unused inline eT operator[](const uword index) const; + + arma_warn_unused arma_inline MapMat_val operator()(const uword index); + arma_warn_unused inline eT operator()(const uword index) const; + + arma_warn_unused arma_inline MapMat_val at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; + + arma_warn_unused arma_inline MapMat_val operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; + + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_rowvec() const; + arma_warn_unused inline bool is_colvec() const; + arma_warn_unused inline bool is_square() const; + + + inline void sprandu(const uword in_n_rows, const uword in_n_cols, const double density); + + inline void print(const std::string& extra_text) const; + + inline uword get_n_nonzero() const; + inline void get_locval_format(umat& locs, Col& vals) const; + + + private: + + inline void init_cold(); + inline void init_warm(const uword in_n_rows, const uword in_n_cols); + + arma_inline void set_val(const uword index, const eT& in_val); + inline void erase_val(const uword index); + + + friend class SpMat; + friend class MapMat_val; + friend class SpMat_MapMat_val; + friend class SpSubview_MapMat_val; + }; + + + +template +class MapMat_val + { + private: + + arma_aligned MapMat& parent; + + arma_aligned const uword index; + + inline MapMat_val(MapMat& in_parent, const uword in_index); + + friend class MapMat; + + + public: + + arma_inline operator eT() const; + + arma_inline typename get_pod_type::result real() const; + arma_inline typename get_pod_type::result imag() const; + + arma_inline void operator= (const MapMat_val& x); + arma_inline void operator= (const eT in_val); + arma_inline void operator+=(const eT in_val); + arma_inline void operator-=(const eT in_val); + arma_inline void operator*=(const eT in_val); + arma_inline void operator/=(const eT in_val); + + arma_inline void operator++(); + arma_inline void operator++(int); + + arma_inline void operator--(); + arma_inline void operator--(int); + }; + + + +template +class SpMat_MapMat_val + { + private: + + arma_aligned SpMat& s_parent; + arma_aligned MapMat& m_parent; + + arma_aligned const uword row; + arma_aligned const uword col; + + inline SpMat_MapMat_val(SpMat& in_s_parent, MapMat& in_m_parent, const uword in_row, const uword in_col); + + friend class SpMat; + friend class MapMat; + friend class SpSubview_MapMat_val; + + + public: + + inline operator eT() const; + + inline typename get_pod_type::result real() const; + inline typename get_pod_type::result imag() const; + + inline SpMat_MapMat_val& operator= (const SpMat_MapMat_val& x); + + inline SpMat_MapMat_val& operator= (const eT in_val); + inline SpMat_MapMat_val& operator+=(const eT in_val); + inline SpMat_MapMat_val& operator-=(const eT in_val); + inline SpMat_MapMat_val& operator*=(const eT in_val); + inline SpMat_MapMat_val& operator/=(const eT in_val); + + inline SpMat_MapMat_val& operator++(); + arma_warn_unused inline eT operator++(int); + + inline SpMat_MapMat_val& operator--(); + arma_warn_unused inline eT operator--(int); + + inline void set(const eT in_val); + inline void add(const eT in_val); + inline void sub(const eT in_val); + inline void mul(const eT in_val); + inline void div(const eT in_val); + }; + + + +template +class SpSubview_MapMat_val : public SpMat_MapMat_val + { + private: + + arma_inline SpSubview_MapMat_val(SpSubview& in_sv_parent, MapMat& in_m_parent, const uword in_row, const uword in_col); + + arma_aligned SpSubview& sv_parent; + + friend class SpMat; + friend class MapMat; + friend class SpSubview; + friend class SpMat_MapMat_val; + + + public: + + inline SpSubview_MapMat_val& operator= (const SpSubview_MapMat_val& x); + + inline SpSubview_MapMat_val& operator= (const eT in_val); + inline SpSubview_MapMat_val& operator+=(const eT in_val); + inline SpSubview_MapMat_val& operator-=(const eT in_val); + inline SpSubview_MapMat_val& operator*=(const eT in_val); + inline SpSubview_MapMat_val& operator/=(const eT in_val); + + inline SpSubview_MapMat_val& operator++(); + arma_warn_unused inline eT operator++(int); + + inline SpSubview_MapMat_val& operator--(); + arma_warn_unused inline eT operator--(int); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/MapMat_meat.hpp b/src/armadillo/include/armadillo_bits/MapMat_meat.hpp new file mode 100644 index 0000000..e67a307 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/MapMat_meat.hpp @@ -0,0 +1,1778 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup MapMat +//! @{ + + + +template +inline +MapMat::~MapMat() + { + arma_extra_debug_sigprint_this(this); + + if(map_ptr) { (*map_ptr).clear(); delete map_ptr; } + + // try to expose buggy user code that accesses deleted objects + if(arma_config::debug) { map_ptr = nullptr; } + + arma_type_check(( is_supported_elem_type::value == false )); + } + + + +template +inline +MapMat::MapMat() + : n_rows (0) + , n_cols (0) + , n_elem (0) + , map_ptr(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + } + + + +template +inline +MapMat::MapMat(const uword in_n_rows, const uword in_n_cols) + : n_rows (in_n_rows) + , n_cols (in_n_cols) + , n_elem (in_n_rows * in_n_cols) + , map_ptr(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + } + + + +template +inline +MapMat::MapMat(const SizeMat& s) + : n_rows (s.n_rows) + , n_cols (s.n_cols) + , n_elem (s.n_rows * s.n_cols) + , map_ptr(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + } + + + +template +inline +MapMat::MapMat(const MapMat& x) + : n_rows (0) + , n_cols (0) + , n_elem (0) + , map_ptr(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).operator=(x); + } + + + +template +inline +void +MapMat::operator=(const MapMat& x) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + access::rw(n_rows) = x.n_rows; + access::rw(n_cols) = x.n_cols; + access::rw(n_elem) = x.n_elem; + + (*map_ptr) = *(x.map_ptr); + } + + + +template +inline +MapMat::MapMat(const SpMat& x) + : n_rows (0) + , n_cols (0) + , n_elem (0) + , map_ptr(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).operator=(x); + } + + + +template +inline +void +MapMat::operator=(const SpMat& x) + { + arma_extra_debug_sigprint(); + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + + (*this).zeros(x_n_rows, x_n_cols); + + if(x.n_nonzero == 0) { return; } + + const eT* x_values = x.values; + const uword* x_row_indices = x.row_indices; + const uword* x_col_ptrs = x.col_ptrs; + + map_type& map_ref = (*map_ptr); + + for(uword col = 0; col < x_n_cols; ++col) + { + const uword start = x_col_ptrs[col ]; + const uword end = x_col_ptrs[col + 1]; + + for(uword i = start; i < end; ++i) + { + const uword row = x_row_indices[i]; + const eT val = x_values[i]; + + const uword index = (x_n_rows * col) + row; + + map_ref.emplace_hint(map_ref.cend(), index, val); + } + } + } + + + +template +inline +MapMat::MapMat(MapMat&& x) + : n_rows (x.n_rows ) + , n_cols (x.n_cols ) + , n_elem (x.n_elem ) + , map_ptr(x.map_ptr) + { + arma_extra_debug_sigprint_this(this); + + access::rw(x.n_rows) = 0; + access::rw(x.n_cols) = 0; + access::rw(x.n_elem) = 0; + access::rw(x.map_ptr) = nullptr; + } + + + +template +inline +void +MapMat::operator=(MapMat&& x) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + reset(); + + if(map_ptr) { delete map_ptr; } + + access::rw(n_rows) = x.n_rows; + access::rw(n_cols) = x.n_cols; + access::rw(n_elem) = x.n_elem; + access::rw(map_ptr) = x.map_ptr; + + access::rw(x.n_rows) = 0; + access::rw(x.n_cols) = 0; + access::rw(x.n_elem) = 0; + access::rw(x.map_ptr) = nullptr; + } + + + +template +inline +void +MapMat::reset() + { + arma_extra_debug_sigprint(); + + access::rw(n_rows) = 0; + access::rw(n_cols) = 0; + access::rw(n_elem) = 0; + + if((*map_ptr).empty() == false) { (*map_ptr).clear(); } + } + + + +template +inline +void +MapMat::set_size(const uword in_n_rows) + { + arma_extra_debug_sigprint(); + + init_warm(in_n_rows, 1); + } + + + +template +inline +void +MapMat::set_size(const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + init_warm(in_n_rows, in_n_cols); + } + + + +template +inline +void +MapMat::set_size(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + init_warm(s.n_rows, s.n_cols); + } + + + +template +inline +void +MapMat::zeros() + { + arma_extra_debug_sigprint(); + + (*map_ptr).clear(); + } + + + +template +inline +void +MapMat::zeros(const uword in_n_rows) + { + arma_extra_debug_sigprint(); + + init_warm(in_n_rows, 1); + + (*map_ptr).clear(); + } + + + +template +inline +void +MapMat::zeros(const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + init_warm(in_n_rows, in_n_cols); + + (*map_ptr).clear(); + } + + + +template +inline +void +MapMat::zeros(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + init_warm(s.n_rows, s.n_cols); + + (*map_ptr).clear(); + } + + + +template +inline +void +MapMat::eye() + { + arma_extra_debug_sigprint(); + + (*this).eye(n_rows, n_cols); + } + + + +template +inline +void +MapMat::eye(const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + zeros(in_n_rows, in_n_cols); + + const uword N = (std::min)(in_n_rows, in_n_cols); + + map_type& map_ref = (*map_ptr); + + for(uword i=0; i +inline +void +MapMat::eye(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + (*this).eye(s.n_rows, s.n_cols); + } + + + +template +inline +void +MapMat::speye() + { + arma_extra_debug_sigprint(); + + (*this).eye(); + } + + + +template +inline +void +MapMat::speye(const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + (*this).eye(in_n_rows, in_n_cols); + } + + + +template +inline +void +MapMat::speye(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + (*this).eye(s); + } + + + +template +arma_inline +MapMat_val +MapMat::operator[](const uword index) + { + return MapMat_val(*this, index); + } + + + +template +inline +eT +MapMat::operator[](const uword index) const + { + map_type& map_ref = (*map_ptr); + + typename map_type::const_iterator it = map_ref.find(index); + typename map_type::const_iterator it_end = map_ref.end(); + + return (it != it_end) ? eT((*it).second) : eT(0); + } + + + +template +arma_inline +MapMat_val +MapMat::operator()(const uword index) + { + arma_debug_check_bounds( (index >= n_elem), "MapMat::operator(): index out of bounds" ); + + return MapMat_val(*this, index); + } + + + +template +inline +eT +MapMat::operator()(const uword index) const + { + arma_debug_check_bounds( (index >= n_elem), "MapMat::operator(): index out of bounds" ); + + map_type& map_ref = (*map_ptr); + + typename map_type::const_iterator it = map_ref.find(index); + typename map_type::const_iterator it_end = map_ref.end(); + + return (it != it_end) ? eT((*it).second) : eT(0); + } + + + +template +arma_inline +MapMat_val +MapMat::at(const uword in_row, const uword in_col) + { + const uword index = (n_rows * in_col) + in_row; + + return MapMat_val(*this, index); + } + + + +template +inline +eT +MapMat::at(const uword in_row, const uword in_col) const + { + const uword index = (n_rows * in_col) + in_row; + + map_type& map_ref = (*map_ptr); + + typename map_type::const_iterator it = map_ref.find(index); + typename map_type::const_iterator it_end = map_ref.end(); + + return (it != it_end) ? eT((*it).second) : eT(0); + } + + + +template +arma_inline +MapMat_val +MapMat::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "MapMat::operator(): index out of bounds" ); + + const uword index = (n_rows * in_col) + in_row; + + return MapMat_val(*this, index); + } + + + +template +inline +eT +MapMat::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "MapMat::operator(): index out of bounds" ); + + const uword index = (n_rows * in_col) + in_row; + + map_type& map_ref = (*map_ptr); + + typename map_type::const_iterator it = map_ref.find(index); + typename map_type::const_iterator it_end = map_ref.end(); + + return (it != it_end) ? eT((*it).second) : eT(0); + } + + + +template +inline +bool +MapMat::is_empty() const + { + return (n_elem == 0); + } + + + +template +inline +bool +MapMat::is_vec() const + { + return ( (n_rows == 1) || (n_cols == 1) ); + } + + + +template +inline +bool +MapMat::is_rowvec() const + { + return (n_rows == 1); + } + + + +//! returns true if the object can be interpreted as a column vector +template +inline +bool +MapMat::is_colvec() const + { + return (n_cols == 1); + } + + + +template +inline +bool +MapMat::is_square() const + { + return (n_rows == n_cols); + } + + + +// this function is for debugging purposes only +template +inline +void +MapMat::sprandu(const uword in_n_rows, const uword in_n_cols, const double density) + { + arma_extra_debug_sigprint(); + + zeros(in_n_rows, in_n_cols); + + const uword N = uword(density * double(n_elem)); + + const Col vals(N, fill::randu); + const Col indx = linspace< Col >(0, ((n_elem > 0) ? uword(n_elem-1) : uword(0)) , N); + + const eT* vals_mem = vals.memptr(); + const uword* indx_mem = indx.memptr(); + + map_type& map_ref = (*map_ptr); + + for(uword i=0; i < N; ++i) + { + const uword index = indx_mem[i]; + const eT val = vals_mem[i]; + + map_ref.emplace_hint(map_ref.cend(), index, val); + } + } + + + +// this function is for debugging purposes only +template +inline +void +MapMat::print(const std::string& extra_text) const + { + arma_extra_debug_sigprint(); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + map_type& map_ref = (*map_ptr); + + const uword n_nonzero = uword(map_ref.size()); + + const double density = (n_elem > 0) ? ((double(n_nonzero) / double(n_elem))*double(100)) : double(0); + + get_cout_stream() + << "[matrix size: " << n_rows << 'x' << n_cols << "; n_nonzero: " << n_nonzero + << "; density: " << density << "%]\n\n"; + + if(n_nonzero > 0) + { + typename map_type::const_iterator it = map_ref.begin(); + + for(uword i=0; i < n_nonzero; ++i) + { + const std::pair& entry = (*it); + + const uword index = entry.first; + const eT val = entry.second; + + const uword row = index % n_rows; + const uword col = index / n_rows; + + get_cout_stream() << '(' << row << ", " << col << ") "; + get_cout_stream() << val << '\n'; + + ++it; + } + } + + get_cout_stream().flush(); + } + + + +template +inline +uword +MapMat::get_n_nonzero() const + { + arma_extra_debug_sigprint(); + + return uword((*map_ptr).size()); + } + + + +template +inline +void +MapMat::get_locval_format(umat& locs, Col& vals) const + { + arma_extra_debug_sigprint(); + + map_type& map_ref = (*map_ptr); + + typename map_type::const_iterator it = map_ref.begin(); + + const uword N = uword(map_ref.size()); + + locs.set_size(2,N); + vals.set_size(N); + + eT* vals_mem = vals.memptr(); + + for(uword i=0; i& entry = (*it); + + const uword index = entry.first; + const eT val = entry.second; + + const uword row = index % n_rows; + const uword col = index / n_rows; + + uword* locs_colptr = locs.colptr(i); + + locs_colptr[0] = row; + locs_colptr[1] = col; + + vals_mem[i] = val; + + ++it; + } + } + + + +template +inline +void +MapMat::init_cold() + { + arma_extra_debug_sigprint(); + + // ensure that n_elem can hold the result of (n_rows * n_cols) + + #if defined(ARMA_64BIT_WORD) + const char* error_message = "MapMat(): requested size is too large"; + #else + const char* error_message = "MapMat(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + arma_debug_check + ( + ( + ( (n_rows > ARMA_MAX_UHWORD) || (n_cols > ARMA_MAX_UHWORD) ) + ? ( (double(n_rows) * double(n_cols)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message + ); + + map_ptr = new (std::nothrow) map_type; + + arma_check_bad_alloc( (map_ptr == nullptr), "MapMat(): out of memory" ); + } + + + +template +inline +void +MapMat::init_warm(const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + if( (n_rows == in_n_rows) && (n_cols == in_n_cols)) { return; } + + // ensure that n_elem can hold the result of (n_rows * n_cols) + + #if defined(ARMA_64BIT_WORD) + const char* error_message = "MapMat(): requested size is too large"; + #else + const char* error_message = "MapMat(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + arma_debug_check + ( + ( + ( (in_n_rows > ARMA_MAX_UHWORD) || (in_n_cols > ARMA_MAX_UHWORD) ) + ? ( (double(in_n_rows) * double(in_n_cols)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message + ); + + const uword new_n_elem = in_n_rows * in_n_cols; + + access::rw(n_rows) = in_n_rows; + access::rw(n_cols) = in_n_cols; + access::rw(n_elem) = new_n_elem; + + if(new_n_elem == 0) { (*map_ptr).clear(); } + } + + + +template +arma_inline +void +MapMat::set_val(const uword index, const eT& in_val) + { + arma_extra_debug_sigprint(); + + if(in_val != eT(0)) + { + map_type& map_ref = (*map_ptr); + + if( (map_ref.empty() == false) && (index > uword(map_ref.crbegin()->first)) ) + { + map_ref.emplace_hint(map_ref.cend(), index, in_val); + } + else + { + map_ref.operator[](index) = in_val; + } + } + else + { + (*this).erase_val(index); + } + } + + + +template +inline +void +MapMat::erase_val(const uword index) + { + arma_extra_debug_sigprint(); + + map_type& map_ref = (*map_ptr); + + typename map_type::iterator it = map_ref.find(index); + typename map_type::iterator it_end = map_ref.end(); + + if(it != it_end) { map_ref.erase(it); } + } + + + + + + +// MapMat_val + + + +template +arma_inline +MapMat_val::MapMat_val(MapMat& in_parent, const uword in_index) + : parent(in_parent) + , index (in_index ) + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +MapMat_val::operator eT() const + { + arma_extra_debug_sigprint(); + + const MapMat& const_parent = parent; + + return const_parent.operator[](index); + } + + + +template +arma_inline +typename get_pod_type::result +MapMat_val::real() const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const MapMat& const_parent = parent; + + return T( access::tmp_real( const_parent.operator[](index) ) ); + } + + + +template +arma_inline +typename get_pod_type::result +MapMat_val::imag() const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const MapMat& const_parent = parent; + + return T( access::tmp_imag( const_parent.operator[](index) ) ); + } + + + +template +arma_inline +void +MapMat_val::operator=(const MapMat_val& x) + { + arma_extra_debug_sigprint(); + + const eT in_val = eT(x); + + parent.set_val(index, in_val); + } + + + +template +arma_inline +void +MapMat_val::operator=(const eT in_val) + { + arma_extra_debug_sigprint(); + + parent.set_val(index, in_val); + } + + + +template +arma_inline +void +MapMat_val::operator+=(const eT in_val) + { + arma_extra_debug_sigprint(); + + typename MapMat::map_type& map_ref = *(parent.map_ptr); + + if(in_val != eT(0)) + { + eT& val = map_ref.operator[](index); // creates the element if it doesn't exist + + val += in_val; + + if(val == eT(0)) { map_ref.erase(index); } + } + } + + + +template +arma_inline +void +MapMat_val::operator-=(const eT in_val) + { + arma_extra_debug_sigprint(); + + typename MapMat::map_type& map_ref = *(parent.map_ptr); + + if(in_val != eT(0)) + { + eT& val = map_ref.operator[](index); // creates the element if it doesn't exist + + val -= in_val; + + if(val == eT(0)) { map_ref.erase(index); } + } + } + + + +template +arma_inline +void +MapMat_val::operator*=(const eT in_val) + { + arma_extra_debug_sigprint(); + + typename MapMat::map_type& map_ref = *(parent.map_ptr); + + typename MapMat::map_type::iterator it = map_ref.find(index); + typename MapMat::map_type::iterator it_end = map_ref.end(); + + if(it != it_end) + { + if(in_val != eT(0)) + { + eT& val = (*it).second; + + val *= in_val; + + if(val == eT(0)) { map_ref.erase(it); } + } + else + { + map_ref.erase(it); + } + } + } + + + +template +arma_inline +void +MapMat_val::operator/=(const eT in_val) + { + arma_extra_debug_sigprint(); + + typename MapMat::map_type& map_ref = *(parent.map_ptr); + + typename MapMat::map_type::iterator it = map_ref.find(index); + typename MapMat::map_type::iterator it_end = map_ref.end(); + + if(it != it_end) + { + eT& val = (*it).second; + + val /= in_val; + + if(val == eT(0)) { map_ref.erase(it); } + } + else + { + // silly operation, but included for completness + + const eT val = eT(0) / in_val; + + if(val != eT(0)) { parent.set_val(index, val); } + } + } + + + +template +arma_inline +void +MapMat_val::operator++() + { + arma_extra_debug_sigprint(); + + typename MapMat::map_type& map_ref = *(parent.map_ptr); + + eT& val = map_ref.operator[](index); // creates the element if it doesn't exist + + val += eT(1); // can't use ++, as eT can be std::complex + + if(val == eT(0)) { map_ref.erase(index); } + } + + + +template +arma_inline +void +MapMat_val::operator++(int) + { + arma_extra_debug_sigprint(); + + (*this).operator++(); + } + + + +template +arma_inline +void +MapMat_val::operator--() + { + arma_extra_debug_sigprint(); + + typename MapMat::map_type& map_ref = *(parent.map_ptr); + + eT& val = map_ref.operator[](index); // creates the element if it doesn't exist + + val -= eT(1); // can't use --, as eT can be std::complex + + if(val == eT(0)) { map_ref.erase(index); } + } + + + +template +arma_inline +void +MapMat_val::operator--(int) + { + arma_extra_debug_sigprint(); + + (*this).operator--(); + } + + + + + +// SpMat_MapMat_val + + + +template +arma_inline +SpMat_MapMat_val::SpMat_MapMat_val(SpMat& in_s_parent, MapMat& in_m_parent, const uword in_row, const uword in_col) + : s_parent(in_s_parent) + , m_parent(in_m_parent) + , row (in_row ) + , col (in_col ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpMat_MapMat_val::operator eT() const + { + arma_extra_debug_sigprint(); + + const SpMat& const_s_parent = s_parent; // declare as const for clarity of intent + + return const_s_parent.get_value(row,col); + } + + + +template +inline +typename get_pod_type::result +SpMat_MapMat_val::real() const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const SpMat& const_s_parent = s_parent; // declare as const for clarity of intent + + return T( access::tmp_real( const_s_parent.get_value(row,col) ) ); + } + + + +template +inline +typename get_pod_type::result +SpMat_MapMat_val::imag() const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const SpMat& const_s_parent = s_parent; // declare as const for clarity of intent + + return T( access::tmp_imag( const_s_parent.get_value(row,col) ) ); + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator=(const SpMat_MapMat_val& x) + { + arma_extra_debug_sigprint(); + + const eT in_val = eT(x); + + return (*this).operator=(in_val); + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator=(const eT in_val) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_SpMat_cache) + { + (*this).set(in_val); + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(s_parent.cache_mutex); + + (*this).set(in_val); + } + #else + { + (*this).set(in_val); + } + #endif + + return *this; + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator+=(const eT in_val) + { + arma_extra_debug_sigprint(); + + if(in_val == eT(0)) { return *this; } + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_SpMat_cache) + { + (*this).add(in_val); + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(s_parent.cache_mutex); + + (*this).add(in_val); + } + #else + { + (*this).add(in_val); + } + #endif + + return *this; + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator-=(const eT in_val) + { + arma_extra_debug_sigprint(); + + if(in_val == eT(0)) { return *this; } + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_SpMat_cache) + { + (*this).sub(in_val); + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(s_parent.cache_mutex); + + (*this).sub(in_val); + } + #else + { + (*this).sub(in_val); + } + #endif + + return *this; + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator*=(const eT in_val) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_SpMat_cache) + { + (*this).mul(in_val); + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(s_parent.cache_mutex); + + (*this).mul(in_val); + } + #else + { + (*this).mul(in_val); + } + #endif + + return *this; + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator/=(const eT in_val) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_SpMat_cache) + { + (*this).div(in_val); + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(s_parent.cache_mutex); + + (*this).div(in_val); + } + #else + { + (*this).div(in_val); + } + #endif + + return *this; + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator++() + { + arma_extra_debug_sigprint(); + + return (*this).operator+=( eT(1) ); + } + + + +template +inline +eT +SpMat_MapMat_val::operator++(int) + { + arma_extra_debug_sigprint(); + + const eT old_val = eT(*this); + + (*this).operator+=( eT(1) ); + + return old_val; + } + + + +template +inline +SpMat_MapMat_val& +SpMat_MapMat_val::operator--() + { + arma_extra_debug_sigprint(); + + return (*this).operator-=( eT(1) ); + } + + + +template +inline +eT +SpMat_MapMat_val::operator--(int) + { + arma_extra_debug_sigprint(); + + const eT old_val = eT(*this); + + (*this).operator-=( eT(1) ); + + return old_val; + } + + + +template +inline +void +SpMat_MapMat_val::set(const eT in_val) + { + arma_extra_debug_sigprint(); + + const bool done = (s_parent.sync_state == 0) ? s_parent.try_set_value_csc(row, col, in_val) : false; + + if(done == false) + { + s_parent.sync_cache_simple(); + + const uword index = (m_parent.n_rows * col) + row; + + m_parent.set_val(index, in_val); + + s_parent.sync_state = 1; + + access::rw(s_parent.n_nonzero) = m_parent.get_n_nonzero(); + } + } + + + +template +inline +void +SpMat_MapMat_val::add(const eT in_val) + { + arma_extra_debug_sigprint(); + + const bool done = (s_parent.sync_state == 0) ? s_parent.try_add_value_csc(row, col, in_val) : false; + + if(done == false) + { + s_parent.sync_cache_simple(); + + const uword index = (m_parent.n_rows * col) + row; + + typename MapMat::map_type& map_ref = *(m_parent.map_ptr); + + eT& val = map_ref.operator[](index); // creates the element if it doesn't exist + + val += in_val; + + if(val == eT(0)) { map_ref.erase(index); } + + s_parent.sync_state = 1; + + access::rw(s_parent.n_nonzero) = m_parent.get_n_nonzero(); + } + } + + + +template +inline +void +SpMat_MapMat_val::sub(const eT in_val) + { + arma_extra_debug_sigprint(); + + const bool done = (s_parent.sync_state == 0) ? s_parent.try_sub_value_csc(row, col, in_val) : false; + + if(done == false) + { + s_parent.sync_cache_simple(); + + const uword index = (m_parent.n_rows * col) + row; + + typename MapMat::map_type& map_ref = *(m_parent.map_ptr); + + eT& val = map_ref.operator[](index); // creates the element if it doesn't exist + + val -= in_val; + + if(val == eT(0)) { map_ref.erase(index); } + + s_parent.sync_state = 1; + + access::rw(s_parent.n_nonzero) = m_parent.get_n_nonzero(); + } + } + + + +template +inline +void +SpMat_MapMat_val::mul(const eT in_val) + { + arma_extra_debug_sigprint(); + + const bool done = (s_parent.sync_state == 0) ? s_parent.try_mul_value_csc(row, col, in_val) : false; + + if(done == false) + { + s_parent.sync_cache_simple(); + + const uword index = (m_parent.n_rows * col) + row; + + typename MapMat::map_type& map_ref = *(m_parent.map_ptr); + + typename MapMat::map_type::iterator it = map_ref.find(index); + typename MapMat::map_type::iterator it_end = map_ref.end(); + + if(it != it_end) + { + if(in_val != eT(0)) + { + eT& val = (*it).second; + + val *= in_val; + + if(val == eT(0)) { map_ref.erase(it); } + } + else + { + map_ref.erase(it); + } + + s_parent.sync_state = 1; + + access::rw(s_parent.n_nonzero) = m_parent.get_n_nonzero(); + } + else + { + // element not found, ie. it's zero; zero multiplied by anything is zero, except for nan and inf + if(arma_isfinite(in_val) == false) + { + const eT result = eT(0) * in_val; + + if(result != eT(0)) // paranoia, in case compiling with -ffast-math + { + m_parent.set_val(index, result); + + s_parent.sync_state = 1; + + access::rw(s_parent.n_nonzero) = m_parent.get_n_nonzero(); + } + } + } + } + } + + + +template +inline +void +SpMat_MapMat_val::div(const eT in_val) + { + arma_extra_debug_sigprint(); + + const bool done = (s_parent.sync_state == 0) ? s_parent.try_div_value_csc(row, col, in_val) : false; + + if(done == false) + { + s_parent.sync_cache_simple(); + + const uword index = (m_parent.n_rows * col) + row; + + typename MapMat::map_type& map_ref = *(m_parent.map_ptr); + + typename MapMat::map_type::iterator it = map_ref.find(index); + typename MapMat::map_type::iterator it_end = map_ref.end(); + + if(it != it_end) + { + eT& val = (*it).second; + + val /= in_val; + + if(val == eT(0)) { map_ref.erase(it); } + + s_parent.sync_state = 1; + + access::rw(s_parent.n_nonzero) = m_parent.get_n_nonzero(); + } + else + { + // element not found, ie. it's zero; zero divided by anything is zero, except for zero and nan + if( (in_val == eT(0)) || (arma_isnan(in_val)) ) + { + const eT result = eT(0) / in_val; + + if(result != eT(0)) // paranoia, in case compiling with -ffast-math + { + m_parent.set_val(index, result); + + s_parent.sync_state = 1; + + access::rw(s_parent.n_nonzero) = m_parent.get_n_nonzero(); + } + } + } + } + } + + + + +// SpSubview_MapMat_val + + + +template +arma_inline +SpSubview_MapMat_val::SpSubview_MapMat_val(SpSubview& in_sv_parent, MapMat& in_m_parent, const uword in_row, const uword in_col) + : SpMat_MapMat_val(access::rw(in_sv_parent.m), in_m_parent, in_row, in_col) + , sv_parent(in_sv_parent) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator=(const SpSubview_MapMat_val& x) + { + arma_extra_debug_sigprint(); + + const eT in_val = eT(x); + + return (*this).operator=(in_val); + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator=(const eT in_val) + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + SpMat_MapMat_val::operator=(in_val); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return *this; + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator+=(const eT in_val) + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + SpMat_MapMat_val::operator+=(in_val); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return *this; + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator-=(const eT in_val) + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + SpMat_MapMat_val::operator-=(in_val); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return *this; + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator*=(const eT in_val) + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + SpMat_MapMat_val::operator*=(in_val); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return *this; + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator/=(const eT in_val) + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + SpMat_MapMat_val::operator/=(in_val); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return *this; + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator++() + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + SpMat_MapMat_val::operator++(); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return *this; + } + + + +template +inline +eT +SpSubview_MapMat_val::operator++(int) + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + const eT old_val = SpMat_MapMat_val::operator++(int(0)); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return old_val; + } + + + +template +inline +SpSubview_MapMat_val& +SpSubview_MapMat_val::operator--() + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + SpMat_MapMat_val::operator--(); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return *this; + } + + + +template +inline +eT +SpSubview_MapMat_val::operator--(int) + { + arma_extra_debug_sigprint(); + + const uword old_n_nonzero = sv_parent.m.n_nonzero; + + const eT old_val = SpMat_MapMat_val::operator--(int(0)); + + if(sv_parent.m.n_nonzero > old_n_nonzero) { access::rw(sv_parent.n_nonzero)++; } + if(sv_parent.m.n_nonzero < old_n_nonzero) { access::rw(sv_parent.n_nonzero)--; } + + return old_val; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Mat_bones.hpp b/src/armadillo/include/armadillo_bits/Mat_bones.hpp new file mode 100644 index 0000000..baa41da --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Mat_bones.hpp @@ -0,0 +1,945 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Mat +//! @{ + + + +//! Dense matrix class + +template +class Mat : public Base< eT, Mat > + { + public: + + typedef eT elem_type; //!< the type of elements stored in the matrix + typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT + + const uword n_rows; //!< number of rows (read-only) + const uword n_cols; //!< number of columns (read-only) + const uword n_elem; //!< number of elements (read-only) + const uword n_alloc; //!< number of allocated elements (read-only); NOTE: n_alloc can be 0, even if n_elem > 0 + const uhword vec_state; //!< 0: matrix layout; 1: column vector layout; 2: row vector layout + const uhword mem_state; + + // mem_state = 0: normal matrix which manages its own memory + // mem_state = 1: use auxiliary memory until a size change + // mem_state = 2: use auxiliary memory and don't allow the number of elements to be changed + // mem_state = 3: fixed size (eg. via template based size specification) + + arma_aligned const eT* const mem; //!< pointer to the memory used for storing elements (memory is read-only) + + + protected: + + arma_align_mem eT mem_local[ arma_config::mat_prealloc ]; // local storage, for small vectors and matrices + + + public: + + static constexpr bool is_col = false; + static constexpr bool is_row = false; + static constexpr bool is_xvec = false; + + inline ~Mat(); + inline Mat(); + + inline explicit Mat(const uword in_n_rows, const uword in_n_cols); + inline explicit Mat(const SizeMat& s); + + template inline explicit Mat(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&); + template inline explicit Mat(const SizeMat& s, const arma_initmode_indicator&); + + template inline Mat(const uword in_n_rows, const uword in_n_cols, const fill::fill_class& f); + template inline Mat(const SizeMat& s, const fill::fill_class& f); + + inline Mat(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f); + inline Mat(const SizeMat& s, const fill::scalar_holder f); + + arma_cold inline Mat(const char* text); + arma_cold inline Mat& operator=(const char* text); + + arma_cold inline Mat(const std::string& text); + arma_cold inline Mat& operator=(const std::string& text); + + inline Mat(const std::vector& x); + inline Mat& operator=(const std::vector& x); + + inline Mat(const std::initializer_list& list); + inline Mat& operator=(const std::initializer_list& list); + + inline Mat(const std::initializer_list< std::initializer_list >& list); + inline Mat& operator=(const std::initializer_list< std::initializer_list >& list); + + inline Mat(Mat&& m); + inline Mat& operator=(Mat&& m); + + inline Mat( eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const bool copy_aux_mem = true, const bool strict = false); + inline Mat(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols); + + inline Mat& operator= (const eT val); + inline Mat& operator+=(const eT val); + inline Mat& operator-=(const eT val); + inline Mat& operator*=(const eT val); + inline Mat& operator/=(const eT val); + + inline Mat(const Mat& m); + inline Mat& operator= (const Mat& m); + inline Mat& operator+=(const Mat& m); + inline Mat& operator-=(const Mat& m); + inline Mat& operator*=(const Mat& m); + inline Mat& operator%=(const Mat& m); + inline Mat& operator/=(const Mat& m); + + template inline Mat(const BaseCube& X); + template inline Mat& operator= (const BaseCube& X); + template inline Mat& operator+=(const BaseCube& X); + template inline Mat& operator-=(const BaseCube& X); + template inline Mat& operator*=(const BaseCube& X); + template inline Mat& operator%=(const BaseCube& X); + template inline Mat& operator/=(const BaseCube& X); + + template + inline explicit Mat(const Base& A, const Base& B); + + inline explicit Mat(const subview& X, const bool use_colmem); // only to be used by the quasi_unwrap class + + inline Mat(const subview& X); + inline Mat& operator= (const subview& X); + inline Mat& operator+=(const subview& X); + inline Mat& operator-=(const subview& X); + inline Mat& operator*=(const subview& X); + inline Mat& operator%=(const subview& X); + inline Mat& operator/=(const subview& X); + + inline Mat(const subview_row_strans& X); // subview_row_strans can only be generated by the Proxy class + inline Mat(const subview_row_htrans& X); // subview_row_htrans can only be generated by the Proxy class + inline Mat(const xvec_htrans& X); // xvec_htrans can only be generated by the Proxy class + + template + inline Mat(const xtrans_mat& X); // xtrans_mat can only be generated by the Proxy class + + inline Mat(const subview_cube& X); + inline Mat& operator= (const subview_cube& X); + inline Mat& operator+=(const subview_cube& X); + inline Mat& operator-=(const subview_cube& X); + inline Mat& operator*=(const subview_cube& X); + inline Mat& operator%=(const subview_cube& X); + inline Mat& operator/=(const subview_cube& X); + + inline Mat(const diagview& X); + inline Mat& operator= (const diagview& X); + inline Mat& operator+=(const diagview& X); + inline Mat& operator-=(const diagview& X); + inline Mat& operator*=(const diagview& X); + inline Mat& operator%=(const diagview& X); + inline Mat& operator/=(const diagview& X); + + template inline Mat(const subview_elem1& X); + template inline Mat& operator= (const subview_elem1& X); + template inline Mat& operator+=(const subview_elem1& X); + template inline Mat& operator-=(const subview_elem1& X); + template inline Mat& operator*=(const subview_elem1& X); + template inline Mat& operator%=(const subview_elem1& X); + template inline Mat& operator/=(const subview_elem1& X); + + template inline Mat(const subview_elem2& X); + template inline Mat& operator= (const subview_elem2& X); + template inline Mat& operator+=(const subview_elem2& X); + template inline Mat& operator-=(const subview_elem2& X); + template inline Mat& operator*=(const subview_elem2& X); + template inline Mat& operator%=(const subview_elem2& X); + template inline Mat& operator/=(const subview_elem2& X); + + // Operators on sparse matrices (and subviews) + template inline explicit Mat(const SpBase& m); + template inline Mat& operator= (const SpBase& m); + template inline Mat& operator+=(const SpBase& m); + template inline Mat& operator-=(const SpBase& m); + template inline Mat& operator*=(const SpBase& m); + template inline Mat& operator%=(const SpBase& m); + template inline Mat& operator/=(const SpBase& m); + + inline explicit Mat(const SpSubview& X); + inline Mat& operator= (const SpSubview& X); + + inline explicit Mat(const spdiagview& X); + inline Mat& operator= (const spdiagview& X); + inline Mat& operator+=(const spdiagview& X); + inline Mat& operator-=(const spdiagview& X); + inline Mat& operator*=(const spdiagview& X); + inline Mat& operator%=(const spdiagview& X); + inline Mat& operator/=(const spdiagview& X); + + + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const eT val); + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const injector_end_of_row<>& x); + + + arma_inline subview_row row(const uword row_num); + arma_inline const subview_row row(const uword row_num) const; + + inline subview_row operator()(const uword row_num, const span& col_span); + inline const subview_row operator()(const uword row_num, const span& col_span) const; + + + arma_inline subview_col col(const uword col_num); + arma_inline const subview_col col(const uword col_num) const; + + inline subview_col operator()(const span& row_span, const uword col_num); + inline const subview_col operator()(const span& row_span, const uword col_num) const; + + inline Col unsafe_col(const uword col_num); + inline const Col unsafe_col(const uword col_num) const; + + + arma_inline subview rows(const uword in_row1, const uword in_row2); + arma_inline const subview rows(const uword in_row1, const uword in_row2) const; + + arma_inline subview_cols cols(const uword in_col1, const uword in_col2); + arma_inline const subview_cols cols(const uword in_col1, const uword in_col2) const; + + inline subview rows(const span& row_span); + inline const subview rows(const span& row_span) const; + + arma_inline subview_cols cols(const span& col_span); + arma_inline const subview_cols cols(const span& col_span) const; + + + arma_inline subview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); + arma_inline const subview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const; + + arma_inline subview submat(const uword in_row1, const uword in_col1, const SizeMat& s); + arma_inline const subview submat(const uword in_row1, const uword in_col1, const SizeMat& s) const; + + inline subview submat (const span& row_span, const span& col_span); + inline const subview submat (const span& row_span, const span& col_span) const; + + inline subview operator()(const span& row_span, const span& col_span); + inline const subview operator()(const span& row_span, const span& col_span) const; + + inline subview operator()(const uword in_row1, const uword in_col1, const SizeMat& s); + inline const subview operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const; + + inline subview head_rows(const uword N); + inline const subview head_rows(const uword N) const; + + inline subview tail_rows(const uword N); + inline const subview tail_rows(const uword N) const; + + inline subview_cols head_cols(const uword N); + inline const subview_cols head_cols(const uword N) const; + + inline subview_cols tail_cols(const uword N); + inline const subview_cols tail_cols(const uword N) const; + + template arma_inline subview_elem1 elem(const Base& a); + template arma_inline const subview_elem1 elem(const Base& a) const; + + template arma_inline subview_elem1 operator()(const Base& a); + template arma_inline const subview_elem1 operator()(const Base& a) const; + + + template arma_inline subview_elem2 elem(const Base& ri, const Base& ci); + template arma_inline const subview_elem2 elem(const Base& ri, const Base& ci) const; + + template arma_inline subview_elem2 submat(const Base& ri, const Base& ci); + template arma_inline const subview_elem2 submat(const Base& ri, const Base& ci) const; + + template arma_inline subview_elem2 operator()(const Base& ri, const Base& ci); + template arma_inline const subview_elem2 operator()(const Base& ri, const Base& ci) const; + + + template arma_inline subview_elem2 rows(const Base& ri); + template arma_inline const subview_elem2 rows(const Base& ri) const; + + template arma_inline subview_elem2 cols(const Base& ci); + template arma_inline const subview_elem2 cols(const Base& ci) const; + + + arma_inline subview_each1< Mat, 0 > each_col(); + arma_inline subview_each1< Mat, 1 > each_row(); + + arma_inline const subview_each1< Mat, 0 > each_col() const; + arma_inline const subview_each1< Mat, 1 > each_row() const; + + template inline subview_each2< Mat, 0, T1 > each_col(const Base& indices); + template inline subview_each2< Mat, 1, T1 > each_row(const Base& indices); + + template inline const subview_each2< Mat, 0, T1 > each_col(const Base& indices) const; + template inline const subview_each2< Mat, 1, T1 > each_row(const Base& indices) const; + + inline Mat& each_col(const std::function< void( Col&) >& F); + inline const Mat& each_col(const std::function< void(const Col&) >& F) const; + + inline Mat& each_row(const std::function< void( Row&) >& F); + inline const Mat& each_row(const std::function< void(const Row&) >& F) const; + + + arma_inline diagview diag(const sword in_id = 0); + arma_inline const diagview diag(const sword in_id = 0) const; + + + inline void swap_rows(const uword in_row1, const uword in_row2); + inline void swap_cols(const uword in_col1, const uword in_col2); + + inline void shed_row(const uword row_num); + inline void shed_col(const uword col_num); + + inline void shed_rows(const uword in_row1, const uword in_row2); + inline void shed_cols(const uword in_col1, const uword in_col2); + + template inline void shed_rows(const Base& indices); + template inline void shed_cols(const Base& indices); + + arma_deprecated inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero); + arma_deprecated inline void insert_cols(const uword col_num, const uword N, const bool set_to_zero); + + inline void insert_rows(const uword row_num, const uword N); + inline void insert_cols(const uword col_num, const uword N); + + template inline void insert_rows(const uword row_num, const Base& X); + template inline void insert_cols(const uword col_num, const Base& X); + + + template inline Mat(const Gen& X); + template inline Mat& operator= (const Gen& X); + template inline Mat& operator+=(const Gen& X); + template inline Mat& operator-=(const Gen& X); + template inline Mat& operator*=(const Gen& X); + template inline Mat& operator%=(const Gen& X); + template inline Mat& operator/=(const Gen& X); + + template inline Mat(const Op& X); + template inline Mat& operator= (const Op& X); + template inline Mat& operator+=(const Op& X); + template inline Mat& operator-=(const Op& X); + template inline Mat& operator*=(const Op& X); + template inline Mat& operator%=(const Op& X); + template inline Mat& operator/=(const Op& X); + + template inline Mat(const eOp& X); + template inline Mat& operator= (const eOp& X); + template inline Mat& operator+=(const eOp& X); + template inline Mat& operator-=(const eOp& X); + template inline Mat& operator*=(const eOp& X); + template inline Mat& operator%=(const eOp& X); + template inline Mat& operator/=(const eOp& X); + + template inline Mat(const mtOp& X); + template inline Mat& operator= (const mtOp& X); + template inline Mat& operator+=(const mtOp& X); + template inline Mat& operator-=(const mtOp& X); + template inline Mat& operator*=(const mtOp& X); + template inline Mat& operator%=(const mtOp& X); + template inline Mat& operator/=(const mtOp& X); + + template inline Mat(const CubeToMatOp& X); + template inline Mat& operator= (const CubeToMatOp& X); + template inline Mat& operator+=(const CubeToMatOp& X); + template inline Mat& operator-=(const CubeToMatOp& X); + template inline Mat& operator*=(const CubeToMatOp& X); + template inline Mat& operator%=(const CubeToMatOp& X); + template inline Mat& operator/=(const CubeToMatOp& X); + + template inline Mat(const SpToDOp& X); + template inline Mat& operator= (const SpToDOp& X); + template inline Mat& operator+=(const SpToDOp& X); + template inline Mat& operator-=(const SpToDOp& X); + template inline Mat& operator*=(const SpToDOp& X); + template inline Mat& operator%=(const SpToDOp& X); + template inline Mat& operator/=(const SpToDOp& X); + + template inline Mat(const Glue& X); + template inline Mat& operator= (const Glue& X); + template inline Mat& operator+=(const Glue& X); + template inline Mat& operator-=(const Glue& X); + template inline Mat& operator*=(const Glue& X); + template inline Mat& operator%=(const Glue& X); + template inline Mat& operator/=(const Glue& X); + + template inline Mat& operator+=(const Glue& X); + template inline Mat& operator-=(const Glue& X); + + template inline Mat(const eGlue& X); + template inline Mat& operator= (const eGlue& X); + template inline Mat& operator+=(const eGlue& X); + template inline Mat& operator-=(const eGlue& X); + template inline Mat& operator*=(const eGlue& X); + template inline Mat& operator%=(const eGlue& X); + template inline Mat& operator/=(const eGlue& X); + + template inline Mat(const mtGlue& X); + template inline Mat& operator= (const mtGlue& X); + template inline Mat& operator+=(const mtGlue& X); + template inline Mat& operator-=(const mtGlue& X); + template inline Mat& operator*=(const mtGlue& X); + template inline Mat& operator%=(const mtGlue& X); + template inline Mat& operator/=(const mtGlue& X); + + template inline Mat(const SpToDGlue& X); + template inline Mat& operator= (const SpToDGlue& X); + template inline Mat& operator+=(const SpToDGlue& X); + template inline Mat& operator-=(const SpToDGlue& X); + template inline Mat& operator*=(const SpToDGlue& X); + template inline Mat& operator%=(const SpToDGlue& X); + template inline Mat& operator/=(const SpToDGlue& X); + + + arma_warn_unused arma_inline const eT& at_alt (const uword ii) const; + + arma_warn_unused arma_inline eT& operator[] (const uword ii); + arma_warn_unused arma_inline const eT& operator[] (const uword ii) const; + arma_warn_unused arma_inline eT& at (const uword ii); + arma_warn_unused arma_inline const eT& at (const uword ii) const; + arma_warn_unused arma_inline eT& operator() (const uword ii); + arma_warn_unused arma_inline const eT& operator() (const uword ii) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col) const; + #endif + + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; + + arma_inline const Mat& operator++(); + arma_inline void operator++(int); + + arma_inline const Mat& operator--(); + arma_inline void operator--(int); + + arma_warn_unused arma_inline bool is_empty() const; + arma_warn_unused arma_inline bool is_vec() const; + arma_warn_unused arma_inline bool is_rowvec() const; + arma_warn_unused arma_inline bool is_colvec() const; + arma_warn_unused arma_inline bool is_square() const; + + arma_warn_unused inline bool internal_is_finite() const; + arma_warn_unused inline bool internal_has_inf() const; + arma_warn_unused inline bool internal_has_nan() const; + arma_warn_unused inline bool internal_has_nonfinite() const; + + arma_warn_unused inline bool is_sorted(const char* direction = "ascend") const; + arma_warn_unused inline bool is_sorted(const char* direction, const uword dim) const; + + template + arma_warn_unused inline bool is_sorted_helper(const comparator& comp, const uword dim) const; + + arma_warn_unused arma_inline bool in_range(const uword ii) const; + arma_warn_unused arma_inline bool in_range(const span& x ) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const span& col_span) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; + + arma_warn_unused arma_inline eT* colptr(const uword in_col); + arma_warn_unused arma_inline const eT* colptr(const uword in_col) const; + + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; + + + template + inline Mat& copy_size(const Base& X); + + inline Mat& set_size(const uword new_n_elem); + inline Mat& set_size(const uword new_n_rows, const uword new_n_cols); + inline Mat& set_size(const SizeMat& s); + + inline Mat& resize(const uword new_n_elem); + inline Mat& resize(const uword new_n_rows, const uword new_n_cols); + inline Mat& resize(const SizeMat& s); + + inline Mat& reshape(const uword new_n_rows, const uword new_n_cols); + inline Mat& reshape(const SizeMat& s); + + arma_deprecated inline void reshape(const uword new_n_rows, const uword new_n_cols, const uword dim); //!< NOTE: don't use this form: it will be removed + + + template inline Mat& for_each(functor F); + template inline const Mat& for_each(functor F) const; + + template inline Mat& transform(functor F); + template inline Mat& imbue(functor F); + + + inline Mat& replace(const eT old_val, const eT new_val); + + inline Mat& clean(const pod_type threshold); + + inline Mat& clamp(const eT min_val, const eT max_val); + + inline Mat& fill(const eT val); + + template + inline Mat& fill(const fill::fill_class& f); + + inline Mat& zeros(); + inline Mat& zeros(const uword new_n_elem); + inline Mat& zeros(const uword new_n_rows, const uword new_n_cols); + inline Mat& zeros(const SizeMat& s); + + inline Mat& ones(); + inline Mat& ones(const uword new_n_elem); + inline Mat& ones(const uword new_n_rows, const uword new_n_cols); + inline Mat& ones(const SizeMat& s); + + inline Mat& randu(); + inline Mat& randu(const uword new_n_elem); + inline Mat& randu(const uword new_n_rows, const uword new_n_cols); + inline Mat& randu(const SizeMat& s); + + inline Mat& randn(); + inline Mat& randn(const uword new_n_elem); + inline Mat& randn(const uword new_n_rows, const uword new_n_cols); + inline Mat& randn(const SizeMat& s); + + inline Mat& eye(); + inline Mat& eye(const uword new_n_rows, const uword new_n_cols); + inline Mat& eye(const SizeMat& s); + + arma_cold inline void reset(); + arma_cold inline void soft_reset(); + + + template inline void set_real(const Base& X); + template inline void set_imag(const Base& X); + + + arma_warn_unused inline eT min() const; + arma_warn_unused inline eT max() const; + + inline eT min(uword& index_of_min_val) const; + inline eT max(uword& index_of_max_val) const; + + inline eT min(uword& row_of_min_val, uword& col_of_min_val) const; + inline eT max(uword& row_of_max_val, uword& col_of_max_val) const; + + + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_cold inline bool save(const csv_name& spec, const file_type type = csv_ascii) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; + + arma_cold inline bool load(const std::string name, const file_type type = auto_detect); + arma_cold inline bool load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_cold inline bool load(const csv_name& spec, const file_type type = csv_ascii); + arma_cold inline bool load( std::istream& is, const file_type type = auto_detect); + + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_deprecated inline bool quiet_save(const csv_name& spec, const file_type type = csv_ascii) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_deprecated inline bool quiet_load(const csv_name& spec, const file_type type = csv_ascii); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = auto_detect); + + + // for container-like functionality + + typedef eT value_type; + typedef uword size_type; + + typedef eT* iterator; + typedef const eT* const_iterator; + + typedef eT* col_iterator; + typedef const eT* const_col_iterator; + + class const_row_iterator; + + class row_iterator + { + public: + + inline row_iterator(); + inline row_iterator(const row_iterator& X); + inline row_iterator(Mat& in_M, const uword in_row, const uword in_col); + + arma_warn_unused inline eT& operator* (); + + inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); + + inline row_iterator& operator--(); + arma_warn_unused inline row_iterator operator--(int); + + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; + + typedef std::bidirectional_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef eT* pointer; + typedef eT& reference; + + arma_aligned Mat* M; + arma_aligned uword current_row; + arma_aligned uword current_col; + }; + + + class const_row_iterator + { + public: + + inline const_row_iterator(); + inline const_row_iterator(const row_iterator& X); + inline const_row_iterator(const const_row_iterator& X); + inline const_row_iterator(const Mat& in_M, const uword in_row, const uword in_col); + + arma_warn_unused inline const eT& operator*() const; + + inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); + + inline const_row_iterator& operator--(); + arma_warn_unused inline const_row_iterator operator--(int); + + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; + + typedef std::bidirectional_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef const eT* pointer; + typedef const eT& reference; + + arma_aligned const Mat* M; + arma_aligned uword current_row; + arma_aligned uword current_col; + }; + + + class const_row_col_iterator; + + class row_col_iterator + { + public: + + inline row_col_iterator(); + inline row_col_iterator(const row_col_iterator& in_it); + inline row_col_iterator(Mat& in_M, const uword row = 0, const uword col = 0); + + arma_warn_unused inline eT& operator*(); + + inline row_col_iterator& operator++(); + arma_warn_unused inline row_col_iterator operator++(int); + + inline row_col_iterator& operator--(); + arma_warn_unused inline row_col_iterator operator--(int); + + arma_warn_unused inline uword row() const; + arma_warn_unused inline uword col() const; + + arma_warn_unused inline bool operator==(const row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const row_col_iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_row_col_iterator& rhs) const; + + typedef std::bidirectional_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef eT* pointer; + typedef eT& reference; + + arma_aligned Mat* M; + arma_aligned eT* current_ptr; + arma_aligned uword current_col; + arma_aligned uword current_row; + }; + + + class const_row_col_iterator + { + public: + + inline const_row_col_iterator(); + inline const_row_col_iterator(const row_col_iterator& in_it); + inline const_row_col_iterator(const const_row_col_iterator& in_it); + inline const_row_col_iterator(const Mat& in_M, const uword row = 0, const uword col = 0); + + arma_warn_unused inline const eT& operator*() const; + + inline const_row_col_iterator& operator++(); + arma_warn_unused inline const_row_col_iterator operator++(int); + + inline const_row_col_iterator& operator--(); + arma_warn_unused inline const_row_col_iterator operator--(int); + + arma_warn_unused inline uword row() const; + arma_warn_unused inline uword col() const; + + arma_warn_unused inline bool operator==(const const_row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_row_col_iterator& rhs) const; + arma_warn_unused inline bool operator==(const row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const row_col_iterator& rhs) const; + + // So that we satisfy the STL iterator types. + typedef std::bidirectional_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef const eT* pointer; + typedef const eT& reference; + + arma_aligned const Mat* M; + arma_aligned const eT* current_ptr; + arma_aligned uword current_col; + arma_aligned uword current_row; + }; + + + inline iterator begin(); + inline const_iterator begin() const; + inline const_iterator cbegin() const; + + inline iterator end(); + inline const_iterator end() const; + inline const_iterator cend() const; + + inline col_iterator begin_col(const uword col_num); + inline const_col_iterator begin_col(const uword col_num) const; + + inline col_iterator end_col (const uword col_num); + inline const_col_iterator end_col (const uword col_num) const; + + inline row_iterator begin_row(const uword row_num); + inline const_row_iterator begin_row(const uword row_num) const; + + inline row_iterator end_row (const uword row_num); + inline const_row_iterator end_row (const uword row_num) const; + + inline row_col_iterator begin_row_col(); + inline const_row_col_iterator begin_row_col() const; + + inline row_col_iterator end_row_col(); + inline const_row_col_iterator end_row_col() const; + + + inline void clear(); + inline bool empty() const; + inline uword size() const; + + arma_warn_unused inline eT& front(); + arma_warn_unused inline const eT& front() const; + + arma_warn_unused inline eT& back(); + arma_warn_unused inline const eT& back() const; + + inline void swap(Mat& B); + + inline void steal_mem(Mat& X); //!< don't use this unless you're writing code internal to Armadillo + inline void steal_mem(Mat& X, const bool is_move); //!< don't use this unless you're writing code internal to Armadillo + + inline void steal_mem_col(Mat& X, const uword max_n_rows); + + + template class fixed; + + + protected: + + inline void init_cold(); + inline void init_warm(uword in_n_rows, uword in_n_cols); + + arma_cold inline void init(const std::string& text); + + inline void init(const std::initializer_list& list); + inline void init(const std::initializer_list< std::initializer_list >& list); + + template + inline void init(const Base& A, const Base& B); + + inline Mat(const char junk, const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols); + + inline Mat(const arma_vec_indicator&, const uhword in_vec_state); + inline Mat(const arma_vec_indicator&, const uword in_n_rows, const uword in_n_cols, const uhword in_vec_state); + + inline Mat(const arma_fixed_indicator&, const uword in_n_rows, const uword in_n_cols, const uhword in_vec_state, const eT* in_mem); + + + friend class Cube; + friend class subview_cube; + friend class glue_join; + friend class op_strans; + friend class op_htrans; + friend class op_resize; + friend class op_mean; + friend class op_max; + friend class op_min; + + + public: + + #if defined(ARMA_EXTRA_MAT_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_MAT_PROTO) + #endif + }; + + + +template +template +class Mat::fixed : public Mat + { + private: + + static constexpr uword fixed_n_elem = fixed_n_rows * fixed_n_cols; + static constexpr bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); + + arma_align_mem eT mem_local_extra[ (use_extra) ? fixed_n_elem : 1 ]; + + + public: + + typedef fixed Mat_fixed_type; + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_col = (fixed_n_cols == 1); + static constexpr bool is_row = (fixed_n_rows == 1); + static constexpr bool is_xvec = false; + + static const uword n_rows; // value provided below the class definition + static const uword n_cols; // value provided below the class definition + static const uword n_elem; // value provided below the class definition + + arma_inline fixed(); + arma_inline fixed(const fixed& X); + + inline fixed(const fill::scalar_holder f); + template inline fixed(const fill::fill_class& f); + template inline fixed(const Base& A); + template inline fixed(const Base& A, const Base& B); + + inline fixed(const eT* aux_mem); + + inline fixed(const char* text); + inline fixed(const std::string& text); + + using Mat::operator=; + using Mat::operator(); + + inline fixed(const std::initializer_list& list); + inline Mat& operator=(const std::initializer_list& list); + + inline fixed(const std::initializer_list< std::initializer_list >& list); + inline Mat& operator=(const std::initializer_list< std::initializer_list >& list); + + arma_inline Mat& operator=(const fixed& X); + + #if defined(ARMA_GOOD_COMPILER) + template inline Mat& operator=(const eOp& X); + template inline Mat& operator=(const eGlue& X); + #endif + + arma_warn_unused arma_inline const Op< Mat_fixed_type, op_htrans > t() const; + arma_warn_unused arma_inline const Op< Mat_fixed_type, op_htrans > ht() const; + arma_warn_unused arma_inline const Op< Mat_fixed_type, op_strans > st() const; + + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; + + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col) const; + #endif + + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; + + arma_warn_unused arma_inline eT* colptr(const uword in_col); + arma_warn_unused arma_inline const eT* colptr(const uword in_col) const; + + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; + + arma_warn_unused arma_inline bool is_vec() const; + + inline const Mat& fill(const eT val); + inline const Mat& zeros(); + inline const Mat& ones(); + }; + + + +// these definitions are outside of the class due to bizarre C++ rules; +// C++17 has inline variables to address this shortcoming + +template +template +const uword Mat::fixed::n_rows = fixed_n_rows; + +template +template +const uword Mat::fixed::n_cols = fixed_n_cols; + +template +template +const uword Mat::fixed::n_elem = fixed_n_rows * fixed_n_cols; + + + +class Mat_aux + { + public: + + template inline static void prefix_pp(Mat& x); + template inline static void prefix_pp(Mat< std::complex >& x); + + template inline static void postfix_pp(Mat& x); + template inline static void postfix_pp(Mat< std::complex >& x); + + template inline static void prefix_mm(Mat& x); + template inline static void prefix_mm(Mat< std::complex >& x); + + template inline static void postfix_mm(Mat& x); + template inline static void postfix_mm(Mat< std::complex >& x); + + template inline static void set_real(Mat& out, const Base& X); + template inline static void set_real(Mat< std::complex >& out, const Base< T,T1>& X); + + template inline static void set_imag(Mat& out, const Base& X); + template inline static void set_imag(Mat< std::complex >& out, const Base< T,T1>& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Mat_meat.hpp b/src/armadillo/include/armadillo_bits/Mat_meat.hpp new file mode 100644 index 0000000..cb92ffd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Mat_meat.hpp @@ -0,0 +1,10169 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Mat +//! @{ + + +template +inline +Mat::~Mat() + { + arma_extra_debug_sigprint_this(this); + + if(n_alloc > 0) + { + arma_extra_debug_print("Mat::destructor: releasing memory"); + memory::release( access::rw(mem) ); + } + + // try to expose buggy user code that accesses deleted objects + if(arma_config::debug) { access::rw(mem) = nullptr; } + + arma_type_check(( is_supported_elem_type::value == false )); + } + + + +template +inline +Mat::Mat() + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + } + + + +//! construct the matrix to have user specified dimensions +template +inline +Mat::Mat(const uword in_n_rows, const uword in_n_cols) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows*in_n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +template +inline +Mat::Mat(const SizeMat& s) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem(s.n_rows*s.n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Mat::Mat(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows*in_n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Mat::Mat(const SizeMat& s, const arma_initmode_indicator&) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem(s.n_rows*s.n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! construct the matrix to have user specified dimensions and fill with specified pattern +template +template +inline +Mat::Mat(const uword in_n_rows, const uword in_n_cols, const fill::fill_class& f) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows*in_n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f); + } + + + +template +template +inline +Mat::Mat(const SizeMat& s, const fill::fill_class& f) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem(s.n_rows*s.n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f); + } + + + +//! construct the matrix to have user specified dimensions and fill with specified value +template +inline +Mat::Mat(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows*in_n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f.scalar); + } + + + +template +inline +Mat::Mat(const SizeMat& s, const fill::scalar_holder f) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem(s.n_rows*s.n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f.scalar); + } + + + +//! constructor used by Row and Col classes +template +inline +Mat::Mat(const arma_vec_indicator&, const uhword in_vec_state) + : n_rows( (in_vec_state == 2) ? 1 : 0 ) + , n_cols( (in_vec_state == 1) ? 1 : 0 ) + , n_elem(0) + , n_alloc(0) + , vec_state(in_vec_state) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + } + + + +//! constructor used by Row and Col classes +template +inline +Mat::Mat(const arma_vec_indicator&, const uword in_n_rows, const uword in_n_cols, const uhword in_vec_state) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows*in_n_cols) + , n_alloc() + , vec_state(in_vec_state) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + } + + + +template +inline +Mat::Mat(const arma_fixed_indicator&, const uword in_n_rows, const uword in_n_cols, const uhword in_vec_state, const eT* in_mem) + : n_rows (in_n_rows) + , n_cols (in_n_cols) + , n_elem (in_n_rows*in_n_cols) + , n_alloc (0) + , vec_state (in_vec_state) + , mem_state (3) + , mem (in_mem) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +void +Mat::init_cold() + { + arma_extra_debug_sigprint( arma_str::format("n_rows = %u, n_cols = %u") % n_rows % n_cols ); + + // ensure that n_elem can hold the result of (n_rows * n_cols) + + #if defined(ARMA_64BIT_WORD) + const char* error_message = "Mat::init(): requested size is too large"; + #else + const char* error_message = "Mat::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + arma_debug_check + ( + ( + ( (n_rows > ARMA_MAX_UHWORD) || (n_cols > ARMA_MAX_UHWORD) ) + ? ( (double(n_rows) * double(n_cols)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message + ); + + if(n_elem <= arma_config::mat_prealloc) + { + if(n_elem > 0) { arma_extra_debug_print("Mat::init(): using local memory"); } + + access::rw(mem) = (n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; + } + else + { + arma_extra_debug_print("Mat::init(): acquiring memory"); + + access::rw(mem) = memory::acquire(n_elem); + access::rw(n_alloc) = n_elem; + } + } + + + +template +inline +void +Mat::init_warm(uword in_n_rows, uword in_n_cols) + { + arma_extra_debug_sigprint( arma_str::format("in_n_rows = %u, in_n_cols = %u") % in_n_rows % in_n_cols ); + + if( (n_rows == in_n_rows) && (n_cols == in_n_cols) ) { return; } + + bool err_state = false; + char* err_msg = nullptr; + + const uhword t_vec_state = vec_state; + const uhword t_mem_state = mem_state; + + const char* error_message_1 = "Mat::init(): size is fixed and hence cannot be changed"; + const char* error_message_2 = "Mat::init(): requested size is not compatible with column vector layout"; + const char* error_message_3 = "Mat::init(): requested size is not compatible with row vector layout"; + + arma_debug_set_error( err_state, err_msg, (t_mem_state == 3), error_message_1 ); + + if(t_vec_state > 0) + { + if( (in_n_rows == 0) && (in_n_cols == 0) ) + { + if(t_vec_state == 1) { in_n_cols = 1; } + if(t_vec_state == 2) { in_n_rows = 1; } + } + else + { + if(t_vec_state == 1) { arma_debug_set_error( err_state, err_msg, (in_n_cols != 1), error_message_2 ); } + if(t_vec_state == 2) { arma_debug_set_error( err_state, err_msg, (in_n_rows != 1), error_message_3 ); } + } + } + + // ensure that n_elem can hold the result of (n_rows * n_cols) + + #if defined(ARMA_64BIT_WORD) + const char* error_message_4 = "Mat::init(): requested size is too large"; + #else + const char* error_message_4 = "Mat::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + arma_debug_set_error + ( + err_state, + err_msg, + ( + ( (in_n_rows > ARMA_MAX_UHWORD) || (in_n_cols > ARMA_MAX_UHWORD) ) + ? ( (double(in_n_rows) * double(in_n_cols)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message_4 + ); + + arma_debug_check(err_state, err_msg); + + const uword old_n_elem = n_elem; + const uword new_n_elem = in_n_rows * in_n_cols; + + if(old_n_elem == new_n_elem) + { + arma_extra_debug_print("Mat::init(): reusing memory"); + access::rw(n_rows) = in_n_rows; + access::rw(n_cols) = in_n_cols; + return; + } + + arma_debug_check( (t_mem_state == 2), "Mat::init(): mismatch between size of auxiliary memory and requested size" ); + + if(new_n_elem <= arma_config::mat_prealloc) + { + if(n_alloc > 0) + { + arma_extra_debug_print("Mat::init(): releasing memory"); + memory::release( access::rw(mem) ); + } + + if(new_n_elem > 0) { arma_extra_debug_print("Mat::init(): using local memory"); } + + access::rw(mem) = (new_n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; + } + else // condition: new_n_elem > arma_config::mat_prealloc + { + if(new_n_elem > n_alloc) + { + if(n_alloc > 0) + { + arma_extra_debug_print("Mat::init(): releasing memory"); + memory::release( access::rw(mem) ); + + // in case memory::acquire() throws an exception + access::rw(mem) = nullptr; + access::rw(n_rows) = 0; + access::rw(n_cols) = 0; + access::rw(n_elem) = 0; + access::rw(n_alloc) = 0; + } + + arma_extra_debug_print("Mat::init(): acquiring memory"); + access::rw(mem) = memory::acquire(new_n_elem); + access::rw(n_alloc) = new_n_elem; + } + else // condition: new_n_elem <= n_alloc + { + arma_extra_debug_print("Mat::init(): reusing memory"); + } + } + + access::rw(n_rows) = in_n_rows; + access::rw(n_cols) = in_n_cols; + access::rw(n_elem) = new_n_elem; + access::rw(mem_state) = 0; + } + + + +//! create the matrix from a textual description +template +inline +Mat::Mat(const char* text) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init( std::string(text) ); + } + + + +//! create the matrix from a textual description +template +inline +Mat& +Mat::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + init( std::string(text) ); + + return *this; + } + + + +//! create the matrix from a textual description +template +inline +Mat::Mat(const std::string& text) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init(text); + } + + + +//! create the matrix from a textual description +template +inline +Mat& +Mat::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + init(text); + + return *this; + } + + + +//! internal function to create the matrix from a textual description +template +inline +void +Mat::init(const std::string& text_orig) + { + arma_extra_debug_sigprint(); + + const bool replace_commas = (is_cx::yes) ? false : ( text_orig.find(',') != std::string::npos ); + + std::string text_mod; + + if(replace_commas) { text_mod = text_orig; std::replace(text_mod.begin(), text_mod.end(), ',', ' '); } + + const std::string& text = (replace_commas) ? text_mod : text_orig; + + // + // work out the size + + uword t_n_rows = 0; + uword t_n_cols = 0; + + bool has_semicolon = false; + bool has_token = false; + + std::string token; + + std::string::size_type line_start = 0; + std::string::size_type line_end = 0; + std::string::size_type line_len = 0; + + std::stringstream line_stream; + + while( line_start < text.length() ) + { + line_end = text.find(';', line_start); + + if(line_end == std::string::npos) + { + has_semicolon = false; + line_end = text.length()-1; + line_len = line_end - line_start + 1; + } + else + { + has_semicolon = true; + line_len = line_end - line_start; // omit the ';' character + } + + line_stream.clear(); + line_stream.str( text.substr(line_start,line_len) ); + + has_token = false; + + uword line_n_cols = 0; + + while(line_stream >> token) { has_token = true; ++line_n_cols; } + + if(t_n_rows == 0) + { + t_n_cols = line_n_cols; + } + else + { + if(has_semicolon || has_token) { arma_check( (line_n_cols != t_n_cols), "Mat::init(): inconsistent number of columns in given string"); } + } + + ++t_n_rows; + + line_start = line_end+1; + } + + // if the last line was empty, ignore it + if( (has_semicolon == false) && (has_token == false) && (t_n_rows >= 1) ) { --t_n_rows; } + + Mat& x = (*this); + x.set_size(t_n_rows, t_n_cols); + + if(x.is_empty()) { return; } + + line_start = 0; + line_end = 0; + line_len = 0; + + uword urow = 0; + + while( line_start < text.length() ) + { + line_end = text.find(';', line_start); + + if(line_end == std::string::npos) + { + line_end = text.length()-1; + line_len = line_end - line_start + 1; + } + else + { + line_len = line_end - line_start; // omit the ';' character + } + + line_stream.clear(); + line_stream.str( text.substr(line_start,line_len) ); + + uword ucol = 0; + while(line_stream >> token) + { + diskio::convert_token( x.at(urow,ucol), token ); + ++ucol; + } + + ++urow; + line_start = line_end+1; + } + } + + + +//! create the matrix from std::vector +template +inline +Mat::Mat(const std::vector& x) + : n_rows(uword(x.size())) + , n_cols(1) + , n_elem(uword(x.size())) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(n_elem > 0) { arrayops::copy( memptr(), &(x[0]), n_elem ); } + } + + + +//! create the matrix from std::vector +template +inline +Mat& +Mat::operator=(const std::vector& x) + { + arma_extra_debug_sigprint(); + + init_warm(uword(x.size()), 1); + + if(x.size() > 0) { arrayops::copy( memptr(), &(x[0]), uword(x.size()) ); } + + return *this; + } + + + +template +inline +Mat::Mat(const std::initializer_list& list) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init(list); + } + + + +template +inline +Mat& +Mat::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + init(list); + + return *this; + } + + + +template +inline +Mat::Mat(const std::initializer_list< std::initializer_list >& list) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init(list); + } + + + +template +inline +Mat& +Mat::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + init(list); + + return *this; + } + + + +template +inline +Mat::Mat(Mat&& X) + : n_rows (X.n_rows ) + , n_cols (X.n_cols ) + , n_elem (X.n_elem ) + , n_alloc (X.n_alloc) + , vec_state(0 ) + , mem_state(0 ) + , mem ( ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) + { + access::rw(mem_state) = X.mem_state; + access::rw(mem) = X.mem; + + access::rw(X.n_rows) = 0; + access::rw(X.n_cols) = 0; + access::rw(X.n_elem) = 0; + access::rw(X.n_alloc) = 0; + access::rw(X.mem_state) = 0; + access::rw(X.mem) = nullptr; + } + else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) + { + init_cold(); + + arrayops::copy( memptr(), X.mem, X.n_elem ); + + if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) + { + access::rw(X.n_rows) = 0; + access::rw(X.n_cols) = 0; + access::rw(X.n_elem) = 0; + access::rw(X.mem) = nullptr; + } + } + } + + + +template +inline +Mat& +Mat::operator=(Mat&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + (*this).steal_mem(X, true); + + return *this; + } + + + +//! Set the matrix to be equal to the specified scalar. +//! NOTE: the size of the matrix will be 1x1 +template +inline +Mat& +Mat::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + init_warm(1,1); + + access::rw(mem[0]) = val; + + return *this; + } + + + +//! In-place addition of a scalar to all elements of the matrix +template +inline +Mat& +Mat::operator+=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_plus( memptr(), val, n_elem ); + + return *this; + } + + + +//! In-place subtraction of a scalar from all elements of the matrix +template +inline +Mat& +Mat::operator-=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_minus( memptr(), val, n_elem ); + + return *this; + } + + + +//! In-place multiplication of all elements of the matrix with a scalar +template +inline +Mat& +Mat::operator*=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_mul( memptr(), val, n_elem ); + + return *this; + } + + + +//! In-place division of all elements of the matrix with a scalar +template +inline +Mat& +Mat::operator/=(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_div( memptr(), val, n_elem ); + + return *this; + } + + + +//! construct a matrix from a given matrix +template +inline +Mat::Mat(const Mat& in_mat) + : n_rows(in_mat.n_rows) + , n_cols(in_mat.n_cols) + , n_elem(in_mat.n_elem) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); + + init_cold(); + + arrayops::copy( memptr(), in_mat.mem, in_mat.n_elem ); + } + + + +//! construct a matrix from a given matrix +template +inline +Mat& +Mat::operator=(const Mat& in_mat) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); + + if(this != &in_mat) + { + init_warm(in_mat.n_rows, in_mat.n_cols); + + arrayops::copy( memptr(), in_mat.mem, in_mat.n_elem ); + } + + return *this; + } + + + +template +inline +void +Mat::init(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + set_size(1, N); + + if(N > 0) { arrayops::copy( memptr(), list.begin(), N ); } + } + + + +template +inline +void +Mat::init(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + uword x_n_rows = uword(list.size()); + uword x_n_cols = 0; + + auto it = list.begin(); + auto it_end = list.end(); + + for(; it != it_end; ++it) { x_n_cols = (std::max)(x_n_cols, uword((*it).size())); } + + Mat& t = (*this); + + if(t.mem_state == 3) + { + arma_debug_check( ((x_n_rows != t.n_rows) || (x_n_cols != t.n_cols)), "Mat::init(): size mismatch between fixed size matrix and initialiser list" ); + } + else + { + t.set_size(x_n_rows, x_n_cols); + } + + uword row_num = 0; + + auto row_it = list.begin(); + auto row_it_end = list.end(); + + for(; row_it != row_it_end; ++row_it) + { + uword col_num = 0; + + auto col_it = (*row_it).begin(); + auto col_it_end = (*row_it).end(); + + for(; col_it != col_it_end; ++col_it) + { + t.at(row_num, col_num) = (*col_it); + ++col_num; + } + + for(uword c=col_num; c < x_n_cols; ++c) + { + t.at(row_num, c) = eT(0); + } + + ++row_num; + } + } + + + +//! for constructing a complex matrix out of two non-complex matrices +template +template +inline +void +Mat::init + ( + const Base::pod_type, T1>& X, + const Base::pod_type, T2>& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type T; + + arma_type_check(( is_cx::no )); //!< compile-time abort if eT is not std::complex + arma_type_check(( is_cx< T>::yes )); //!< compile-time abort if T is std::complex + + arma_type_check(( is_same_type< std::complex, eT >::no )); //!< compile-time abort if types are not compatible + + const Proxy PX(X.get_ref()); + const Proxy PY(Y.get_ref()); + + arma_debug_assert_same_size(PX, PY, "Mat()"); + + const uword local_n_rows = PX.get_n_rows(); + const uword local_n_cols = PX.get_n_cols(); + + init_warm(local_n_rows, local_n_cols); + + eT* out_mem = (*this).memptr(); + + const bool use_at = ( Proxy::use_at || Proxy::use_at ); + + if(use_at == false) + { + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const uword N = n_elem; + + ea_type1 A = PX.get_ea(); + ea_type2 B = PY.get_ea(); + + for(uword ii=0; ii < N; ++ii) + { + out_mem[ii] = std::complex(A[ii], B[ii]); + } + } + else + { + for(uword ucol=0; ucol < local_n_cols; ++ucol) + for(uword urow=0; urow < local_n_rows; ++urow) + { + *out_mem = std::complex(PX.at(urow,ucol), PY.at(urow,ucol)); + out_mem++; + } + } + } + + + +//! swap the contents of this matrix, denoted as matrix A, with given matrix B +template +inline +void +Mat::swap(Mat& B) + { + Mat& A = (*this); + + arma_extra_debug_sigprint(arma_str::format("A = %x B = %x") % &A % &B); + + bool layout_ok = false; + + if(A.vec_state == B.vec_state) + { + layout_ok = true; + } + else + { + const uhword A_vec_state = A.vec_state; + const uhword B_vec_state = B.vec_state; + + const bool A_absorbs_B = (A_vec_state == 0) || ( (A_vec_state == 1) && (B.n_cols == 1) ) || ( (A_vec_state == 2) && (B.n_rows == 1) ); + const bool B_absorbs_A = (B_vec_state == 0) || ( (B_vec_state == 1) && (A.n_cols == 1) ) || ( (B_vec_state == 2) && (A.n_rows == 1) ); + + layout_ok = A_absorbs_B && B_absorbs_A; + } + + const uhword A_mem_state = A.mem_state; + const uhword B_mem_state = B.mem_state; + + if( (A_mem_state == 0) && (B_mem_state == 0) && layout_ok ) + { + const uword A_n_elem = A.n_elem; + const uword B_n_elem = B.n_elem; + + const bool A_use_local_mem = (A.n_alloc <= arma_config::mat_prealloc); + const bool B_use_local_mem = (B.n_alloc <= arma_config::mat_prealloc); + + if( (A_use_local_mem == false) && (B_use_local_mem == false) ) + { + std::swap( access::rw(A.mem), access::rw(B.mem) ); + } + else + if( (A_use_local_mem == true) && (B_use_local_mem == true) ) + { + eT* A_mem_local = &(A.mem_local[0]); + eT* B_mem_local = &(B.mem_local[0]); + + access::rw(A.mem) = A_mem_local; + access::rw(B.mem) = B_mem_local; + + const uword N = (std::max)(A_n_elem, B_n_elem); + + for(uword ii=0; ii < N; ++ii) { std::swap( A_mem_local[ii], B_mem_local[ii] ); } + } + else + if( (A_use_local_mem == true) && (B_use_local_mem == false) ) + { + eT* A_mem_local = &(A.mem_local[0]); + eT* B_mem_local = &(B.mem_local[0]); + + arrayops::copy(B_mem_local, A_mem_local, A_n_elem); + + access::rw(A.mem) = B.mem; + access::rw(B.mem) = B_mem_local; + } + else + if( (A_use_local_mem == false) && (B_use_local_mem == true) ) + { + eT* A_mem_local = &(A.mem_local[0]); + eT* B_mem_local = &(B.mem_local[0]); + + arrayops::copy(A_mem_local, B_mem_local, B_n_elem); + + access::rw(B.mem) = A.mem; + access::rw(A.mem) = A_mem_local; + } + + std::swap( access::rw(A.n_rows), access::rw(B.n_rows) ); + std::swap( access::rw(A.n_cols), access::rw(B.n_cols) ); + std::swap( access::rw(A.n_elem), access::rw(B.n_elem) ); + std::swap( access::rw(A.n_alloc), access::rw(B.n_alloc) ); + } + else + if( (A_mem_state <= 2) && (B_mem_state <= 2) && (A.n_elem == B.n_elem) && layout_ok ) + { + std::swap( access::rw(A.n_rows), access::rw(B.n_rows) ); + std::swap( access::rw(A.n_cols), access::rw(B.n_cols) ); + + const uword N = A.n_elem; + + eT* A_mem = A.memptr(); + eT* B_mem = B.memptr(); + + for(uword ii=0; ii < N; ++ii) { std::swap(A_mem[ii], B_mem[ii]); } + } + else + if( (A.n_rows == B.n_rows) && (A.n_cols == B.n_cols) ) + { + const uword N = A.n_elem; + + eT* A_mem = A.memptr(); + eT* B_mem = B.memptr(); + + for(uword ii=0; ii < N; ++ii) { std::swap(A_mem[ii], B_mem[ii]); } + } + else + { + // generic swap to handle remaining cases + + if(A.n_elem <= B.n_elem) + { + Mat C = A; + + A.steal_mem(B); + B.steal_mem(C); + } + else + { + Mat C = B; + + B.steal_mem(A); + A.steal_mem(C); + } + } + } + + + +//! try to steal the memory from a given matrix; +//! if memory can't be stolen, copy the given matrix +template +inline +void +Mat::steal_mem(Mat& x) + { + arma_extra_debug_sigprint(); + + (*this).steal_mem(x, false); + } + + + +template +inline +void +Mat::steal_mem(Mat& x, const bool is_move) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + const uword x_n_elem = x.n_elem; + const uword x_n_alloc = x.n_alloc; + const uhword x_vec_state = x.vec_state; + const uhword x_mem_state = x.mem_state; + + const uhword t_vec_state = vec_state; + const uhword t_mem_state = mem_state; + + const bool layout_ok = (t_vec_state == x_vec_state) || ((t_vec_state == 1) && (x_n_cols == 1)) || ((t_vec_state == 2) && (x_n_rows == 1)); + + if( layout_ok && (t_mem_state <= 1) && ( (x_n_alloc > arma_config::mat_prealloc) || (x_mem_state == 1) || (is_move && (x_mem_state == 2)) ) ) + { + arma_extra_debug_print("Mat::steal_mem(): stealing memory"); + + reset(); + + access::rw(n_rows) = x_n_rows; + access::rw(n_cols) = x_n_cols; + access::rw(n_elem) = x_n_elem; + access::rw(n_alloc) = x_n_alloc; + access::rw(mem_state) = x_mem_state; + access::rw(mem) = x.mem; + + access::rw(x.n_rows) = (x_vec_state == 2) ? 1 : 0; + access::rw(x.n_cols) = (x_vec_state == 1) ? 1 : 0; + access::rw(x.n_elem) = 0; + access::rw(x.n_alloc) = 0; + access::rw(x.mem_state) = 0; + access::rw(x.mem) = nullptr; + } + else + { + arma_extra_debug_print("Mat::steal_mem(): copying memory"); + + (*this).operator=(x); + + if( (is_move) && (x_mem_state == 0) && (x_n_alloc <= arma_config::mat_prealloc) ) + { + access::rw(x.n_rows) = (x_vec_state == 2) ? 1 : 0; + access::rw(x.n_cols) = (x_vec_state == 1) ? 1 : 0; + access::rw(x.n_elem) = 0; + access::rw(x.mem) = nullptr; + } + } + } + + + +template +inline +void +Mat::steal_mem_col(Mat& x, const uword max_n_rows) + { + arma_extra_debug_sigprint(); + + const uword x_n_elem = x.n_elem; + const uword x_n_alloc = x.n_alloc; + const uhword x_mem_state = x.mem_state; + + const uhword t_vec_state = vec_state; + const uhword t_mem_state = mem_state; + + const uword alt_n_rows = (std::min)(x.n_rows, max_n_rows); + + if((x_n_elem == 0) || (alt_n_rows == 0)) + { + (*this).set_size(0,1); + + return; + } + + if( (this != &x) && (t_vec_state <= 1) && (t_mem_state <= 1) && (x_mem_state <= 1) ) + { + if( (x_mem_state == 0) && ((x_n_alloc <= arma_config::mat_prealloc) || (alt_n_rows <= arma_config::mat_prealloc)) ) + { + (*this).set_size(alt_n_rows, uword(1)); + + arrayops::copy( (*this).memptr(), x.memptr(), alt_n_rows ); + } + else + { + reset(); + + access::rw(n_rows) = alt_n_rows; + access::rw(n_cols) = 1; + access::rw(n_elem) = alt_n_rows; + access::rw(n_alloc) = x_n_alloc; + access::rw(mem_state) = x_mem_state; + access::rw(mem) = x.mem; + + access::rw(x.n_rows) = 0; + access::rw(x.n_cols) = 0; + access::rw(x.n_elem) = 0; + access::rw(x.n_alloc) = 0; + access::rw(x.mem_state) = 0; + access::rw(x.mem) = nullptr; + } + } + else + { + Mat tmp(alt_n_rows, 1, arma_nozeros_indicator()); + + arrayops::copy( tmp.memptr(), x.memptr(), alt_n_rows ); + + steal_mem(tmp); + } + } + + + +//! construct a matrix from a given auxiliary array of eTs. +//! if copy_aux_mem is true, new memory is allocated and the array is copied. +//! if copy_aux_mem is false, the auxiliary array is used directly (without allocating memory and copying). +//! the default is to copy the array. + +template +inline +Mat::Mat(eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const bool copy_aux_mem, const bool strict) + : n_rows ( aux_n_rows ) + , n_cols ( aux_n_cols ) + , n_elem ( aux_n_rows*aux_n_cols ) + , n_alloc ( 0 ) + , vec_state( 0 ) + , mem_state( copy_aux_mem ? 0 : ( strict ? 2 : 1 ) ) + , mem ( copy_aux_mem ? nullptr : aux_mem ) + { + arma_extra_debug_sigprint_this(this); + + if(copy_aux_mem) + { + init_cold(); + + arrayops::copy( memptr(), aux_mem, n_elem ); + } + } + + + +//! construct a matrix from a given auxiliary read-only array of eTs. +//! the array is copied. +template +inline +Mat::Mat(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols) + : n_rows(aux_n_rows) + , n_cols(aux_n_cols) + , n_elem(aux_n_rows*aux_n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + arrayops::copy( memptr(), aux_mem, n_elem ); + } + + + +//! DANGEROUS! Construct a temporary matrix, using auxiliary memory. +//! This constructor is NOT intended for usage by user code. +//! Its sole purpose is to be used by the Cube class. + +template +inline +Mat::Mat(const char junk, const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols) + : n_rows (aux_n_rows ) + , n_cols (aux_n_cols ) + , n_elem (aux_n_rows*aux_n_cols) + , n_alloc (0 ) + , vec_state(0 ) + , mem_state(3 ) + , mem (aux_mem ) + { + arma_extra_debug_sigprint_this(this); + + arma_ignore(junk); + } + + + +//! in-place matrix addition +template +inline +Mat& +Mat::operator+=(const Mat& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "addition"); + + arrayops::inplace_plus( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +//! in-place matrix subtraction +template +inline +Mat& +Mat::operator-=(const Mat& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "subtraction"); + + arrayops::inplace_minus( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +//! in-place matrix multiplication +template +inline +Mat& +Mat::operator*=(const Mat& m) + { + arma_extra_debug_sigprint(); + + glue_times::apply_inplace(*this, m); + + return *this; + } + + + +//! in-place element-wise matrix multiplication +template +inline +Mat& +Mat::operator%=(const Mat& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "element-wise multiplication"); + + arrayops::inplace_mul( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +//! in-place element-wise matrix division +template +inline +Mat& +Mat::operator/=(const Mat& m) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(*this, m, "element-wise division"); + + arrayops::inplace_div( memptr(), m.memptr(), n_elem ); + + return *this; + } + + + +template +template +inline +Mat::Mat(const BaseCube& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(X); + } + + + +template +template +inline +Mat& +Mat::operator=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + Mat& out = *this; + + const unwrap_cube tmp(X.get_ref()); + const Cube& in = tmp.M; + + arma_debug_assert_cube_as_mat(out, in, "copy into matrix", false); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + out.set_size(in_n_rows, in_n_cols); + + for(uword ucol=0; ucol < in_n_cols; ++ucol) + { + arrayops::copy( out.colptr(ucol), in.slice_colptr(0, ucol), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if(in_n_cols == 1) + { + out.set_size(in_n_rows, in_n_slices); + + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::copy( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if(in_n_rows == 1) + { + out.set_size(in_n_cols, in_n_slices); + + for(uword slice=0; slice < in_n_slices; ++slice) + { + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = in.at(0, i, slice); + const eT tmp_j = in.at(0, j, slice); + + out_colptr[i] = tmp_i; + out_colptr[j] = tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] = in.at(0, i, slice); + } + } + } + } + else + { + out.set_size(in_n_slices); + + eT* out_mem = out.memptr(); + + for(uword i=0; i +template +inline +Mat& +Mat::operator+=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + Mat& out = *this; + + const unwrap_cube tmp(X.get_ref()); + const Cube& in = tmp.M; + + arma_debug_assert_cube_as_mat(out, in, "addition", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + for(uword ucol=0; ucol < in_n_cols; ++ucol) + { + arrayops::inplace_plus( out.colptr(ucol), in.slice_colptr(0, ucol), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_plus( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + for(uword slice=0; slice < in_n_slices; ++slice) + { + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = in.at(0, i, slice); + const eT tmp_j = in.at(0, j, slice); + + out_colptr[i] += tmp_i; + out_colptr[j] += tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] += in.at(0, i, slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i +template +inline +Mat& +Mat::operator-=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + Mat& out = *this; + + const unwrap_cube tmp(X.get_ref()); + const Cube& in = tmp.M; + + arma_debug_assert_cube_as_mat(out, in, "subtraction", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + for(uword ucol=0; ucol < in_n_cols; ++ucol) + { + arrayops::inplace_minus( out.colptr(ucol), in.slice_colptr(0, ucol), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_minus( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + for(uword slice=0; slice < in_n_slices; ++slice) + { + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = in.at(0, i, slice); + const eT tmp_j = in.at(0, j, slice); + + out_colptr[i] -= tmp_i; + out_colptr[j] -= tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] -= in.at(0, i, slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i +template +inline +Mat& +Mat::operator*=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const Mat B(X); + + (*this).operator*=(B); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + Mat& out = *this; + + const unwrap_cube tmp(X.get_ref()); + const Cube& in = tmp.M; + + arma_debug_assert_cube_as_mat(out, in, "element-wise multiplication", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + for(uword ucol=0; ucol < in_n_cols; ++ucol) + { + arrayops::inplace_mul( out.colptr(ucol), in.slice_colptr(0, ucol), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_mul( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + for(uword slice=0; slice < in_n_slices; ++slice) + { + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = in.at(0, i, slice); + const eT tmp_j = in.at(0, j, slice); + + out_colptr[i] *= tmp_i; + out_colptr[j] *= tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] *= in.at(0, i, slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i +template +inline +Mat& +Mat::operator/=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + Mat& out = *this; + + const unwrap_cube tmp(X.get_ref()); + const Cube& in = tmp.M; + + arma_debug_assert_cube_as_mat(out, in, "element-wise division", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + for(uword ucol=0; ucol < in_n_cols; ++ucol) + { + arrayops::inplace_div( out.colptr(ucol), in.slice_colptr(0, ucol), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_div( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + for(uword slice=0; slice < in_n_slices; ++slice) + { + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = in.at(0, i, slice); + const eT tmp_j = in.at(0, j, slice); + + out_colptr[i] /= tmp_i; + out_colptr[j] /= tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] /= in.at(0, i, slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i +template +inline +Mat::Mat + ( + const Base::pod_type,T1>& A, + const Base::pod_type,T2>& B + ) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init(A,B); + } + + + +template +inline +Mat::Mat(const subview& X, const bool use_colmem) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc(0) + , vec_state(0) + , mem_state(use_colmem ? 3 : 0) + , mem (use_colmem ? X.colptr(0) : nullptr) + { + arma_extra_debug_sigprint_this(this); + + if(use_colmem) + { + arma_extra_debug_print("Mat::Mat(): using existing memory in a submatrix"); + } + else + { + init_cold(); + + subview::extract(*this, X); + } + } + + + +//! construct a matrix from subview (eg. construct a matrix from a delayed submatrix operation) +template +inline +Mat::Mat(const subview& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + subview::extract(*this, X); + } + + + +//! construct a matrix from subview (eg. construct a matrix from a delayed submatrix operation) +template +inline +Mat& +Mat::operator=(const subview& X) + { + arma_extra_debug_sigprint(); + + const bool alias = (this == &(X.m)); + + if(alias == false) + { + init_warm(X.n_rows, X.n_cols); + + subview::extract(*this, X); + } + else + { + Mat tmp(X); + + steal_mem(tmp); + } + + return *this; + } + + +//! in-place matrix addition (using a submatrix on the right-hand-side) +template +inline +Mat& +Mat::operator+=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::plus_inplace(*this, X); + + return *this; + } + + +//! in-place matrix subtraction (using a submatrix on the right-hand-side) +template +inline +Mat& +Mat::operator-=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::minus_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix mutiplication (using a submatrix on the right-hand-side) +template +inline +Mat& +Mat::operator*=(const subview& X) + { + arma_extra_debug_sigprint(); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +//! in-place element-wise matrix mutiplication (using a submatrix on the right-hand-side) +template +inline +Mat& +Mat::operator%=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::schur_inplace(*this, X); + + return *this; + } + + + +//! in-place element-wise matrix division (using a submatrix on the right-hand-side) +template +inline +Mat& +Mat::operator/=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::div_inplace(*this, X); + + return *this; + } + + + +template +inline +Mat::Mat(const subview_row_strans& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + X.extract(*this); + } + + + +template +inline +Mat::Mat(const subview_row_htrans& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + X.extract(*this); + } + + + +template +inline +Mat::Mat(const xvec_htrans& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + X.extract(*this); + } + + + +template +template +inline +Mat::Mat(const xtrans_mat& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + X.extract(*this); + } + + + +//! construct a matrix from a subview_cube instance +template +inline +Mat::Mat(const subview_cube& x) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + this->operator=(x); + } + + + +//! construct a matrix from a subview_cube instance +template +inline +Mat& +Mat::operator=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::extract(*this, X); + + return *this; + } + + + +//! in-place matrix addition (using a single-slice subcube on the right-hand-side) +template +inline +Mat& +Mat::operator+=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::plus_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix subtraction (using a single-slice subcube on the right-hand-side) +template +inline +Mat& +Mat::operator-=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::minus_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix mutiplication (using a single-slice subcube on the right-hand-side) +template +inline +Mat& +Mat::operator*=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + const Mat tmp(X); + + glue_times::apply_inplace(*this, tmp); + + return *this; + } + + + +//! in-place element-wise matrix mutiplication (using a single-slice subcube on the right-hand-side) +template +inline +Mat& +Mat::operator%=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::schur_inplace(*this, X); + + return *this; + } + + + +//! in-place element-wise matrix division (using a single-slice subcube on the right-hand-side) +template +inline +Mat& +Mat::operator/=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + subview_cube::div_inplace(*this, X); + + return *this; + } + + + +//! construct a matrix from diagview (eg. construct a matrix from a delayed diag operation) +template +inline +Mat::Mat(const diagview& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + diagview::extract(*this, X); + } + + + +//! construct a matrix from diagview (eg. construct a matrix from a delayed diag operation) +template +inline +Mat& +Mat::operator=(const diagview& X) + { + arma_extra_debug_sigprint(); + + const bool alias = (this == &(X.m)); + + if(alias == false) + { + init_warm(X.n_rows, X.n_cols); + + diagview::extract(*this, X); + } + else + { + Mat tmp(X); + + steal_mem(tmp); + } + + return *this; + } + + + +//! in-place matrix addition (using a diagview on the right-hand-side) +template +inline +Mat& +Mat::operator+=(const diagview& X) + { + arma_extra_debug_sigprint(); + + diagview::plus_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix subtraction (using a diagview on the right-hand-side) +template +inline +Mat& +Mat::operator-=(const diagview& X) + { + arma_extra_debug_sigprint(); + + diagview::minus_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix mutiplication (using a diagview on the right-hand-side) +template +inline +Mat& +Mat::operator*=(const diagview& X) + { + arma_extra_debug_sigprint(); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +//! in-place element-wise matrix mutiplication (using a diagview on the right-hand-side) +template +inline +Mat& +Mat::operator%=(const diagview& X) + { + arma_extra_debug_sigprint(); + + diagview::schur_inplace(*this, X); + + return *this; + } + + + +//! in-place element-wise matrix division (using a diagview on the right-hand-side) +template +inline +Mat& +Mat::operator/=(const diagview& X) + { + arma_extra_debug_sigprint(); + + diagview::div_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat::Mat(const subview_elem1& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + this->operator=(X); + } + + + +template +template +inline +Mat& +Mat::operator=(const subview_elem1& X) + { + arma_extra_debug_sigprint(); + + subview_elem1::extract(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const subview_elem1& X) + { + arma_extra_debug_sigprint(); + + subview_elem1::plus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator-=(const subview_elem1& X) + { + arma_extra_debug_sigprint(); + + subview_elem1::minus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator*=(const subview_elem1& X) + { + arma_extra_debug_sigprint(); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const subview_elem1& X) + { + arma_extra_debug_sigprint(); + + subview_elem1::schur_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator/=(const subview_elem1& X) + { + arma_extra_debug_sigprint(); + + subview_elem1::div_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat::Mat(const subview_elem2& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + this->operator=(X); + } + + + +template +template +inline +Mat& +Mat::operator=(const subview_elem2& X) + { + arma_extra_debug_sigprint(); + + subview_elem2::extract(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const subview_elem2& X) + { + arma_extra_debug_sigprint(); + + subview_elem2::plus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator-=(const subview_elem2& X) + { + arma_extra_debug_sigprint(); + + subview_elem2::minus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator*=(const subview_elem2& X) + { + arma_extra_debug_sigprint(); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const subview_elem2& X) + { + arma_extra_debug_sigprint(); + + subview_elem2::schur_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator/=(const subview_elem2& X) + { + arma_extra_debug_sigprint(); + + subview_elem2::div_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat::Mat(const SpBase& m) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(m); + } + + + +template +template +inline +Mat& +Mat::operator=(const SpBase& m) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(m.get_ref()); + const SpMat& x = U.M; + + const uword x_n_cols = x.n_cols; + + (*this).zeros(x.n_rows, x_n_cols); + + if(x.n_nonzero == 0) { return *this; } + + const eT* x_values = x.values; + const uword* x_row_indices = x.row_indices; + const uword* x_col_ptrs = x.col_ptrs; + + for(uword x_col = 0; x_col < x_n_cols; ++x_col) + { + const uword start = x_col_ptrs[x_col ]; + const uword end = x_col_ptrs[x_col + 1]; + + for(uword i = start; i < end; ++i) + { + const uword x_row = x_row_indices[i]; + const eT x_val = x_values[i]; + + at(x_row, x_col) = x_val; + } + } + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const SpBase& m) + { + arma_extra_debug_sigprint(); + + const SpProxy p(m.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "addition"); + + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + for(; it != it_end; ++it) { at(it.row(), it.col()) += (*it); } + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator-=(const SpBase& m) + { + arma_extra_debug_sigprint(); + + const SpProxy p(m.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "subtraction"); + + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + for(; it != it_end; ++it) { at(it.row(), it.col()) -= (*it); } + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator*=(const SpBase& m) + { + arma_extra_debug_sigprint(); + + Mat z = (*this) * m.get_ref(); + + steal_mem(z); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const SpBase& m) + { + arma_extra_debug_sigprint(); + + const SpProxy p(m.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise multiplication"); + + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + // We have to zero everything that isn't being used. + arrayops::fill_zeros(memptr(), (it.col() * n_rows) + it.row()); + + while(it != it_end) + { + const uword cur_loc = (it.col() * n_rows) + it.row(); + + access::rw(mem[cur_loc]) *= (*it); + + ++it; + + const uword next_loc = (it == it_end) + ? (p.get_n_cols() * n_rows) + : (it.col() * n_rows) + it.row(); + + arrayops::fill_zeros(memptr() + cur_loc + 1, (next_loc - cur_loc - 1)); + } + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator/=(const SpBase& m) + { + arma_extra_debug_sigprint(); + + // NOTE: use of this function is not advised; it is implemented only for completeness + + const SpProxy p(m.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise division"); + + for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) + { + at(r, c) /= p.at(r, c); + } + + return *this; + } + + + +template +inline +Mat::Mat(const SpSubview& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(X); + } + + + +template +inline +Mat& +Mat::operator=(const SpSubview& X) + { + arma_extra_debug_sigprint(); + + (*this).zeros(X.n_rows, X.n_cols); + + if(X.n_nonzero == 0) { return *this; } + + if(X.n_rows == X.m.n_rows) + { + X.m.sync(); + + const uword sv_col_start = X.aux_col1; + const uword sv_col_end = X.aux_col1 + X.n_cols - 1; + + const eT* m_values = X.m.values; + const uword* m_row_indices = X.m.row_indices; + const uword* m_col_ptrs = X.m.col_ptrs; + + for(uword m_col = sv_col_start; m_col <= sv_col_end; ++m_col) + { + const uword m_col_adjusted = m_col - sv_col_start; + + const uword start = m_col_ptrs[m_col ]; + const uword end = m_col_ptrs[m_col + 1]; + + for(uword ii = start; ii < end; ++ii) + { + const uword m_row = m_row_indices[ii]; + const eT m_val = m_values[ii]; + + at(m_row, m_col_adjusted) = m_val; + } + } + } + else + { + typename SpSubview::const_iterator it = X.begin(); + typename SpSubview::const_iterator it_end = X.end(); + + for(; it != it_end; ++it) { at(it.row(), it.col()) = (*it); } + } + + return *this; + } + + + +template +inline +Mat::Mat(const spdiagview& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(X.n_elem) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + spdiagview::extract(*this, X); + } + + + +template +inline +Mat& +Mat::operator=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + init_warm(X.n_rows, X.n_cols); + + spdiagview::extract(*this, X); + + return *this; + } + + + +template +inline +Mat& +Mat::operator+=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const Mat tmp(X); + + (*this).operator+=(tmp); + + return *this; + } + + + +template +inline +Mat& +Mat::operator-=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const Mat tmp(X); + + (*this).operator-=(tmp); + + return *this; + } + + + +template +inline +Mat& +Mat::operator*=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const Mat tmp(X); + + (*this).operator*=(tmp); + + return *this; + } + + + +template +inline +Mat& +Mat::operator%=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const Mat tmp(X); + + (*this).operator%=(tmp); + + return *this; + } + + + +template +inline +Mat& +Mat::operator/=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const Mat tmp(X); + + (*this).operator/=(tmp); + + return *this; + } + + + +template +inline +mat_injector< Mat > +Mat::operator<<(const eT val) + { + return mat_injector< Mat >(*this, val); + } + + + +template +inline +mat_injector< Mat > +Mat::operator<<(const injector_end_of_row<>& x) + { + return mat_injector< Mat >(*this, x); + } + + + +//! creation of subview (row vector) +template +arma_inline +subview_row +Mat::row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= n_rows, "Mat::row(): index out of bounds" ); + + return subview_row(*this, row_num); + } + + + +//! creation of subview (row vector) +template +arma_inline +const subview_row +Mat::row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= n_rows, "Mat::row(): index out of bounds" ); + + return subview_row(*this, row_num); + } + + + +template +inline +subview_row +Mat::operator()(const uword row_num, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + (row_num >= n_rows) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Mat::operator(): indices out of bounds or incorrectly used" + ); + + return subview_row(*this, row_num, in_col1, submat_n_cols); + } + + + +template +inline +const subview_row +Mat::operator()(const uword row_num, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + (row_num >= n_rows) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Mat::operator(): indices out of bounds or incorrectly used" + ); + + return subview_row(*this, row_num, in_col1, submat_n_cols); + } + + + +//! creation of subview (column vector) +template +arma_inline +subview_col +Mat::col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "Mat::col(): index out of bounds" ); + + return subview_col(*this, col_num); + } + + + +//! creation of subview (column vector) +template +arma_inline +const subview_col +Mat::col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "Mat::col(): index out of bounds" ); + + return subview_col(*this, col_num); + } + + + +template +inline +subview_col +Mat::operator()(const span& row_span, const uword col_num) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds + ( + (col_num >= n_cols) + || + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "Mat::operator(): indices out of bounds or incorrectly used" + ); + + return subview_col(*this, col_num, in_row1, submat_n_rows); + } + + + +template +inline +const subview_col +Mat::operator()(const span& row_span, const uword col_num) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds + ( + (col_num >= n_cols) + || + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "Mat::operator(): indices out of bounds or incorrectly used" + ); + + return subview_col(*this, col_num, in_row1, submat_n_rows); + } + + + +//! create a Col object which uses memory from an existing matrix object. +//! this approach is currently not alias safe +//! and does not take into account that the parent matrix object could be deleted. +//! if deleted memory is accessed by the created Col object, +//! it will cause memory corruption and/or a crash +template +inline +Col +Mat::unsafe_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "Mat::unsafe_col(): index out of bounds" ); + + return Col(colptr(col_num), n_rows, false, true); + } + + + +//! create a Col object which uses memory from an existing matrix object. +//! this approach is currently not alias safe +//! and does not take into account that the parent matrix object could be deleted. +//! if deleted memory is accessed by the created Col object, +//! it will cause memory corruption and/or a crash +template +inline +const Col +Mat::unsafe_col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "Mat::unsafe_col(): index out of bounds" ); + + typedef const Col out_type; + + return out_type(const_cast(colptr(col_num)), n_rows, false, true); + } + + + +//! creation of subview (submatrix comprised of specified row vectors) +template +arma_inline +subview +Mat::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "Mat::rows(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return subview(*this, in_row1, 0, subview_n_rows, n_cols ); + } + + + +//! creation of subview (submatrix comprised of specified row vectors) +template +arma_inline +const subview +Mat::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "Mat::rows(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return subview(*this, in_row1, 0, subview_n_rows, n_cols ); + } + + + +//! creation of subview (submatrix comprised of specified column vectors) +template +arma_inline +subview_cols +Mat::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "Mat::cols(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview_cols(*this, in_col1, subview_n_cols); + } + + + +//! creation of subview (submatrix comprised of specified column vectors) +template +arma_inline +const subview_cols +Mat::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "Mat::cols(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview_cols(*this, in_col1, subview_n_cols); + } + + + +//! creation of subview (submatrix comprised of specified row vectors) +template +inline +subview +Mat::rows(const span& row_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "Mat::rows(): indices out of bounds or incorrectly used" + ); + + return subview(*this, in_row1, 0, submat_n_rows, n_cols); + } + + + +//! creation of subview (submatrix comprised of specified row vectors) +template +inline +const subview +Mat::rows(const span& row_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "Mat::rows(): indices out of bounds or incorrectly used" + ); + + return subview(*this, in_row1, 0, submat_n_rows, n_cols); + } + + + +//! creation of subview (submatrix comprised of specified column vectors) +template +arma_inline +subview_cols +Mat::cols(const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Mat::cols(): indices out of bounds or incorrectly used" + ); + + return subview_cols(*this, in_col1, submat_n_cols); + } + + + +//! creation of subview (submatrix comprised of specified column vectors) +template +arma_inline +const subview_cols +Mat::cols(const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Mat::cols(): indices out of bounds or incorrectly used" + ); + + return subview_cols(*this, in_col1, submat_n_cols); + } + + + +//! creation of subview (submatrix) +template +arma_inline +subview +Mat::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "Mat::submat(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); + } + + + +//! creation of subview (generic submatrix) +template +arma_inline +const subview +Mat::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "Mat::submat(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); + } + + + +//! creation of subview (submatrix) +template +arma_inline +subview +Mat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "Mat::submat(): indices or size out of bounds" + ); + + return subview(*this, in_row1, in_col1, s_n_rows, s_n_cols); + } + + + +//! creation of subview (submatrix) +template +arma_inline +const subview +Mat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "Mat::submat(): indices or size out of bounds" + ); + + return subview(*this, in_row1, in_col1, s_n_rows, s_n_cols); + } + + + +//! creation of subview (submatrix) +template +inline +subview +Mat::submat(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Mat::submat(): indices out of bounds or incorrectly used" + ); + + return subview(*this, in_row1, in_col1, submat_n_rows, submat_n_cols); + } + + + +//! creation of subview (generic submatrix) +template +inline +const subview +Mat::submat(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "Mat::submat(): indices out of bounds or incorrectly used" + ); + + return subview(*this, in_row1, in_col1, submat_n_rows, submat_n_cols); + } + + + +template +inline +subview +Mat::operator()(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + return (*this).submat(row_span, col_span); + } + + + +template +inline +const subview +Mat::operator()(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + return (*this).submat(row_span, col_span); + } + + + +template +inline +subview +Mat::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).submat(in_row1, in_col1, s); + } + + + +template +inline +const subview +Mat::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + return (*this).submat(in_row1, in_col1, s); + } + + + +template +inline +subview +Mat::head_rows(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "Mat::head_rows(): size out of bounds" ); + + return subview(*this, 0, 0, N, n_cols); + } + + + +template +inline +const subview +Mat::head_rows(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "Mat::head_rows(): size out of bounds" ); + + return subview(*this, 0, 0, N, n_cols); + } + + + +template +inline +subview +Mat::tail_rows(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "Mat::tail_rows(): size out of bounds" ); + + const uword start_row = n_rows - N; + + return subview(*this, start_row, 0, N, n_cols); + } + + + +template +inline +const subview +Mat::tail_rows(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "Mat::tail_rows(): size out of bounds" ); + + const uword start_row = n_rows - N; + + return subview(*this, start_row, 0, N, n_cols); + } + + + +template +inline +subview_cols +Mat::head_cols(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "Mat::head_cols(): size out of bounds" ); + + return subview_cols(*this, 0, N); + } + + + +template +inline +const subview_cols +Mat::head_cols(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "Mat::head_cols(): size out of bounds" ); + + return subview_cols(*this, 0, N); + } + + + +template +inline +subview_cols +Mat::tail_cols(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "Mat::tail_cols(): size out of bounds" ); + + const uword start_col = n_cols - N; + + return subview_cols(*this, start_col, N); + } + + + +template +inline +const subview_cols +Mat::tail_cols(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "Mat::tail_cols(): size out of bounds" ); + + const uword start_col = n_cols - N; + + return subview_cols(*this, start_col, N); + } + + + +template +template +arma_inline +subview_elem1 +Mat::elem(const Base& a) + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +template +arma_inline +const subview_elem1 +Mat::elem(const Base& a) const + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +template +arma_inline +subview_elem1 +Mat::operator()(const Base& a) + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +template +arma_inline +const subview_elem1 +Mat::operator()(const Base& a) const + { + arma_extra_debug_sigprint(); + + return subview_elem1(*this, a); + } + + + +template +template +arma_inline +subview_elem2 +Mat::elem(const Base& ri, const Base& ci) + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ci, false, false); + } + + + +template +template +arma_inline +const subview_elem2 +Mat::elem(const Base& ri, const Base& ci) const + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ci, false, false); + } + + + +template +template +arma_inline +subview_elem2 +Mat::submat(const Base& ri, const Base& ci) + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ci, false, false); + } + + + +template +template +arma_inline +const subview_elem2 +Mat::submat(const Base& ri, const Base& ci) const + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ci, false, false); + } + + + +template +template +arma_inline +subview_elem2 +Mat::operator()(const Base& ri, const Base& ci) + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ci, false, false); + } + + + +template +template +arma_inline +const subview_elem2 +Mat::operator()(const Base& ri, const Base& ci) const + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ci, false, false); + } + + + +template +template +arma_inline +subview_elem2 +Mat::rows(const Base& ri) + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ri, false, true); + } + + + +template +template +arma_inline +const subview_elem2 +Mat::rows(const Base& ri) const + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ri, ri, false, true); + } + + + +template +template +arma_inline +subview_elem2 +Mat::cols(const Base& ci) + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ci, ci, true, false); + } + + + +template +template +arma_inline +const subview_elem2 +Mat::cols(const Base& ci) const + { + arma_extra_debug_sigprint(); + + return subview_elem2(*this, ci, ci, true, false); + } + + + +template +arma_inline +subview_each1< Mat, 0 > +Mat::each_col() + { + arma_extra_debug_sigprint(); + + return subview_each1< Mat, 0>(*this); + } + + + +template +arma_inline +subview_each1< Mat, 1 > +Mat::each_row() + { + arma_extra_debug_sigprint(); + + return subview_each1< Mat, 1>(*this); + } + + + +template +arma_inline +const subview_each1< Mat, 0 > +Mat::each_col() const + { + arma_extra_debug_sigprint(); + + return subview_each1< Mat, 0>(*this); + } + + + +template +arma_inline +const subview_each1< Mat, 1 > +Mat::each_row() const + { + arma_extra_debug_sigprint(); + + return subview_each1< Mat, 1>(*this); + } + + + +template +template +inline +subview_each2< Mat, 0, T1 > +Mat::each_col(const Base& indices) + { + arma_extra_debug_sigprint(); + + return subview_each2< Mat, 0, T1 >(*this, indices); + } + + + +template +template +inline +subview_each2< Mat, 1, T1 > +Mat::each_row(const Base& indices) + { + arma_extra_debug_sigprint(); + + return subview_each2< Mat, 1, T1 >(*this, indices); + } + + + +template +template +inline +const subview_each2< Mat, 0, T1 > +Mat::each_col(const Base& indices) const + { + arma_extra_debug_sigprint(); + + return subview_each2< Mat, 0, T1 >(*this, indices); + } + + + +template +template +inline +const subview_each2< Mat, 1, T1 > +Mat::each_row(const Base& indices) const + { + arma_extra_debug_sigprint(); + + return subview_each2< Mat, 1, T1 >(*this, indices); + } + + + +//! apply a lambda function to each column, where each column is interpreted as a column vector +template +inline +Mat& +Mat::each_col(const std::function< void(Col&) >& F) + { + arma_extra_debug_sigprint(); + + for(uword ii=0; ii < n_cols; ++ii) + { + Col tmp(colptr(ii), n_rows, false, true); + F(tmp); + } + + return *this; + } + + + +template +inline +const Mat& +Mat::each_col(const std::function< void(const Col&) >& F) const + { + arma_extra_debug_sigprint(); + + for(uword ii=0; ii < n_cols; ++ii) + { + const Col tmp(const_cast(colptr(ii)), n_rows, false, true); + F(tmp); + } + + return *this; + } + + + +//! apply a lambda function to each row, where each row is interpreted as a row vector +template +inline +Mat& +Mat::each_row(const std::function< void(Row&) >& F) + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); + + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + { + for(uword col_id = 0; col_id < n_cols; ++col_id) + { + const eT* col_mem = colptr(col_id); + + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; + } + + F(tmp1); + F(tmp2); + + for(uword col_id = 0; col_id < n_cols; ++col_id) + { + eT* col_mem = colptr(col_id); + + col_mem[ii] = tmp1_mem[col_id]; + col_mem[jj] = tmp2_mem[col_id]; + } + } + + if(ii < n_rows) + { + tmp1 = (*this).row(ii); + + F(tmp1); + + (*this).row(ii) = tmp1; + } + + return *this; + } + + + +template +inline +const Mat& +Mat::each_row(const std::function< void(const Row&) >& F) const + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); + + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + { + for(uword col_id = 0; col_id < n_cols; ++col_id) + { + const eT* col_mem = colptr(col_id); + + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; + } + + F(tmp1); + F(tmp2); + } + + if(ii < n_rows) + { + tmp1 = (*this).row(ii); + + F(tmp1); + } + + return *this; + } + + + +//! creation of diagview (diagonal) +template +arma_inline +diagview +Mat::diag(const sword in_id) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (in_id < 0) ? uword(-in_id) : 0; + const uword col_offset = (in_id > 0) ? uword( in_id) : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "Mat::diag(): requested diagonal out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + return diagview(*this, row_offset, col_offset, len); + } + + + +//! creation of diagview (diagonal) +template +arma_inline +const diagview +Mat::diag(const sword in_id) const + { + arma_extra_debug_sigprint(); + + const uword row_offset = uword( (in_id < 0) ? -in_id : 0 ); + const uword col_offset = uword( (in_id > 0) ? in_id : 0 ); + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "Mat::diag(): requested diagonal out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + return diagview(*this, row_offset, col_offset, len); + } + + + +template +inline +void +Mat::swap_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + arma_debug_check_bounds + ( + (in_row1 >= local_n_rows) || (in_row2 >= local_n_rows), + "Mat::swap_rows(): index out of bounds" + ); + + if(n_elem > 0) + { + for(uword ucol=0; ucol < local_n_cols; ++ucol) + { + const uword offset = ucol * local_n_rows; + const uword pos1 = in_row1 + offset; + const uword pos2 = in_row2 + offset; + + std::swap( access::rw(mem[pos1]), access::rw(mem[pos2]) ); + } + } + } + + + +template +inline +void +Mat::swap_cols(const uword in_colA, const uword in_colB) + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + arma_debug_check_bounds + ( + (in_colA >= local_n_cols) || (in_colB >= local_n_cols), + "Mat::swap_cols(): index out of bounds" + ); + + if(n_elem > 0) + { + eT* ptrA = colptr(in_colA); + eT* ptrB = colptr(in_colB); + + eT tmp_i; + eT tmp_j; + + uword iq,jq; + for(iq=0, jq=1; jq < local_n_rows; iq+=2, jq+=2) + { + tmp_i = ptrA[iq]; + tmp_j = ptrA[jq]; + + ptrA[iq] = ptrB[iq]; + ptrA[jq] = ptrB[jq]; + + ptrB[iq] = tmp_i; + ptrB[jq] = tmp_j; + } + + if(iq < local_n_rows) + { + std::swap( ptrA[iq], ptrB[iq] ); + } + } + } + + + +//! remove specified row +template +inline +void +Mat::shed_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= n_rows, "Mat::shed_row(): index out of bounds" ); + + shed_rows(row_num, row_num); + } + + + +//! remove specified column +template +inline +void +Mat::shed_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "Mat::shed_col(): index out of bounds" ); + + shed_cols(col_num, col_num); + } + + + +//! remove specified rows +template +inline +void +Mat::shed_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "Mat::shed_rows(): indices out of bounds or incorrectly used" + ); + + const uword n_keep_front = in_row1; + const uword n_keep_back = n_rows - (in_row2 + 1); + + Mat X(n_keep_front + n_keep_back, n_cols, arma_nozeros_indicator()); + + if(n_keep_front > 0) + { + X.rows( 0, (n_keep_front-1) ) = rows( 0, (in_row1-1) ); + } + + if(n_keep_back > 0) + { + X.rows( n_keep_front, (n_keep_front+n_keep_back-1) ) = rows( (in_row2+1), (n_rows-1) ); + } + + steal_mem(X); + } + + + +//! remove specified columns +template +inline +void +Mat::shed_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "Mat::shed_cols(): indices out of bounds or incorrectly used" + ); + + const uword n_keep_front = in_col1; + const uword n_keep_back = n_cols - (in_col2 + 1); + + Mat X(n_rows, n_keep_front + n_keep_back, arma_nozeros_indicator()); + + if(n_keep_front > 0) + { + X.cols( 0, (n_keep_front-1) ) = cols( 0, (in_col1-1) ); + } + + if(n_keep_back > 0) + { + X.cols( n_keep_front, (n_keep_front+n_keep_back-1) ) = cols( (in_col2+1), (n_cols-1) ); + } + + steal_mem(X); + } + + + +//! remove specified rows +template +template +inline +void +Mat::shed_rows(const Base& indices) + { + arma_extra_debug_sigprint(); + + const unwrap_check_mixed U(indices.get_ref(), *this); + const Mat& tmp1 = U.M; + + arma_debug_check( ((tmp1.is_vec() == false) && (tmp1.is_empty() == false)), "Mat::shed_rows(): list of indices must be a vector" ); + + if(tmp1.is_empty()) { return; } + + const Col tmp2(const_cast(tmp1.memptr()), tmp1.n_elem, false, false); + + const Col& rows_to_shed = (tmp2.is_sorted("strictascend") == false) + ? Col(unique(tmp2)) + : Col(const_cast(tmp2.memptr()), tmp2.n_elem, false, false); + + const uword* rows_to_shed_mem = rows_to_shed.memptr(); + const uword N = rows_to_shed.n_elem; + + if(arma_config::debug) + { + for(uword i=0; i= n_rows), "Mat::shed_rows(): indices out of bounds" ); + } + } + + Col tmp3(n_rows, arma_nozeros_indicator()); + + uword* tmp3_mem = tmp3.memptr(); + + uword i = 0; + uword count = 0; + + for(uword j=0; j < n_rows; ++j) + { + if(i < N) + { + if( j != rows_to_shed_mem[i] ) + { + tmp3_mem[count] = j; + ++count; + } + else + { + ++i; + } + } + else + { + tmp3_mem[count] = j; + ++count; + } + } + + const Col rows_to_keep(tmp3.memptr(), count, false, false); + + Mat X = (*this).rows(rows_to_keep); + + steal_mem(X); + } + + + +//! remove specified columns +template +template +inline +void +Mat::shed_cols(const Base& indices) + { + arma_extra_debug_sigprint(); + + const unwrap_check_mixed U(indices.get_ref(), *this); + const Mat& tmp1 = U.M; + + arma_debug_check( ((tmp1.is_vec() == false) && (tmp1.is_empty() == false)), "Mat::shed_cols(): list of indices must be a vector" ); + + if(tmp1.is_empty()) { return; } + + const Col tmp2(const_cast(tmp1.memptr()), tmp1.n_elem, false, false); + + const Col& cols_to_shed = (tmp2.is_sorted("strictascend") == false) + ? Col(unique(tmp2)) + : Col(const_cast(tmp2.memptr()), tmp2.n_elem, false, false); + + const uword* cols_to_shed_mem = cols_to_shed.memptr(); + const uword N = cols_to_shed.n_elem; + + if(arma_config::debug) + { + for(uword i=0; i= n_cols), "Mat::shed_cols(): indices out of bounds" ); + } + } + + Col tmp3(n_cols, arma_nozeros_indicator()); + + uword* tmp3_mem = tmp3.memptr(); + + uword i = 0; + uword count = 0; + + for(uword j=0; j < n_cols; ++j) + { + if(i < N) + { + if( j != cols_to_shed_mem[i] ) + { + tmp3_mem[count] = j; + ++count; + } + else + { + ++i; + } + } + else + { + tmp3_mem[count] = j; + ++count; + } + } + + const Col cols_to_keep(tmp3.memptr(), count, false, false); + + Mat X = (*this).cols(cols_to_keep); + + steal_mem(X); + } + + + +template +inline +void +Mat::insert_rows(const uword row_num, const uword N, const bool set_to_zero) + { + arma_extra_debug_sigprint(); + + arma_ignore(set_to_zero); + + (*this).insert_rows(row_num, N); + } + + + +template +inline +void +Mat::insert_rows(const uword row_num, const uword N) + { + arma_extra_debug_sigprint(); + + const uword t_n_rows = n_rows; + const uword t_n_cols = n_cols; + + const uword A_n_rows = row_num; + const uword B_n_rows = t_n_rows - row_num; + + // insertion at row_num == n_rows is in effect an append operation + arma_debug_check_bounds( (row_num > t_n_rows), "Mat::insert_rows(): index out of bounds" ); + + if(N == 0) { return; } + + Mat out(t_n_rows + N, t_n_cols, arma_nozeros_indicator()); + + if(A_n_rows > 0) + { + out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); + } + + if(B_n_rows > 0) + { + out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows-1); + } + + out.rows(row_num, row_num + N - 1).zeros(); + + steal_mem(out); + } + + + +template +inline +void +Mat::insert_cols(const uword col_num, const uword N, const bool set_to_zero) + { + arma_extra_debug_sigprint(); + + arma_ignore(set_to_zero); + + (*this).insert_cols(col_num, N); + } + + + +template +inline +void +Mat::insert_cols(const uword col_num, const uword N) + { + arma_extra_debug_sigprint(); + + const uword t_n_rows = n_rows; + const uword t_n_cols = n_cols; + + const uword A_n_cols = col_num; + const uword B_n_cols = t_n_cols - col_num; + + // insertion at col_num == n_cols is in effect an append operation + arma_debug_check_bounds( (col_num > t_n_cols), "Mat::insert_cols(): index out of bounds" ); + + if(N == 0) { return; } + + Mat out(t_n_rows, t_n_cols + N, arma_nozeros_indicator()); + + if(A_n_cols > 0) + { + out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); + } + + if(B_n_cols > 0) + { + out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols-1); + } + + out.cols(col_num, col_num + N - 1).zeros(); + + steal_mem(out); + } + + + +//! insert the given object at the specified row position; +//! the given object must have the same number of columns as the matrix +template +template +inline +void +Mat::insert_rows(const uword row_num, const Base& X) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(X.get_ref()); + const Mat& C = tmp.M; + + const uword C_n_rows = C.n_rows; + const uword C_n_cols = C.n_cols; + + const uword t_n_rows = n_rows; + const uword t_n_cols = n_cols; + + const uword A_n_rows = row_num; + const uword B_n_rows = t_n_rows - row_num; + + bool err_state = false; + char* err_msg = nullptr; + + const char* error_message_1 = "Mat::insert_rows(): index out of bounds"; + const char* error_message_2 = "Mat::insert_rows(): given object has an incompatible number of columns"; + + // insertion at row_num == n_rows is in effect an append operation + + arma_debug_set_error + ( + err_state, + err_msg, + (row_num > t_n_rows), + error_message_1 + ); + + arma_debug_set_error + ( + err_state, + err_msg, + ( (C_n_cols != t_n_cols) && ( (t_n_rows > 0) || (t_n_cols > 0) ) && ( (C_n_rows > 0) || (C_n_cols > 0) ) ), + error_message_2 + ); + + arma_debug_check_bounds(err_state, err_msg); + + if(C_n_rows > 0) + { + Mat out( t_n_rows + C_n_rows, (std::max)(t_n_cols, C_n_cols), arma_nozeros_indicator() ); + + if(t_n_cols > 0) + { + if(A_n_rows > 0) + { + out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); + } + + if( (t_n_cols > 0) && (B_n_rows > 0) ) + { + out.rows(row_num + C_n_rows, t_n_rows + C_n_rows - 1) = rows(row_num, t_n_rows - 1); + } + } + + if(C_n_cols > 0) + { + out.rows(row_num, row_num + C_n_rows - 1) = C; + } + + steal_mem(out); + } + } + + + +//! insert the given object at the specified column position; +//! the given object must have the same number of rows as the matrix +template +template +inline +void +Mat::insert_cols(const uword col_num, const Base& X) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(X.get_ref()); + const Mat& C = tmp.M; + + const uword C_n_rows = C.n_rows; + const uword C_n_cols = C.n_cols; + + const uword t_n_rows = n_rows; + const uword t_n_cols = n_cols; + + const uword A_n_cols = col_num; + const uword B_n_cols = t_n_cols - col_num; + + bool err_state = false; + char* err_msg = nullptr; + + const char* error_message_1 = "Mat::insert_cols(): index out of bounds"; + const char* error_message_2 = "Mat::insert_cols(): given object has an incompatible number of rows"; + + // insertion at col_num == n_cols is in effect an append operation + + arma_debug_set_error + ( + err_state, + err_msg, + (col_num > t_n_cols), + error_message_1 + ); + + arma_debug_set_error + ( + err_state, + err_msg, + ( (C_n_rows != t_n_rows) && ( (t_n_rows > 0) || (t_n_cols > 0) ) && ( (C_n_rows > 0) || (C_n_cols > 0) ) ), + error_message_2 + ); + + arma_debug_check_bounds(err_state, err_msg); + + if(C_n_cols > 0) + { + Mat out( (std::max)(t_n_rows, C_n_rows), t_n_cols + C_n_cols, arma_nozeros_indicator() ); + + if(t_n_rows > 0) + { + if(A_n_cols > 0) + { + out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); + } + + if(B_n_cols > 0) + { + out.cols(col_num + C_n_cols, t_n_cols + C_n_cols - 1) = cols(col_num, t_n_cols - 1); + } + } + + if(C_n_rows > 0) + { + out.cols(col_num, col_num + C_n_cols - 1) = C; + } + + steal_mem(out); + } + } + + + +template +template +inline +Mat::Mat(const Gen& X) + : n_rows(X.n_rows) + , n_cols(X.n_cols) + , n_elem(n_rows*n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + init_cold(); + + X.apply(*this); + } + + + +template +template +inline +Mat& +Mat::operator=(const Gen& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + init_warm(X.n_rows, X.n_cols); + + X.apply(*this); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const Gen& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + X.apply_inplace_plus(*this); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator-=(const Gen& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + X.apply_inplace_minus(*this); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator*=(const Gen& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat tmp(X); + + return (*this).operator*=(tmp); + } + + + +template +template +inline +Mat& +Mat::operator%=(const Gen& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + X.apply_inplace_schur(*this); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator/=(const Gen& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + X.apply_inplace_div(*this); + + return *this; + } + + + +//! create a matrix from Op, ie. run the previously delayed unary operations +template +template +inline +Mat::Mat(const Op& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + } + + + +//! create a matrix from Op, ie. run the previously delayed unary operations +template +template +inline +Mat& +Mat::operator=(const Op& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + + return *this; + } + + + +//! in-place matrix addition, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator+=(const Op& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator+=(m); + } + + + +//! in-place matrix subtraction, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator-=(const Op& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator-=(m); + } + + + +//! in-place matrix multiplication, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator*=(const Op& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix element-wise multiplication, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator%=(const Op& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator%=(m); + } + + + +//! in-place matrix element-wise division, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator/=(const Op& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator/=(m); + } + + + +//! create a matrix from eOp, ie. run the previously delayed unary operations +template +template +inline +Mat::Mat(const eOp& X) + : n_rows(X.get_n_rows()) + , n_cols(X.get_n_cols()) + , n_elem(X.get_n_elem()) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + init_cold(); + + eop_type::apply(*this, X); + } + + + +//! create a matrix from eOp, ie. run the previously delayed unary operations +template +template +inline +Mat& +Mat::operator=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { Mat tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols()); + + eop_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator+=(tmp); } + + eop_type::apply_inplace_plus(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator-=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator-=(tmp); } + + eop_type::apply_inplace_minus(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator*=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator%=(tmp); } + + eop_type::apply_inplace_schur(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator/=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator/=(tmp); } + + eop_type::apply_inplace_div(*this, X); + + return *this; + } + + + +template +template +inline +Mat::Mat(const mtOp& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + op_type::apply(*this, X); + } + + + +template +template +inline +Mat& +Mat::operator=(const mtOp& X) + { + arma_extra_debug_sigprint(); + + op_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const mtOp& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +Mat& +Mat::operator-=(const mtOp& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +Mat& +Mat::operator*=(const mtOp& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator*=(m); + } + + + +template +template +inline +Mat& +Mat::operator%=(const mtOp& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +Mat& +Mat::operator/=(const mtOp& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator/=(m); + } + + + +template +template +inline +Mat::Mat(const CubeToMatOp& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + } + + + +template +template +inline +Mat& +Mat::operator=(const CubeToMatOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const CubeToMatOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + (*this) = (*this) + X; + + return (*this); + } + + + +template +template +inline +Mat& +Mat::operator-=(const CubeToMatOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + (*this) = (*this) - X; + + return (*this); + } + + + +template +template +inline +Mat& +Mat::operator*=(const CubeToMatOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const CubeToMatOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + (*this) = (*this) % X; + + return (*this); + } + + + +template +template +inline +Mat& +Mat::operator/=(const CubeToMatOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + (*this) = (*this) / X; + + return (*this); + } + + + +template +template +inline +Mat::Mat(const SpToDOp& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + } + + + +//! create a matrix from an SpToDOp, ie. run the previously delayed unary operations +template +template +inline +Mat& +Mat::operator=(const SpToDOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + op_type::apply(*this, X); + + return *this; + } + + + +//! in-place matrix addition, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator+=(const SpToDOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator+=(m); + } + + + +//! in-place matrix subtraction, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator-=(const SpToDOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator-=(m); + } + + + +//! in-place matrix multiplication, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator*=(const SpToDOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix element-wise multiplication, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator%=(const SpToDOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator%=(m); + } + + + +//! in-place matrix element-wise division, with the right-hand-side operand having delayed operations +template +template +inline +Mat& +Mat::operator/=(const SpToDOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const Mat m(X); + + return (*this).operator/=(m); + } + + + +//! create a matrix from Glue, ie. run the previously delayed binary operations +template +template +inline +Mat::Mat(const Glue& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_type::apply(*this, X); + } + + + +//! create a matrix from Glue, ie. run the previously delayed binary operations +template +template +inline +Mat& +Mat::operator=(const Glue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_type::apply(*this, X); + + return *this; + } + + + +//! in-place matrix addition, with the right-hand-side operands having delayed operations +template +template +inline +Mat& +Mat::operator+=(const Glue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator+=(m); + } + + + +//! in-place matrix subtraction, with the right-hand-side operands having delayed operations +template +template +inline +Mat& +Mat::operator-=(const Glue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator-=(m); + } + + + +//! in-place matrix multiplications, with the right-hand-side operands having delayed operations +template +template +inline +Mat& +Mat::operator*=(const Glue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +//! in-place matrix element-wise multiplication, with the right-hand-side operands having delayed operations +template +template +inline +Mat& +Mat::operator%=(const Glue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator%=(m); + } + + + +//! in-place matrix element-wise division, with the right-hand-side operands having delayed operations +template +template +inline +Mat& +Mat::operator/=(const Glue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator/=(m); + } + + + +template +template +inline +Mat& +Mat::operator+=(const Glue& X) + { + arma_extra_debug_sigprint(); + + glue_times::apply_inplace_plus(*this, X, sword(+1)); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator-=(const Glue& X) + { + arma_extra_debug_sigprint(); + + glue_times::apply_inplace_plus(*this, X, sword(-1)); + + return *this; + } + + + +//! create a matrix from eGlue, ie. run the previously delayed binary operations +template +template +inline +Mat::Mat(const eGlue& X) + : n_rows(X.get_n_rows()) + , n_cols(X.get_n_cols()) + , n_elem(X.get_n_elem()) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + init_cold(); + + eglue_type::apply(*this, X); + } + + + +//! create a matrix from eGlue, ie. run the previously delayed binary operations +template +template +inline +Mat& +Mat::operator=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { Mat tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols()); + + eglue_type::apply(*this, X); + + return *this; + } + + + +//! in-place matrix addition, with the right-hand-side operands having delayed operations +template +template +inline +Mat& +Mat::operator+=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator+=(tmp); } + + eglue_type::apply_inplace_plus(*this, X); + + return *this; + } + + + +//! in-place matrix subtraction, with the right-hand-side operands having delayed operations +template +template +inline +Mat& +Mat::operator-=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator-=(tmp); } + + eglue_type::apply_inplace_minus(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator*=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator%=(tmp); } + + eglue_type::apply_inplace_schur(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator/=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator/=(tmp); } + + eglue_type::apply_inplace_div(*this, X); + + return *this; + } + + + +template +template +inline +Mat::Mat(const mtGlue& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + glue_type::apply(*this, X); + } + + + +template +template +inline +Mat& +Mat::operator=(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + glue_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +Mat& +Mat::operator-=(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +Mat& +Mat::operator*=(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + glue_times::apply_inplace(*this, m); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +Mat& +Mat::operator/=(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Mat m(X); + + return (*this).operator/=(m); + } + + + +template +template +inline +Mat::Mat(const SpToDGlue& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_type::apply(*this, X); + } + + + +template +template +inline +Mat& +Mat::operator=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +Mat& +Mat::operator-=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +Mat& +Mat::operator*=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +Mat& +Mat::operator/=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator/=(m); + } + + + +//! linear element accessor (treats the matrix as a vector); no bounds check; assumes memory is aligned +template +arma_inline +const eT& +Mat::at_alt(const uword ii) const + { + const eT* mem_aligned = mem; + + memory::mark_as_aligned(mem_aligned); + + return mem_aligned[ii]; + } + + + +//! linear element accessor (treats the matrix as a vector); bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +eT& +Mat::operator() (const uword ii) + { + arma_debug_check_bounds( (ii >= n_elem), "Mat::operator(): index out of bounds" ); + + return access::rw(mem[ii]); + } + + + +//! linear element accessor (treats the matrix as a vector); bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +const eT& +Mat::operator() (const uword ii) const + { + arma_debug_check_bounds( (ii >= n_elem), "Mat::operator(): index out of bounds" ); + + return mem[ii]; + } + + +//! linear element accessor (treats the matrix as a vector); no bounds check. +template +arma_inline +eT& +Mat::operator[] (const uword ii) + { + return access::rw(mem[ii]); + } + + + +//! linear element accessor (treats the matrix as a vector); no bounds check +template +arma_inline +const eT& +Mat::operator[] (const uword ii) const + { + return mem[ii]; + } + + + +//! linear element accessor (treats the matrix as a vector); no bounds check. +template +arma_inline +eT& +Mat::at(const uword ii) + { + return access::rw(mem[ii]); + } + + + +//! linear element accessor (treats the matrix as a vector); no bounds check +template +arma_inline +const eT& +Mat::at(const uword ii) const + { + return mem[ii]; + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +eT& +Mat::operator() (const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "Mat::operator(): index out of bounds" ); + + return access::rw(mem[in_row + in_col*n_rows]); + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +const eT& +Mat::operator() (const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "Mat::operator(): index out of bounds" ); + + return mem[in_row + in_col*n_rows]; + } + + + +//! element accessor; no bounds check +template +arma_inline +eT& +Mat::at(const uword in_row, const uword in_col) + { + return access::rw( mem[in_row + in_col*n_rows] ); + } + + + +//! element accessor; no bounds check +template +arma_inline +const eT& +Mat::at(const uword in_row, const uword in_col) const + { + return mem[in_row + in_col*n_rows]; + } + + + +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + eT& + Mat::operator[] (const uword in_row, const uword in_col) + { + return access::rw( mem[in_row + in_col*n_rows] ); + } + + + + //! element accessor; no bounds check + template + arma_inline + const eT& + Mat::operator[] (const uword in_row, const uword in_col) const + { + return mem[in_row + in_col*n_rows]; + } + +#endif + + + +//! prefix ++ +template +arma_inline +const Mat& +Mat::operator++() + { + Mat_aux::prefix_pp(*this); + + return *this; + } + + + +//! postfix ++ (must not return the object by reference) +template +arma_inline +void +Mat::operator++(int) + { + Mat_aux::postfix_pp(*this); + } + + + +//! prefix -- +template +arma_inline +const Mat& +Mat::operator--() + { + Mat_aux::prefix_mm(*this); + + return *this; + } + + + +//! postfix -- (must not return the object by reference) +template +arma_inline +void +Mat::operator--(int) + { + Mat_aux::postfix_mm(*this); + } + + + +//! returns true if the matrix has no elements +template +arma_inline +bool +Mat::is_empty() const + { + return (n_elem == 0); + } + + + +//! returns true if the object can be interpreted as a column or row vector +template +arma_inline +bool +Mat::is_vec() const + { + return ( (n_rows == 1) || (n_cols == 1) ); + } + + + +//! returns true if the object can be interpreted as a row vector +template +arma_inline +bool +Mat::is_rowvec() const + { + return (n_rows == 1); + } + + + +//! returns true if the object can be interpreted as a column vector +template +arma_inline +bool +Mat::is_colvec() const + { + return (n_cols == 1); + } + + + +//! returns true if the object has the same number of non-zero rows and columnns +template +arma_inline +bool +Mat::is_square() const + { + return (n_rows == n_cols); + } + + + +template +inline +bool +Mat::internal_is_finite() const + { + arma_extra_debug_sigprint(); + + return arrayops::is_finite(memptr(), n_elem); + } + + + +template +inline +bool +Mat::internal_has_inf() const + { + arma_extra_debug_sigprint(); + + return arrayops::has_inf(memptr(), n_elem); + } + + + +template +inline +bool +Mat::internal_has_nan() const + { + arma_extra_debug_sigprint(); + + return arrayops::has_nan(memptr(), n_elem); + } + + + +template +inline +bool +Mat::internal_has_nonfinite() const + { + arma_extra_debug_sigprint(); + + return (arrayops::is_finite(memptr(), n_elem) == false); + } + + + +template +inline +bool +Mat::is_sorted(const char* direction) const + { + arma_extra_debug_sigprint(); + + return (*this).is_sorted(direction, (((vec_state == 2) || (n_rows == 1)) ? uword(1) : uword(0))); + } + + + +template +inline +bool +Mat::is_sorted(const char* direction, const uword dim) const + { + arma_extra_debug_sigprint(); + + const char sig1 = (direction != nullptr) ? direction[0] : char(0); + + // direction is one of: + // "ascend" + // "descend" + // "strictascend" + // "strictdescend" + + arma_debug_check( ((sig1 != 'a') && (sig1 != 'd') && (sig1 != 's')), "Mat::is_sorted(): unknown sort direction" ); + + // "strictascend" + // "strictdescend" + // 0123456 + + const char sig2 = (sig1 == 's') ? direction[6] : char(0); + + if(sig1 == 's') { arma_debug_check( ((sig2 != 'a') && (sig2 != 'd')), "Mat::is_sorted(): unknown sort direction" ); } + + arma_debug_check( (dim > 1), "Mat::is_sorted(): parameter 'dim' must be 0 or 1" ); + + if(sig1 == 'a') + { + // case: ascend + + // deliberately using the opposite direction comparator, + // as we need to handle the case of two elements being equal + + arma_gt_comparator comparator; + + return (*this).is_sorted_helper(comparator, dim); + } + else + if(sig1 == 'd') + { + // case: descend + + // deliberately using the opposite direction comparator, + // as we need to handle the case of two elements being equal + + arma_lt_comparator comparator; + + return (*this).is_sorted_helper(comparator, dim); + } + else + if((sig1 == 's') && (sig2 == 'a')) + { + // case: strict ascend + + arma_geq_comparator comparator; + + return (*this).is_sorted_helper(comparator, dim); + } + else + if((sig1 == 's') && (sig2 == 'd')) + { + // case: strict descend + + arma_leq_comparator comparator; + + return (*this).is_sorted_helper(comparator, dim); + } + + return true; + } + + + +template +template +inline +bool +Mat::is_sorted_helper(const comparator& comp, const uword dim) const + { + arma_extra_debug_sigprint(); + + if(n_elem <= 1) { return true; } + + const uword local_n_cols = n_cols; + const uword local_n_rows = n_rows; + + if(dim == 0) + { + if(local_n_rows <= 1u) { return true; } + + const uword local_n_rows_m1 = local_n_rows - 1; + + for(uword c=0; c < local_n_cols; ++c) + { + const eT* coldata = colptr(c); + + for(uword r=0; r < local_n_rows_m1; ++r) + { + const eT val1 = (*coldata); coldata++; + const eT val2 = (*coldata); + + if(comp(val1,val2)) { return false; } + } + } + } + else + if(dim == 1) + { + if(local_n_cols <= 1u) { return true; } + + const uword local_n_cols_m1 = local_n_cols - 1; + + if(local_n_rows == 1) + { + const eT* rowdata = memptr(); + + for(uword c=0; c < local_n_cols_m1; ++c) + { + const eT val1 = (*rowdata); rowdata++; + const eT val2 = (*rowdata); + + if(comp(val1,val2)) { return false; } + } + } + else + { + for(uword r=0; r < local_n_rows; ++r) + for(uword c=0; c < local_n_cols_m1; ++c) + { + const eT val1 = at(r,c ); + const eT val2 = at(r,c+1); + + if(comp(val1,val2)) { return false; } + } + } + } + + return true; + } + + + +//! returns true if the given index is currently in range +template +arma_inline +bool +Mat::in_range(const uword ii) const + { + return (ii < n_elem); + } + + + +//! returns true if the given start and end indices are currently in range +template +arma_inline +bool +Mat::in_range(const span& x) const + { + arma_extra_debug_sigprint(); + + if(x.whole) + { + return true; + } + else + { + const uword a = x.a; + const uword b = x.b; + + return ( (a <= b) && (b < n_elem) ); + } + } + + + +//! returns true if the given location is currently in range +template +arma_inline +bool +Mat::in_range(const uword in_row, const uword in_col) const + { + return ( (in_row < n_rows) && (in_col < n_cols) ); + } + + + +template +arma_inline +bool +Mat::in_range(const span& row_span, const uword in_col) const + { + arma_extra_debug_sigprint(); + + if(row_span.whole) + { + return (in_col < n_cols); + } + else + { + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) ); + } + } + + + +template +arma_inline +bool +Mat::in_range(const uword in_row, const span& col_span) const + { + arma_extra_debug_sigprint(); + + if(col_span.whole) + { + return (in_row < n_rows); + } + else + { + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) ); + } + } + + + +template +arma_inline +bool +Mat::in_range(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) ); + const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) ); + + return ( rows_ok && cols_ok ); + } + + + +template +arma_inline +bool +Mat::in_range(const uword in_row, const uword in_col, const SizeMat& s) const + { + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + if( (in_row >= l_n_rows) || (in_col >= l_n_cols) || ((in_row + s.n_rows) > l_n_rows) || ((in_col + s.n_cols) > l_n_cols) ) + { + return false; + } + else + { + return true; + } + } + + + +//! returns a pointer to array of eTs for a specified column; no bounds check +template +arma_inline +eT* +Mat::colptr(const uword in_col) + { + return & access::rw(mem[in_col*n_rows]); + } + + + +//! returns a pointer to array of eTs for a specified column; no bounds check +template +arma_inline +const eT* +Mat::colptr(const uword in_col) const + { + return & mem[in_col*n_rows]; + } + + + +//! returns a pointer to array of eTs used by the matrix +template +arma_inline +eT* +Mat::memptr() + { + return const_cast(mem); + } + + + +//! returns a pointer to array of eTs used by the matrix +template +arma_inline +const eT* +Mat::memptr() const + { + return mem; + } + + + +//! change the matrix to have user specified dimensions (data is not preserved) +template +inline +Mat& +Mat::set_size(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + switch(vec_state) + { + case 0: + // fallthrough + case 1: + init_warm(new_n_elem, 1); + break; + + case 2: + init_warm(1, new_n_elem); + break; + + default: + ; + } + + return *this; + } + + + +//! change the matrix to have user specified dimensions (data is not preserved) +template +inline +Mat& +Mat::set_size(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + init_warm(new_n_rows, new_n_cols); + + return *this; + } + + + +template +inline +Mat& +Mat::set_size(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + init_warm(s.n_rows, s.n_cols); + + return *this; + } + + + +//! change the matrix to have user specified dimensions (data is preserved) +template +inline +Mat& +Mat::resize(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + switch(vec_state) + { + case 0: + // fallthrough + case 1: + (*this).resize(new_n_elem, 1); + break; + + case 2: + (*this).resize(1, new_n_elem); + break; + + default: + ; + } + + return *this; + } + + + +//! change the matrix to have user specified dimensions (data is preserved) +template +inline +Mat& +Mat::resize(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + op_resize::apply_mat_inplace((*this), new_n_rows, new_n_cols); + + return *this; + } + + + +template +inline +Mat& +Mat::resize(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + op_resize::apply_mat_inplace((*this), s.n_rows, s.n_cols); + + return *this; + } + + + +//! change the matrix to have user specified dimensions (data is preserved) +template +inline +Mat& +Mat::reshape(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + op_reshape::apply_mat_inplace((*this), new_n_rows, new_n_cols); + + return *this; + } + + + +template +inline +Mat& +Mat::reshape(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + op_reshape::apply_mat_inplace((*this), s.n_rows, s.n_cols); + + return *this; + } + + + +//! NOTE: don't use this form; it's deprecated and will be removed +template +inline +void +Mat::reshape(const uword new_n_rows, const uword new_n_cols, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (dim > 1), "reshape(): parameter 'dim' must be 0 or 1" ); + + if(dim == 0) + { + op_reshape::apply_mat_inplace((*this), new_n_rows, new_n_cols); + } + else + if(dim == 1) + { + Mat tmp; + + op_strans::apply_mat_noalias(tmp, (*this)); + + op_reshape::apply_mat_noalias((*this), tmp, new_n_rows, new_n_cols); + } + } + + + +//! change the matrix (without preserving data) to have the same dimensions as the given expression +template +template +inline +Mat& +Mat::copy_size(const Base& X) + { + arma_extra_debug_sigprint(); + + const Proxy P(X.get_ref()); + + const uword X_n_rows = P.get_n_rows(); + const uword X_n_cols = P.get_n_cols(); + + init_warm(X_n_rows, X_n_cols); + + return *this; + } + + + +//! apply a functor to each element +template +template +inline +Mat& +Mat::for_each(functor F) + { + arma_extra_debug_sigprint(); + + eT* data = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + F(data[ii]); + F(data[jj]); + } + + if(ii < N) + { + F(data[ii]); + } + + return *this; + } + + + +template +template +inline +const Mat& +Mat::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + const eT* data = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + F(data[ii]); + F(data[jj]); + } + + if(ii < N) + { + F(data[ii]); + } + + return *this; + } + + + +//! transform each element in the matrix using a functor +template +template +inline +Mat& +Mat::transform(functor F) + { + arma_extra_debug_sigprint(); + + eT* out_mem = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + eT tmp_ii = out_mem[ii]; + eT tmp_jj = out_mem[jj]; + + tmp_ii = eT( F(tmp_ii) ); + tmp_jj = eT( F(tmp_jj) ); + + out_mem[ii] = tmp_ii; + out_mem[jj] = tmp_jj; + } + + if(ii < N) + { + out_mem[ii] = eT( F(out_mem[ii]) ); + } + + return *this; + } + + + +//! imbue (fill) the matrix with values provided by a functor +template +template +inline +Mat& +Mat::imbue(functor F) + { + arma_extra_debug_sigprint(); + + eT* out_mem = memptr(); + + const uword N = n_elem; + + uword ii, jj; + + for(ii=0, jj=1; jj < N; ii+=2, jj+=2) + { + const eT tmp_ii = eT( F() ); + const eT tmp_jj = eT( F() ); + + out_mem[ii] = tmp_ii; + out_mem[jj] = tmp_jj; + } + + if(ii < N) + { + out_mem[ii] = eT( F() ); + } + + return *this; + } + + + +template +inline +Mat& +Mat::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + arrayops::replace(memptr(), n_elem, old_val, new_val); + + return *this; + } + + + +template +inline +Mat& +Mat::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + arrayops::clean(memptr(), n_elem, threshold); + + return *this; + } + + + +template +inline +Mat& +Mat::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Mat::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Mat::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "Mat::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + arrayops::clamp(memptr(), n_elem, min_val, max_val); + + return *this; + } + + + +//! fill the matrix with the specified value +template +inline +Mat& +Mat::fill(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_set( memptr(), val, n_elem ); + + return *this; + } + + + +//! fill the matrix with the specified pattern +template +template +inline +Mat& +Mat::fill(const fill::fill_class&) + { + arma_extra_debug_sigprint(); + + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } + + return *this; + } + + + +template +inline +Mat& +Mat::zeros() + { + arma_extra_debug_sigprint(); + + arrayops::fill_zeros(memptr(), n_elem); + + return *this; + } + + + +template +inline +Mat& +Mat::zeros(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + set_size(new_n_elem); + + return (*this).zeros(); + } + + + +template +inline +Mat& +Mat::zeros(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols); + + return (*this).zeros(); + } + + + +template +inline +Mat& +Mat::zeros(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).zeros(s.n_rows, s.n_cols); + } + + + +template +inline +Mat& +Mat::ones() + { + arma_extra_debug_sigprint(); + + return fill(eT(1)); + } + + + +template +inline +Mat& +Mat::ones(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + set_size(new_n_elem); + + return fill(eT(1)); + } + + + +template +inline +Mat& +Mat::ones(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols); + + return fill(eT(1)); + } + + + +template +inline +Mat& +Mat::ones(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).ones(s.n_rows, s.n_cols); + } + + + +template +inline +Mat& +Mat::randu() + { + arma_extra_debug_sigprint(); + + arma_rng::randu::fill( memptr(), n_elem ); + + return *this; + } + + + +template +inline +Mat& +Mat::randu(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + set_size(new_n_elem); + + return (*this).randu(); + } + + + +template +inline +Mat& +Mat::randu(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols); + + return (*this).randu(); + } + + + +template +inline +Mat& +Mat::randu(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).randu(s.n_rows, s.n_cols); + } + + + +template +inline +Mat& +Mat::randn() + { + arma_extra_debug_sigprint(); + + arma_rng::randn::fill( memptr(), n_elem ); + + return *this; + } + + + +template +inline +Mat& +Mat::randn(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + set_size(new_n_elem); + + return (*this).randn(); + } + + + +template +inline +Mat& +Mat::randn(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols); + + return (*this).randn(); + } + + + +template +inline +Mat& +Mat::randn(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).randn(s.n_rows, s.n_cols); + } + + + +template +inline +Mat& +Mat::eye() + { + arma_extra_debug_sigprint(); + + (*this).zeros(); + + const uword N = (std::min)(n_rows, n_cols); + + for(uword ii=0; ii +inline +Mat& +Mat::eye(const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + set_size(new_n_rows, new_n_cols); + + return (*this).eye(); + } + + + +template +inline +Mat& +Mat::eye(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).eye(s.n_rows, s.n_cols); + } + + + +template +inline +void +Mat::reset() + { + arma_extra_debug_sigprint(); + + const uword new_n_rows = (vec_state == 2) ? 1 : 0; + const uword new_n_cols = (vec_state == 1) ? 1 : 0; + + init_warm(new_n_rows, new_n_cols); + } + + + +template +inline +void +Mat::soft_reset() + { + arma_extra_debug_sigprint(); + + // don't change the size if the matrix has a fixed size or is a cube slice + if(mem_state <= 1) + { + reset(); + } + else + { + zeros(); + } + } + + + +template +template +inline +void +Mat::set_real(const Base::pod_type,T1>& X) + { + arma_extra_debug_sigprint(); + + Mat_aux::set_real(*this, X); + } + + + +template +template +inline +void +Mat::set_imag(const Base::pod_type,T1>& X) + { + arma_extra_debug_sigprint(); + + Mat_aux::set_imag(*this, X); + } + + + +template +inline +eT +Mat::min() const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Mat::min(): object has no elements"); + + return Datum::nan; + } + + return op_min::direct_min(memptr(), n_elem); + } + + + +template +inline +eT +Mat::max() const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Mat::max(): object has no elements"); + + return Datum::nan; + } + + return op_max::direct_max(memptr(), n_elem); + } + + + +template +inline +eT +Mat::min(uword& index_of_min_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Mat::min(): object has no elements"); + + index_of_min_val = uword(0); + + return Datum::nan; + } + + return op_min::direct_min(memptr(), n_elem, index_of_min_val); + } + + + +template +inline +eT +Mat::max(uword& index_of_max_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Mat::max(): object has no elements"); + + index_of_max_val = uword(0); + + return Datum::nan; + } + + return op_max::direct_max(memptr(), n_elem, index_of_max_val); + } + + + +template +inline +eT +Mat::min(uword& row_of_min_val, uword& col_of_min_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Mat::min(): object has no elements"); + + row_of_min_val = uword(0); + col_of_min_val = uword(0); + + return Datum::nan; + } + + uword iq; + + eT val = op_min::direct_min(memptr(), n_elem, iq); + + row_of_min_val = iq % n_rows; + col_of_min_val = iq / n_rows; + + return val; + } + + + +template +inline +eT +Mat::max(uword& row_of_max_val, uword& col_of_max_val) const + { + arma_extra_debug_sigprint(); + + if(n_elem == 0) + { + arma_debug_check(true, "Mat::max(): object has no elements"); + + row_of_max_val = uword(0); + col_of_max_val = uword(0); + + return Datum::nan; + } + + uword iq; + + eT val = op_max::direct_max(memptr(), n_elem, iq); + + row_of_max_val = iq % n_rows; + col_of_max_val = iq / n_rows; + + return val; + } + + + +//! save the matrix to a file +template +inline +bool +Mat::save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + bool save_okay = false; + + switch(type) + { + case raw_ascii: + save_okay = diskio::save_raw_ascii(*this, name); + break; + + case arma_ascii: + save_okay = diskio::save_arma_ascii(*this, name); + break; + + case csv_ascii: + return (*this).save(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).save(csv_name(name), type); + break; + + case coord_ascii: + save_okay = diskio::save_coord_ascii(*this, name); + break; + + case raw_binary: + save_okay = diskio::save_raw_binary(*this, name); + break; + + case arma_binary: + save_okay = diskio::save_arma_binary(*this, name); + break; + + case pgm_binary: + save_okay = diskio::save_pgm_binary(*this, name); + break; + + case hdf5_binary: + return (*this).save(hdf5_name(name)); + break; + + case hdf5_binary_trans: // kept for compatibility with earlier versions of Armadillo + return (*this).save(hdf5_name(name, std::string(), hdf5_opts::trans)); + break; + + default: + arma_debug_warn_level(1, "Mat::save(): unsupported file type"); + save_okay = false; + } + + if(save_okay == false) { arma_debug_warn_level(3, "Mat::save(): write failed; file: ", name); } + + return save_okay; + } + + + +template +inline +bool +Mat::save(const hdf5_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + // handling of hdf5_binary_trans kept for compatibility with earlier versions of Armadillo + + if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) + { + arma_stop_runtime_error("Mat::save(): unsupported file type for hdf5_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & hdf5_opts::flag_trans ) || (type == hdf5_binary_trans); + const bool append = bool(spec.opts.flags & hdf5_opts::flag_append ); + const bool replace = bool(spec.opts.flags & hdf5_opts::flag_replace); + + if(append && replace) + { + arma_stop_runtime_error("Mat::save(): only one of 'append' or 'replace' options can be used"); + return false; + } + + bool save_okay = false; + + std::string err_msg; + + if(do_trans) + { + Mat tmp; + + op_strans::apply_mat_noalias(tmp, *this); + + save_okay = diskio::save_hdf5_binary(tmp, spec, err_msg); + } + else + { + save_okay = diskio::save_hdf5_binary(*this, spec, err_msg); + } + + if(save_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Mat::save(): ", err_msg, "; file: ", spec.filename); + } + else + { + arma_debug_warn_level(3, "Mat::save(): write failed; file: ", spec.filename); + } + } + + return save_okay; + } + + + +template +inline +bool +Mat::save(const csv_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("Mat::save(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + + arma_extra_debug_print("Mat::save(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + if(with_header) + { + if( (spec.header_ro.n_cols != 1) && (spec.header_ro.n_rows != 1) ) + { + arma_debug_warn_level(1, "Mat::save(): given header must have a vector layout"); + return false; + } + + for(uword i=0; i < spec.header_ro.n_elem; ++i) + { + const std::string& token = spec.header_ro.at(i); + + if(token.find(separator) != std::string::npos) + { + arma_debug_warn_level(1, "Mat::save(): token within the header contains the separator character: '", token, "'"); + return false; + } + } + + const uword save_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(spec.header_ro.n_elem != save_n_cols) + { + arma_debug_warn_level(1, "Mat::save(): size mismatch between header and matrix"); + return false; + } + } + + bool save_okay = false; + + if(do_trans) + { + const Mat tmp = (*this).st(); + + save_okay = diskio::save_csv_ascii(tmp, spec.filename, spec.header_ro, with_header, separator); + } + else + { + save_okay = diskio::save_csv_ascii(*this, spec.filename, spec.header_ro, with_header, separator); + } + + if(save_okay == false) { arma_debug_warn_level(3, "Mat::save(): write failed; file: ", spec.filename); } + + return save_okay; + } + + + +//! save the matrix to a stream +template +inline +bool +Mat::save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + bool save_okay = false; + + switch(type) + { + case raw_ascii: + save_okay = diskio::save_raw_ascii(*this, os); + break; + + case arma_ascii: + save_okay = diskio::save_arma_ascii(*this, os); + break; + + case csv_ascii: + save_okay = diskio::save_csv_ascii(*this, os, char(',')); + break; + + case ssv_ascii: + save_okay = diskio::save_csv_ascii(*this, os, char(';')); + break; + + case coord_ascii: + save_okay = diskio::save_coord_ascii(*this, os); + break; + + case raw_binary: + save_okay = diskio::save_raw_binary(*this, os); + break; + + case arma_binary: + save_okay = diskio::save_arma_binary(*this, os); + break; + + case pgm_binary: + save_okay = diskio::save_pgm_binary(*this, os); + break; + + default: + arma_debug_warn_level(1, "Mat::save(): unsupported file type"); + save_okay = false; + } + + if(save_okay == false) { arma_debug_warn_level(3, "Mat::save(): stream write failed"); } + + return save_okay; + } + + + +//! load a matrix from a file +template +inline +bool +Mat::load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + bool load_okay = false; + std::string err_msg; + + switch(type) + { + case auto_detect: + load_okay = diskio::load_auto_detect(*this, name, err_msg); + break; + + case raw_ascii: + load_okay = diskio::load_raw_ascii(*this, name, err_msg); + break; + + case arma_ascii: + load_okay = diskio::load_arma_ascii(*this, name, err_msg); + break; + + case csv_ascii: + return (*this).load(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).load(csv_name(name), type); + break; + + case coord_ascii: + load_okay = diskio::load_coord_ascii(*this, name, err_msg); + break; + + case raw_binary: + load_okay = diskio::load_raw_binary(*this, name, err_msg); + break; + + case arma_binary: + load_okay = diskio::load_arma_binary(*this, name, err_msg); + break; + + case pgm_binary: + load_okay = diskio::load_pgm_binary(*this, name, err_msg); + break; + + case hdf5_binary: + return (*this).load(hdf5_name(name)); + break; + + case hdf5_binary_trans: // kept for compatibility with earlier versions of Armadillo + return (*this).load(hdf5_name(name, std::string(), hdf5_opts::trans)); + break; + + default: + arma_debug_warn_level(1, "Mat::load(): unsupported file type"); + load_okay = false; + } + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Mat::load(): ", err_msg, "; file: ", name); + } + else + { + arma_debug_warn_level(3, "Mat::load(): read failed; file: ", name); + } + } + + if(load_okay == false) { (*this).soft_reset(); } + + return load_okay; + } + + + +template +inline +bool +Mat::load(const hdf5_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) + { + arma_stop_runtime_error("Mat::load(): unsupported file type for hdf5_name()"); + return false; + } + + bool load_okay = false; + std::string err_msg; + + const bool do_trans = bool(spec.opts.flags & hdf5_opts::flag_trans) || (type == hdf5_binary_trans); + + if(do_trans) + { + Mat tmp; + + load_okay = diskio::load_hdf5_binary(tmp, spec, err_msg); + + if(load_okay) { op_strans::apply_mat_noalias(*this, tmp); } + } + else + { + load_okay = diskio::load_hdf5_binary(*this, spec, err_msg); + } + + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Mat::load(): ", err_msg, "; file: ", spec.filename); + } + else + { + arma_debug_warn_level(3, "Mat::load(): read failed; file: ", spec.filename); + } + } + + if(load_okay == false) { (*this).soft_reset(); } + + return load_okay; + } + + + +template +inline +bool +Mat::load(const csv_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("Mat::load(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + const bool strict = bool(spec.opts.flags & csv_opts::flag_strict ); + + arma_extra_debug_print("Mat::load(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + if(strict ) { arma_extra_debug_print("strict"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + bool load_okay = false; + std::string err_msg; + + if(do_trans) + { + Mat tmp_mat; + + load_okay = diskio::load_csv_ascii(tmp_mat, spec.filename, err_msg, spec.header_rw, with_header, separator, strict); + + if(load_okay) + { + (*this) = tmp_mat.st(); + + if(with_header) + { + // field::set_size() preserves data if the number of elements hasn't changed + spec.header_rw.set_size(spec.header_rw.n_elem, 1); + } + } + } + else + { + load_okay = diskio::load_csv_ascii(*this, spec.filename, err_msg, spec.header_rw, with_header, separator, strict); + } + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Mat::load(): ", err_msg, "; file: ", spec.filename); + } + else + { + arma_debug_warn_level(3, "Mat::load(): read failed; file: ", spec.filename); + } + } + else + { + const uword load_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(with_header && (spec.header_rw.n_elem != load_n_cols)) + { + arma_debug_warn_level(3, "Mat::load(): size mismatch between header and matrix"); + } + } + + if(load_okay == false) + { + (*this).soft_reset(); + + if(with_header) { spec.header_rw.reset(); } + } + + return load_okay; + } + + + +//! load a matrix from a stream +template +inline +bool +Mat::load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + bool load_okay = false; + std::string err_msg; + + switch(type) + { + case auto_detect: + load_okay = diskio::load_auto_detect(*this, is, err_msg); + break; + + case raw_ascii: + load_okay = diskio::load_raw_ascii(*this, is, err_msg); + break; + + case arma_ascii: + load_okay = diskio::load_arma_ascii(*this, is, err_msg); + break; + + case csv_ascii: + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(','), false); + break; + + case ssv_ascii: + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(';'), false); + break; + + case coord_ascii: + load_okay = diskio::load_coord_ascii(*this, is, err_msg); + break; + + case raw_binary: + load_okay = diskio::load_raw_binary(*this, is, err_msg); + break; + + case arma_binary: + load_okay = diskio::load_arma_binary(*this, is, err_msg); + break; + + case pgm_binary: + load_okay = diskio::load_pgm_binary(*this, is, err_msg); + break; + + default: + arma_debug_warn_level(1, "Mat::load(): unsupported file type"); + load_okay = false; + } + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Mat::load(): ", err_msg); + } + else + { + arma_debug_warn_level(3, "Mat::load(): stream read failed"); + } + } + + if(load_okay == false) { (*this).soft_reset(); } + + return load_okay; + } + + + +template +inline +bool +Mat::quiet_save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(name, type); + } + + + +template +inline +bool +Mat::quiet_save(const hdf5_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(spec, type); + } + + + +template +inline +bool +Mat::quiet_save(const csv_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(spec, type); + } + + + +template +inline +bool +Mat::quiet_save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(os, type); + } + + + +template +inline +bool +Mat::quiet_load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(name, type); + } + + + +template +inline +bool +Mat::quiet_load(const hdf5_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(spec, type); + } + + + +template +inline +bool +Mat::quiet_load(const csv_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(spec, type); + } + + + +template +inline +bool +Mat::quiet_load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(is, type); + } + + + +template +inline +Mat::row_iterator::row_iterator() + : M (nullptr) + , current_row(0 ) + , current_col(0 ) + { + arma_extra_debug_sigprint(); + + // NOTE: this instance of row_iterator is invalid (it does not point to a valid element) + } + + + +template +inline +Mat::row_iterator::row_iterator(const row_iterator& X) + : M (X.M ) + , current_row(X.current_row) + , current_col(X.current_col) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Mat::row_iterator::row_iterator(Mat& in_M, const uword in_row, const uword in_col) + : M (&in_M ) + , current_row(in_row) + , current_col(in_col) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eT& +Mat::row_iterator::operator*() + { + return M->at(current_row,current_col); + } + + + +template +inline +typename Mat::row_iterator& +Mat::row_iterator::operator++() + { + current_col++; + + if(current_col == M->n_cols) + { + current_col = 0; + current_row++; + } + + return *this; + } + + + +template +inline +typename Mat::row_iterator +Mat::row_iterator::operator++(int) + { + typename Mat::row_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +typename Mat::row_iterator& +Mat::row_iterator::operator--() + { + if(current_col > 0) + { + current_col--; + } + else + { + if(current_row > 0) + { + current_col = M->n_cols - 1; + current_row--; + } + } + + return *this; + } + + + +template +inline +typename Mat::row_iterator +Mat::row_iterator::operator--(int) + { + typename Mat::row_iterator temp(*this); + + --(*this); + + return temp; + } + + + +template +inline +bool +Mat::row_iterator::operator!=(const typename Mat::row_iterator& X) const + { + return ( (current_row != X.current_row) || (current_col != X.current_col) ); + } + + + +template +inline +bool +Mat::row_iterator::operator==(const typename Mat::row_iterator& X) const + { + return ( (current_row == X.current_row) && (current_col == X.current_col) ); + } + + + +template +inline +bool +Mat::row_iterator::operator!=(const typename Mat::const_row_iterator& X) const + { + return ( (current_row != X.current_row) || (current_col != X.current_col) ); + } + + + +template +inline +bool +Mat::row_iterator::operator==(const typename Mat::const_row_iterator& X) const + { + return ( (current_row == X.current_row) && (current_col == X.current_col) ); + } + + + +template +inline +Mat::const_row_iterator::const_row_iterator() + : M (nullptr) + , current_row(0 ) + , current_col(0 ) + { + arma_extra_debug_sigprint(); + + // NOTE: this instance of const_row_iterator is invalid (it does not point to a valid element) + } + + + +template +inline +Mat::const_row_iterator::const_row_iterator(const typename Mat::row_iterator& X) + : M (X.M ) + , current_row(X.current_row) + , current_col(X.current_col) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Mat::const_row_iterator::const_row_iterator(const typename Mat::const_row_iterator& X) + : M (X.M ) + , current_row(X.current_row) + , current_col(X.current_col) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Mat::const_row_iterator::const_row_iterator(const Mat& in_M, const uword in_row, const uword in_col) + : M (&in_M ) + , current_row(in_row) + , current_col(in_col) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +const eT& +Mat::const_row_iterator::operator*() const + { + return M->at(current_row,current_col); + } + + + +template +inline +typename Mat::const_row_iterator& +Mat::const_row_iterator::operator++() + { + current_col++; + + if(current_col == M->n_cols) + { + current_col = 0; + current_row++; + } + + return *this; + } + + + +template +inline +typename Mat::const_row_iterator +Mat::const_row_iterator::operator++(int) + { + typename Mat::const_row_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +typename Mat::const_row_iterator& +Mat::const_row_iterator::operator--() + { + if(current_col > 0) + { + current_col--; + } + else + { + if(current_row > 0) + { + current_col = M->n_cols - 1; + current_row--; + } + } + + return *this; + } + + + +template +inline +typename Mat::const_row_iterator +Mat::const_row_iterator::operator--(int) + { + typename Mat::const_row_iterator temp(*this); + + --(*this); + + return temp; + } + + + +template +inline +bool +Mat::const_row_iterator::operator!=(const typename Mat::row_iterator& X) const + { + return ( (current_row != X.current_row) || (current_col != X.current_col) ); + } + + + +template +inline +bool +Mat::const_row_iterator::operator==(const typename Mat::row_iterator& X) const + { + return ( (current_row == X.current_row) && (current_col == X.current_col) ); + } + + + +template +inline +bool +Mat::const_row_iterator::operator!=(const typename Mat::const_row_iterator& X) const + { + return ( (current_row != X.current_row) || (current_col != X.current_col) ); + } + + + +template +inline +bool +Mat::const_row_iterator::operator==(const typename Mat::const_row_iterator& X) const + { + return ( (current_row == X.current_row) && (current_col == X.current_col) ); + } + + + +template +inline +Mat::row_col_iterator::row_col_iterator() + : M (nullptr) + , current_ptr(nullptr) + , current_col(0 ) + , current_row(0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +Mat::row_col_iterator::row_col_iterator(const row_col_iterator& in_it) + : M (in_it.M ) + , current_ptr(in_it.current_ptr) + , current_col(in_it.current_col) + , current_row(in_it.current_row) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Mat::row_col_iterator::row_col_iterator(Mat& in_M, const uword in_row, const uword in_col) + : M (&in_M ) + , current_ptr(&in_M.at(in_row,in_col)) + , current_col(in_col ) + , current_row(in_row ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eT& +Mat::row_col_iterator::operator*() + { + return *current_ptr; + } + + + +template +inline +typename Mat::row_col_iterator& +Mat::row_col_iterator::operator++() + { + if(current_col < M->n_cols) + { + current_ptr++; + current_row++; + + // Check to see if we moved a column. + if(current_row == M->n_rows) + { + current_col++; + current_row = 0; + } + } + + return *this; + } + + + +template +inline +typename Mat::row_col_iterator +Mat::row_col_iterator::operator++(int) + { + typename Mat::row_col_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline typename Mat::row_col_iterator& +Mat::row_col_iterator::operator--() + { + if(current_row > 0) + { + current_ptr--; + current_row--; + } + else + if(current_col > 0) + { + current_ptr--; + current_col--; + current_row = M->n_rows - 1; + } + + return *this; + } + + + +template +inline +typename Mat::row_col_iterator +Mat::row_col_iterator::operator--(int) + { + typename Mat::row_col_iterator temp(*this); + + --(*this); + + return temp; + } + + + +template +inline +uword +Mat::row_col_iterator::row() const + { + return current_row; + } + + + +template +inline +uword +Mat::row_col_iterator::col() const + { + return current_col; + } + + + +template +inline +bool +Mat::row_col_iterator::operator==(const row_col_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +Mat::row_col_iterator::operator!=(const row_col_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +bool +Mat::row_col_iterator::operator==(const const_row_col_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +Mat::row_col_iterator::operator!=(const const_row_col_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +Mat::const_row_col_iterator::const_row_col_iterator() + : M (nullptr) + , current_ptr(nullptr) + , current_col(0 ) + , current_row(0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +Mat::const_row_col_iterator::const_row_col_iterator(const row_col_iterator& in_it) + : M (in_it.M ) + , current_ptr(in_it.current_ptr) + , current_col(in_it.col() ) + , current_row(in_it.row() ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Mat::const_row_col_iterator::const_row_col_iterator(const const_row_col_iterator& in_it) + : M (in_it.M ) + , current_ptr(in_it.current_ptr) + , current_col(in_it.col() ) + , current_row(in_it.row() ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Mat::const_row_col_iterator::const_row_col_iterator(const Mat& in_M, const uword in_row, const uword in_col) + : M (&in_M ) + , current_ptr(&in_M.at(in_row,in_col)) + , current_col(in_col ) + , current_row(in_row ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +const eT& +Mat::const_row_col_iterator::operator*() const + { + return *current_ptr; + } + + + +template +inline +typename Mat::const_row_col_iterator& +Mat::const_row_col_iterator::operator++() + { + if(current_col < M->n_cols) + { + current_ptr++; + current_row++; + + // Check to see if we moved a column. + if(current_row == M->n_rows) + { + current_col++; + current_row = 0; + } + } + + return *this; + } + + + +template +inline +typename Mat::const_row_col_iterator +Mat::const_row_col_iterator::operator++(int) + { + typename Mat::const_row_col_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +typename Mat::const_row_col_iterator& +Mat::const_row_col_iterator::operator--() + { + if(current_row > 0) + { + current_ptr--; + current_row--; + } + else + if(current_col > 0) + { + current_ptr--; + current_col--; + current_row = M->n_rows - 1; + } + + return *this; + } + + + +template +inline +typename Mat::const_row_col_iterator +Mat::const_row_col_iterator::operator--(int) + { + typename Mat::const_row_col_iterator temp(*this); + + --(*this); + + return temp; + } + + + +template +inline +uword +Mat::const_row_col_iterator::row() const + { + return current_row; + } + + + +template +inline +uword +Mat::const_row_col_iterator::col() const + { + return current_col; + } + + + +template +inline +bool +Mat::const_row_col_iterator::operator==(const const_row_col_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +Mat::const_row_col_iterator::operator!=(const const_row_col_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +bool +Mat::const_row_col_iterator::operator==(const row_col_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +Mat::const_row_col_iterator::operator!=(const row_col_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +typename Mat::iterator +Mat::begin() + { + arma_extra_debug_sigprint(); + + return memptr(); + } + + + +template +inline +typename Mat::const_iterator +Mat::begin() const + { + arma_extra_debug_sigprint(); + + return memptr(); + } + + + +template +inline +typename Mat::const_iterator +Mat::cbegin() const + { + arma_extra_debug_sigprint(); + + return memptr(); + } + + + +template +inline +typename Mat::iterator +Mat::end() + { + arma_extra_debug_sigprint(); + + return memptr() + n_elem; + } + + + +template +inline +typename Mat::const_iterator +Mat::end() const + { + arma_extra_debug_sigprint(); + + return memptr() + n_elem; + } + + + +template +inline +typename Mat::const_iterator +Mat::cend() const + { + arma_extra_debug_sigprint(); + + return memptr() + n_elem; + } + + + +template +inline +typename Mat::col_iterator +Mat::begin_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (col_num >= n_cols), "Mat::begin_col(): index out of bounds" ); + + return colptr(col_num); + } + + + +template +inline +typename Mat::const_col_iterator +Mat::begin_col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (col_num >= n_cols), "Mat::begin_col(): index out of bounds" ); + + return colptr(col_num); + } + + + +template +inline +typename Mat::col_iterator +Mat::end_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (col_num >= n_cols), "Mat::end_col(): index out of bounds" ); + + return colptr(col_num) + n_rows; + } + + + +template +inline +typename Mat::const_col_iterator +Mat::end_col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (col_num >= n_cols), "Mat::end_col(): index out of bounds" ); + + return colptr(col_num) + n_rows; + } + + + +template +inline +typename Mat::row_iterator +Mat::begin_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= n_rows), "Mat::begin_row(): index out of bounds" ); + + return typename Mat::row_iterator(*this, row_num, uword(0)); + } + + + +template +inline +typename Mat::const_row_iterator +Mat::begin_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= n_rows), "Mat::begin_row(): index out of bounds" ); + + return typename Mat::const_row_iterator(*this, row_num, uword(0)); + } + + + +template +inline +typename Mat::row_iterator +Mat::end_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= n_rows), "Mat::end_row(): index out of bounds" ); + + return typename Mat::row_iterator(*this, (row_num + uword(1)), 0); + } + + + +template +inline +typename Mat::const_row_iterator +Mat::end_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= n_rows), "Mat::end_row(): index out of bounds" ); + + return typename Mat::const_row_iterator(*this, (row_num + uword(1)), 0); + } + + + +template +inline +typename Mat::row_col_iterator +Mat::begin_row_col() + { + return row_col_iterator(*this); + } + + + +template +inline +typename Mat::const_row_col_iterator +Mat::begin_row_col() const + { + return const_row_col_iterator(*this); + } + + + +template +inline typename Mat::row_col_iterator +Mat::end_row_col() + { + return row_col_iterator(*this, 0, n_cols); + } + + + +template +inline typename Mat::const_row_col_iterator +Mat::end_row_col() const + { + return const_row_col_iterator(*this, 0, n_cols); + } + + + +//! resets this matrix to an empty matrix +template +inline +void +Mat::clear() + { + reset(); + } + + + +//! returns true if the matrix has no elements +template +inline +bool +Mat::empty() const + { + return (n_elem == 0); + } + + + +//! returns the number of elements in this matrix +template +inline +uword +Mat::size() const + { + return n_elem; + } + + + +template +inline +eT& +Mat::front() + { + arma_debug_check( (n_elem == 0), "Mat::front(): matrix is empty" ); + + return access::rw(mem[0]); + } + + + +template +inline +const eT& +Mat::front() const + { + arma_debug_check( (n_elem == 0), "Mat::front(): matrix is empty" ); + + return mem[0]; + } + + + +template +inline +eT& +Mat::back() + { + arma_debug_check( (n_elem == 0), "Mat::back(): matrix is empty" ); + + return access::rw(mem[n_elem-1]); + } + + + +template +inline +const eT& +Mat::back() const + { + arma_debug_check( (n_elem == 0), "Mat::back(): matrix is empty" ); + + return mem[n_elem-1]; + } + + + +template +template +arma_inline +Mat::fixed::fixed() + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Mat::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + } + } + + + +template +template +arma_inline +Mat::fixed::fixed(const fixed& X) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + eT* dest = (use_extra) ? mem_local_extra : mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + } + + + +template +template +inline +Mat::fixed::fixed(const fill::scalar_holder f) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).fill(f.scalar); + } + + + +template +template +template +inline +Mat::fixed::fixed(const fill::fill_class&) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } + } + + + +template +template +template +inline +Mat::fixed::fixed(const Base& A) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Mat::operator=(A.get_ref()); + } + + + +template +template +template +inline +Mat::fixed::fixed(const Base& A, const Base& B) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Mat::init(A,B); + } + + + +template +template +inline +Mat::fixed::fixed(const eT* aux_mem) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + eT* dest = (use_extra) ? mem_local_extra : mem_local; + + arrayops::copy( dest, aux_mem, fixed_n_elem ); + } + + + +template +template +inline +Mat::fixed::fixed(const char* text) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Mat::operator=(text); + } + + + +template +template +inline +Mat::fixed::fixed(const std::string& text) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Mat::operator=(text); + } + + + +template +template +inline +Mat::fixed::fixed(const std::initializer_list& list) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(list); + } + + + +template +template +inline +Mat& +Mat::fixed::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + arma_debug_check( (N > fixed_n_elem), "Mat::fixed: initialiser list is too long" ); + + eT* this_mem = (*this).memptr(); + + arrayops::copy( this_mem, list.begin(), N ); + + for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } + + return *this; + } + + + +template +template +inline +Mat::fixed::fixed(const std::initializer_list< std::initializer_list >& list) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Mat::init(list); + } + + + +template +template +inline +Mat& +Mat::fixed::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + Mat::init(list); + + return *this; + } + + + +template +template +arma_inline +Mat& +Mat::fixed::operator=(const fixed& X) + { + arma_extra_debug_sigprint(); + + if(this != &X) + { + eT* dest = (use_extra) ? mem_local_extra : mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + } + + return *this; + } + + + +#if defined(ARMA_GOOD_COMPILER) + + template + template + template + inline + Mat& + Mat::fixed::operator=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); (*this) = tmp; return *this; } + + arma_debug_assert_same_size(fixed_n_rows, fixed_n_cols, X.get_n_rows(), X.get_n_cols(), "Mat::fixed::operator="); + + eop_type::apply(*this, X); + + return *this; + } + + + + template + template + template + inline + Mat& + Mat::fixed::operator=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); (*this) = tmp; return *this; } + + arma_debug_assert_same_size(fixed_n_rows, fixed_n_cols, X.get_n_rows(), X.get_n_cols(), "Mat::fixed::operator="); + + eglue_type::apply(*this, X); + + return *this; + } + +#endif + + + +template +template +arma_inline +const Op< typename Mat::template fixed::Mat_fixed_type, op_htrans > +Mat::fixed::t() const + { + return Op< typename Mat::template fixed::Mat_fixed_type, op_htrans >(*this); + } + + + +template +template +arma_inline +const Op< typename Mat::template fixed::Mat_fixed_type, op_htrans > +Mat::fixed::ht() const + { + return Op< typename Mat::template fixed::Mat_fixed_type, op_htrans >(*this); + } + + + +template +template +arma_inline +const Op< typename Mat::template fixed::Mat_fixed_type, op_strans > +Mat::fixed::st() const + { + return Op< typename Mat::template fixed::Mat_fixed_type, op_strans >(*this); + } + + + +template +template +arma_inline +const eT& +Mat::fixed::at_alt(const uword ii) const + { + #if defined(ARMA_HAVE_ALIGNED_ATTRIBUTE) + + return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; + + #else + const eT* mem_aligned = (use_extra) ? mem_local_extra : mem_local; + + memory::mark_as_aligned(mem_aligned); + + return mem_aligned[ii]; + #endif + } + + + +template +template +arma_inline +eT& +Mat::fixed::operator[] (const uword ii) + { + return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Mat::fixed::operator[] (const uword ii) const + { + return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Mat::fixed::at(const uword ii) + { + return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Mat::fixed::at(const uword ii) const + { + return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Mat::fixed::operator() (const uword ii) + { + arma_debug_check_bounds( (ii >= fixed_n_elem), "Mat::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Mat::fixed::operator() (const uword ii) const + { + arma_debug_check_bounds( (ii >= fixed_n_elem), "Mat::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; + } + + + +#if defined(__cpp_multidimensional_subscript) + + template + template + arma_inline + eT& + Mat::fixed::operator[] (const uword in_row, const uword in_col) + { + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + + + + template + template + arma_inline + const eT& + Mat::fixed::operator[] (const uword in_row, const uword in_col) const + { + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + +#endif + + + +template +template +arma_inline +eT& +Mat::fixed::at(const uword in_row, const uword in_col) + { + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + + + +template +template +arma_inline +const eT& +Mat::fixed::at(const uword in_row, const uword in_col) const + { + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + + + +template +template +arma_inline +eT& +Mat::fixed::operator() (const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= fixed_n_rows) || (in_col >= fixed_n_cols)), "Mat::operator(): index out of bounds" ); + + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + + + +template +template +arma_inline +const eT& +Mat::fixed::operator() (const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= fixed_n_rows) || (in_col >= fixed_n_cols)), "Mat::operator(): index out of bounds" ); + + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + + + +template +template +arma_inline +eT* +Mat::fixed::colptr(const uword in_col) + { + eT* mem_actual = (use_extra) ? mem_local_extra : mem_local; + + return & access::rw(mem_actual[in_col*fixed_n_rows]); + } + + + +template +template +arma_inline +const eT* +Mat::fixed::colptr(const uword in_col) const + { + const eT* mem_actual = (use_extra) ? mem_local_extra : mem_local; + + return & mem_actual[in_col*fixed_n_rows]; + } + + + +template +template +arma_inline +eT* +Mat::fixed::memptr() + { + return (use_extra) ? mem_local_extra : mem_local; + } + + + +template +template +arma_inline +const eT* +Mat::fixed::memptr() const + { + return (use_extra) ? mem_local_extra : mem_local; + } + + + +template +template +arma_inline +bool +Mat::fixed::is_vec() const + { + return ( (fixed_n_rows == 1) || (fixed_n_cols == 1) ); + } + + + +template +template +inline +const Mat& +Mat::fixed::fill(const eT val) + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, val ); + + return *this; + } + + + +template +template +inline +const Mat& +Mat::fixed::zeros() + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + + return *this; + } + + + +template +template +inline +const Mat& +Mat::fixed::ones() + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(1) ); + + return *this; + } + + + +//! prefix ++ +template +inline +void +Mat_aux::prefix_pp(Mat& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +inline +void +Mat_aux::prefix_pp(Mat< std::complex >& x) + { + x += T(1); + } + + + +//! postfix ++ +template +inline +void +Mat_aux::postfix_pp(Mat& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +inline +void +Mat_aux::postfix_pp(Mat< std::complex >& x) + { + x += T(1); + } + + + +//! prefix -- +template +inline +void +Mat_aux::prefix_mm(Mat& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +inline +void +Mat_aux::prefix_mm(Mat< std::complex >& x) + { + x -= T(1); + } + + + +//! postfix -- +template +inline +void +Mat_aux::postfix_mm(Mat& x) + { + eT* memptr = x.memptr(); + const uword n_elem = x.n_elem; + + uword i,j; + + for(i=0, j=1; j +inline +void +Mat_aux::postfix_mm(Mat< std::complex >& x) + { + x -= T(1); + } + + + +template +inline +void +Mat_aux::set_real(Mat& out, const Base& X) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + arma_debug_assert_same_size( out, A, "Mat::set_real()" ); + + out = A; + } + + + +template +inline +void +Mat_aux::set_imag(Mat&, const Base&) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +Mat_aux::set_real(Mat< std::complex >& out, const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy P(X.get_ref()); + + const uword local_n_rows = P.get_n_rows(); + const uword local_n_cols = P.get_n_cols(); + + arma_debug_assert_same_size( out.n_rows, out.n_cols, local_n_rows, local_n_cols, "Mat::set_real()" ); + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + const uword N = out.n_elem; + + for(uword i=0; i +inline +void +Mat_aux::set_imag(Mat< std::complex >& out, const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy P(X.get_ref()); + + const uword local_n_rows = P.get_n_rows(); + const uword local_n_cols = P.get_n_cols(); + + arma_debug_assert_same_size( out.n_rows, out.n_cols, local_n_rows, local_n_cols, "Mat::set_imag()" ); + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + const uword N = out.n_elem; + + for(uword i=0; i +class OpCube : public BaseCube< typename T1::elem_type, OpCube > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline explicit OpCube(const BaseCube& in_m); + inline OpCube(const BaseCube& in_m, const elem_type in_aux); + inline OpCube(const BaseCube& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); + inline OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); + inline OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); + inline ~OpCube(); + + arma_aligned const T1& m; //!< the operand; must be derived from BaseCube + arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 + arma_aligned uword aux_uword_a; //!< auxiliary data, uword format + arma_aligned uword aux_uword_b; //!< auxiliary data, uword format + arma_aligned uword aux_uword_c; //!< auxiliary data, uword format + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/OpCube_meat.hpp b/src/armadillo/include/armadillo_bits/OpCube_meat.hpp new file mode 100644 index 0000000..a20d6b9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/OpCube_meat.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup OpCube +//! @{ + + + +template +OpCube::OpCube(const BaseCube& in_m) + : m(in_m.get_ref()) + { + arma_extra_debug_sigprint(); + } + + + +template +OpCube::OpCube(const BaseCube& in_m, const typename T1::elem_type in_aux) + : m(in_m.get_ref()) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + +template +OpCube::OpCube(const BaseCube& in_m, const typename T1::elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c) + : m(in_m.get_ref()) + , aux(in_aux) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + , aux_uword_c(in_aux_uword_c) + { + arma_extra_debug_sigprint(); + } + + + + +template +OpCube::OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b) + : m(in_m.get_ref()) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +OpCube::OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c) + : m(in_m.get_ref()) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + , aux_uword_c(in_aux_uword_c) + { + arma_extra_debug_sigprint(); + } + + + +template +OpCube::~OpCube() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Op_bones.hpp b/src/armadillo/include/armadillo_bits/Op_bones.hpp new file mode 100644 index 0000000..fa8c3ef --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Op_bones.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Op +//! @{ + + + +template +struct Op_traits {}; + + +template +struct Op_traits + { + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; + }; + +template +struct Op_traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + }; + + +template +class Op + : public Base< typename T1::elem_type, Op > + , public Op_traits::value> + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline explicit Op(const T1& in_m); + inline Op(const T1& in_m, const elem_type in_aux); + inline Op(const T1& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b); + inline Op(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); + inline ~Op(); + + arma_aligned const T1& m; //!< the operand; must be derived from Base + arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 + arma_aligned uword aux_uword_a; //!< auxiliary data, uword format + arma_aligned uword aux_uword_b; //!< auxiliary data, uword format + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Op_meat.hpp b/src/armadillo/include/armadillo_bits/Op_meat.hpp new file mode 100644 index 0000000..cd08ff9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Op_meat.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Op +//! @{ + + + +template +inline +Op::Op(const T1& in_m) + : m(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Op::Op(const T1& in_m, const typename T1::elem_type in_aux) + : m(in_m) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Op::Op(const T1& in_m, const typename T1::elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b) + : m(in_m) + , aux(in_aux) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Op::Op(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b) + : m(in_m) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Op::~Op() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Proxy.hpp b/src/armadillo/include/armadillo_bits/Proxy.hpp new file mode 100644 index 0000000..ca3f713 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Proxy.hpp @@ -0,0 +1,2537 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Proxy +//! @{ + + +// within each specialisation of the Proxy class: +// +// elem_type = the type of the elements obtained from object Q +// pod_type = the underlying type of elements if elem_type is std::complex +// stored_type = the type of the Q object +// ea_type = the type of the object that provides access to elements via operator[i] +// aligned_ea_type = the type of the object that provides access to elements via at_alt(i) +// +// use_at = boolean indicating whether at(row,col) must be used to get elements +// use_mp = boolean indicating whether OpenMP can be used while processing elements +// has_subview = boolean indicating whether the Q object has a subview +// +// is_row = boolean indicating whether the Q object can be treated a row vector +// is_col = boolean indicating whether the Q object can be treated a column vector +// is_xvec = boolean indicating whether the Q object is a vector with unknown orientation +// +// Q = object that can be unwrapped via the unwrap family of classes (ie. Q must be convertible to Mat) +// +// get_n_rows() = return the number of rows in Q +// get_n_cols() = return the number of columns in Q +// get_n_elem() = return the number of elements in Q +// +// operator[i] = linear element accessor; valid only if the 'use_at' boolean is false +// at(row,col) = access elements via (row,col); valid only if the 'use_at' boolean is true +// at_alt(i) = aligned linear element accessor; valid only if the 'use_at' boolean is false and is_aligned() returns true +// +// get_ea() = return the object that provides linear access to elements via operator[i] +// get_aligned_ea() = return the object that provides linear access to elements via at_alt(i); valid only if is_aligned() returns true +// +// is_alias(X) = return true/false indicating whether the Q object aliases matrix X +// has_overlap(X) = return true/false indicating whether the Q object has overlap with subview X +// is_aligned() = return true/false indicating whether the Q object has aligned memory + + + +template +struct Proxy_default + { + inline Proxy_default(const T1&) + { + arma_type_check(( is_arma_type::value == false )); + } + }; + + + +template +struct Proxy_fixed + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef T1 stored_type; + typedef const elem_type* ea_type; + typedef const T1& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + + arma_aligned const T1& Q; + + inline explicit Proxy_fixed(const T1& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + //// this may require T1::n_elem etc to be declared as static constexpr inline variables (C++17) + //// see also the notes in Mat::fixed + //// https://en.cppreference.com/w/cpp/language/static + //// https://en.cppreference.com/w/cpp/language/inline + // + // static constexpr uword get_n_rows() { return T1::n_rows; } + // static constexpr uword get_n_cols() { return T1::n_cols; } + // static constexpr uword get_n_elem() { return T1::n_elem; } + + arma_inline uword get_n_rows() const { return is_row ? 1 : T1::n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : T1::n_cols; } + arma_inline uword get_n_elem() const { return T1::n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&Q) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const + { + #if defined(ARMA_HAVE_ALIGNED_ATTRIBUTE) + return true; + #else + return memory::is_aligned(Q.memptr()); + #endif + } + }; + + + +template +struct Proxy_redirect {}; + +template +struct Proxy_redirect { typedef Proxy_default result; }; + +template +struct Proxy_redirect { typedef Proxy_fixed result; }; + + + +template +struct Proxy : public Proxy_redirect::value>::result + { + inline Proxy(const T1& A) + : Proxy_redirect::value>::result(A) + { + } + }; + + + +template +struct Proxy< Mat > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const eT* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const Mat& Q; + + inline explicit Proxy(const Mat& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< Col > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Col stored_type; + typedef const eT* ea_type; + typedef const Col& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const Col& Q; + + inline explicit Proxy(const Col& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< Row > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Row stored_type; + typedef const eT* ea_type; + typedef const Row& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const Row& Q; + + inline explicit Proxy(const Row& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + constexpr uword get_n_rows() const { return 1; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< Gen > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Gen stored_type; + typedef const Gen& ea_type; + typedef const Gen& aligned_ea_type; + + static constexpr bool use_at = Gen::use_at; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = Gen::is_row; + static constexpr bool is_col = Gen::is_col; + static constexpr bool is_xvec = Gen::is_xvec; + + arma_aligned const Gen& Q; + + inline explicit Proxy(const Gen& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return (is_row ? 1 : Q.n_rows); } + arma_inline uword get_n_cols() const { return (is_col ? 1 : Q.n_cols); } + arma_inline uword get_n_elem() const { return (is_row ? 1 : Q.n_rows) * (is_col ? 1 : Q.n_cols); } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + constexpr bool is_aligned() const { return Gen::is_simple; } + }; + + + +template +struct Proxy< eOp > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef eOp stored_type; + typedef const eOp& ea_type; + typedef const eOp& aligned_ea_type; + + static constexpr bool use_at = eOp::use_at; + static constexpr bool use_mp = eOp::use_mp; + static constexpr bool has_subview = eOp::has_subview; + + static constexpr bool is_row = eOp::is_row; + static constexpr bool is_col = eOp::is_col; + static constexpr bool is_xvec = eOp::is_xvec; + + arma_aligned const eOp& Q; + + inline explicit Proxy(const eOp& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.get_n_rows(); } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.get_n_cols(); } + arma_inline uword get_n_elem() const { return Q.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return Q.P.is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return Q.P.has_overlap(X); } + + arma_inline bool is_aligned() const { return Q.P.is_aligned(); } + }; + + + +template +struct Proxy< eGlue > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef eGlue stored_type; + typedef const eGlue& ea_type; + typedef const eGlue& aligned_ea_type; + + static constexpr bool use_at = eGlue::use_at; + static constexpr bool use_mp = eGlue::use_mp; + static constexpr bool has_subview = eGlue::has_subview; + + static constexpr bool is_row = eGlue::is_row; + static constexpr bool is_col = eGlue::is_col; + static constexpr bool is_xvec = eGlue::is_xvec; + + arma_aligned const eGlue& Q; + + inline explicit Proxy(const eGlue& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.get_n_rows(); } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.get_n_cols(); } + arma_inline uword get_n_elem() const { return Q.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (Q.P1.is_alias(X) || Q.P2.is_alias(X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return (Q.P1.has_overlap(X) || Q.P2.has_overlap(X)); } + + arma_inline bool is_aligned() const { return (Q.P1.is_aligned() && Q.P2.is_aligned()); } + }; + + + +template +struct Proxy< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = Op::is_row; + static constexpr bool is_col = Op::is_col; + static constexpr bool is_xvec = Op::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const Op& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< Glue > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = Glue::is_row; + static constexpr bool is_col = Glue::is_col; + static constexpr bool is_xvec = Glue::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const Glue& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< Glue > + { + typedef Glue this_Glue_type; + typedef Proxy< Glue > this_Proxy_type; + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef this_Glue_type stored_type; + typedef const this_Proxy_type& ea_type; + typedef const this_Proxy_type& aligned_ea_type; + + static constexpr bool use_at = (Proxy::use_at || Proxy::use_at ); + static constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp ); + static constexpr bool has_subview = (Proxy::has_subview || Proxy::has_subview); + + static constexpr bool is_row = this_Glue_type::is_row; + static constexpr bool is_col = this_Glue_type::is_col; + static constexpr bool is_xvec = this_Glue_type::is_xvec; + + arma_aligned const this_Glue_type& Q; + arma_aligned const Proxy P1; + arma_aligned const Proxy P2; + + arma_lt_comparator comparator; + + inline explicit Proxy(const this_Glue_type& X) + : Q (X ) + , P1(X.A) + , P2(X.B) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(P1, P2, "element-wise min()"); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : P1.get_n_rows(); } + arma_inline uword get_n_cols() const { return is_col ? 1 : P1.get_n_cols(); } + arma_inline uword get_n_elem() const { return P1.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { const elem_type A = P1[i]; const elem_type B = P2[i]; return comparator(A,B) ? A : B; } + arma_inline elem_type at (const uword r, const uword c) const { const elem_type A = P1.at(r,c); const elem_type B = P2.at(r,c); return comparator(A,B) ? A : B; } + arma_inline elem_type at_alt (const uword i) const { const elem_type A = P1.at_alt(i); const elem_type B = P2.at_alt(i); return comparator(A,B) ? A : B; } + + arma_inline ea_type get_ea() const { return *this; } + arma_inline aligned_ea_type get_aligned_ea() const { return *this; } + + template + arma_inline bool is_alias(const Mat& X) const { return (P1.is_alias(X) || P2.is_alias(X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return (P1.has_overlap(X) || P2.has_overlap(X)); } + + arma_inline bool is_aligned() const { return (P1.is_aligned() && P2.is_aligned()); } + }; + + + +template +struct Proxy< Glue > + { + typedef Glue this_Glue_type; + typedef Proxy< Glue > this_Proxy_type; + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef this_Glue_type stored_type; + typedef const this_Proxy_type& ea_type; + typedef const this_Proxy_type& aligned_ea_type; + + static constexpr bool use_at = (Proxy::use_at || Proxy::use_at ); + static constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp ); + static constexpr bool has_subview = (Proxy::has_subview || Proxy::has_subview); + + static constexpr bool is_row = this_Glue_type::is_row; + static constexpr bool is_col = this_Glue_type::is_col; + static constexpr bool is_xvec = this_Glue_type::is_xvec; + + arma_aligned const this_Glue_type& Q; + arma_aligned const Proxy P1; + arma_aligned const Proxy P2; + + arma_gt_comparator comparator; + + inline explicit Proxy(const this_Glue_type& X) + : Q (X ) + , P1(X.A) + , P2(X.B) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(P1, P2, "element-wise max()"); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : P1.get_n_rows(); } + arma_inline uword get_n_cols() const { return is_col ? 1 : P1.get_n_cols(); } + arma_inline uword get_n_elem() const { return P1.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { const elem_type A = P1[i]; const elem_type B = P2[i]; return comparator(A,B) ? A : B; } + arma_inline elem_type at (const uword r, const uword c) const { const elem_type A = P1.at(r,c); const elem_type B = P2.at(r,c); return comparator(A,B) ? A : B; } + arma_inline elem_type at_alt (const uword i) const { const elem_type A = P1.at_alt(i); const elem_type B = P2.at_alt(i); return comparator(A,B) ? A : B; } + + arma_inline ea_type get_ea() const { return *this; } + arma_inline aligned_ea_type get_aligned_ea() const { return *this; } + + template + arma_inline bool is_alias(const Mat& X) const { return (P1.is_alias(X) || P2.is_alias(X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return (P1.has_overlap(X) || P2.has_overlap(X)); } + + arma_inline bool is_aligned() const { return (P1.is_aligned() && P2.is_aligned()); } + }; + + + +template +struct Proxy< mtOp > + { + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = mtOp::is_row; + static constexpr bool is_col = mtOp::is_col; + static constexpr bool is_xvec = mtOp::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const mtOp& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< mtGlue > + { + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = mtGlue::is_row; + static constexpr bool is_col = mtGlue::is_col; + static constexpr bool is_xvec = mtGlue::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const mtGlue& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< CubeToMatOp > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = CubeToMatOp::is_row; + static constexpr bool is_col = CubeToMatOp::is_col; + static constexpr bool is_xvec = CubeToMatOp::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const CubeToMatOp& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< CubeToMatOp > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const unwrap_cube U; + arma_aligned const Mat Q; + + inline explicit Proxy(const CubeToMatOp& A) + : U(A.m) + , Q(const_cast(U.M.memptr()), U.M.n_elem, 1, false, true) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< SpToDOp > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = SpToDOp::is_row; + static constexpr bool is_col = SpToDOp::is_col; + static constexpr bool is_xvec = SpToDOp::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const SpToDOp& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< SpToDOp > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const unwrap_spmat U; + arma_aligned const Mat Q; + + inline explicit Proxy(const SpToDOp& A) + : U(A.m) + , Q(const_cast(U.M.values), U.M.n_nonzero, 1, false, true) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< SpToDGlue > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = SpToDGlue::is_row; + static constexpr bool is_col = SpToDGlue::is_col; + static constexpr bool is_xvec = SpToDGlue::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const SpToDGlue& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< subview > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview stored_type; + typedef const subview& ea_type; + typedef const subview& aligned_ea_type; + + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const subview& Q; + + inline explicit Proxy(const subview& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< subview_col > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_col stored_type; + typedef const eT* ea_type; + typedef const subview_col& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_col& Q; + + inline explicit Proxy(const subview_col& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.colmem; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.colmem); } + }; + + + +template +struct Proxy< subview_cols > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const eT* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const subview_cols& sv; + arma_aligned const Mat Q; + + inline explicit Proxy(const subview_cols& A) + : sv(A) + , Q ( const_cast( A.colptr(0) ), A.n_rows, A.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return sv.check_overlap(X); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< subview_row > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_row stored_type; + typedef const subview_row& ea_type; + typedef const subview_row& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row& Q; + + inline explicit Proxy(const subview_row& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + constexpr uword get_n_rows() const { return 1; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< subview_elem1 > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_elem1 stored_type; + typedef const Proxy< subview_elem1 >& ea_type; + typedef const Proxy< subview_elem1 >& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_elem1& Q; + arma_aligned const Proxy R; + + inline explicit Proxy(const subview_elem1& A) + : Q(A) + , R(A.a.get_ref()) + { + arma_extra_debug_sigprint(); + + const bool R_is_vec = ((R.get_n_rows() == 1) || (R.get_n_cols() == 1)); + const bool R_is_empty = (R.get_n_elem() == 0); + + arma_debug_check( ((R_is_vec == false) && (R_is_empty == false)), "Mat::elem(): given object must be a vector" ); + } + + arma_inline uword get_n_rows() const { return R.get_n_elem(); } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return R.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { const uword ii = (Proxy::use_at) ? R.at(i,0) : R[i]; arma_debug_check_bounds( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } + arma_inline elem_type at (const uword r, const uword) const { const uword ii = (Proxy::use_at) ? R.at(r,0) : R[r]; arma_debug_check_bounds( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } + arma_inline elem_type at_alt (const uword i) const { const uword ii = (Proxy::use_at) ? R.at(i,0) : R[i]; arma_debug_check_bounds( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } + + arma_inline ea_type get_ea() const { return (*this); } + arma_inline aligned_ea_type get_aligned_ea() const { return (*this); } + + template + arma_inline bool is_alias(const Mat& X) const { return ( (void_ptr(&X) == void_ptr(&(Q.m))) || (R.is_alias(X)) ); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< subview_elem2 > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const eT* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const Mat Q; + + inline explicit Proxy(const subview_elem2& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< diagview > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef diagview stored_type; + typedef const diagview& ea_type; + typedef const diagview& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const diagview& Q; + + inline explicit Proxy(const diagview& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy_diagvec_mat + { + inline Proxy_diagvec_mat(const T1&) {} + }; + + + +template +struct Proxy_diagvec_mat< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef diagview stored_type; + typedef const diagview& ea_type; + typedef const diagview& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const Mat& R; + arma_aligned const diagview Q; + + inline explicit Proxy_diagvec_mat(const Op& A) + : R(A.m), Q( R.diag() ) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&R) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy_diagvec_expr + { + inline Proxy_diagvec_expr(const T1&) {} + }; + + + +template +struct Proxy_diagvec_expr< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const Mat Q; + + inline explicit Proxy_diagvec_expr(const Op& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy_diagvec_redirect {}; + +template +struct Proxy_diagvec_redirect< Op, true > { typedef Proxy_diagvec_mat < Op > result; }; + +template +struct Proxy_diagvec_redirect< Op, false> { typedef Proxy_diagvec_expr< Op > result; }; + + + +template +struct Proxy< Op > + : public Proxy_diagvec_redirect< Op, is_Mat::value >::result + { + typedef typename Proxy_diagvec_redirect< Op, is_Mat::value >::result Proxy_diagvec; + + inline explicit Proxy(const Op& A) + : Proxy_diagvec(A) + { + arma_extra_debug_sigprint(); + } + }; + + + +template +struct Proxy< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const Mat Q; + + inline explicit Proxy(const Op& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy_xtrans_default + { + inline Proxy_xtrans_default(const T1&) {} + }; + + + +template +struct Proxy_xtrans_default< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef xtrans_mat stored_type; + typedef const xtrans_mat& ea_type; + typedef const xtrans_mat& aligned_ea_type; + + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const unwrap U; + const xtrans_mat Q; + + inline explicit Proxy_xtrans_default(const Op& A) + : U(A.m) + , Q(U.M) + { + arma_extra_debug_sigprint(); + } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return void_ptr(&(U.M)) == void_ptr(&X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy_xtrans_default< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef xtrans_mat stored_type; + typedef const xtrans_mat& ea_type; + typedef const xtrans_mat& aligned_ea_type; + + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const unwrap U; + const xtrans_mat Q; + + inline explicit Proxy_xtrans_default(const Op& A) + : U(A.m) + , Q(U.M) + { + arma_extra_debug_sigprint(); + } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return void_ptr(&(U.M)) == void_ptr(&X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy_xtrans_vector + { + inline Proxy_xtrans_vector(const T1&) {} + }; + + + +template +struct Proxy_xtrans_vector< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = quasi_unwrap::has_subview; + + // NOTE: the Op class takes care of swapping row and col for op_htrans + static constexpr bool is_row = Op::is_row; + static constexpr bool is_col = Op::is_col; + static constexpr bool is_xvec = Op::is_xvec; + + arma_aligned const quasi_unwrap U; // avoid copy if T1 is a Row, Col or subview_col + arma_aligned const Mat Q; + + inline Proxy_xtrans_vector(const Op& A) + : U(A.m) + , Q(const_cast(U.M.memptr()), U.M.n_cols, U.M.n_rows, false, false) + { + arma_extra_debug_sigprint(); + } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return U.is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy_xtrans_vector< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = quasi_unwrap::has_subview; + + // NOTE: the Op class takes care of swapping row and col for op_strans + static constexpr bool is_row = Op::is_row; + static constexpr bool is_col = Op::is_col; + static constexpr bool is_xvec = Op::is_xvec; + + arma_aligned const quasi_unwrap U; // avoid copy if T1 is a Row, Col or subview_col + arma_aligned const Mat Q; + + inline Proxy_xtrans_vector(const Op& A) + : U(A.m) + , Q(const_cast(U.M.memptr()), U.M.n_cols, U.M.n_rows, false, false) + { + arma_extra_debug_sigprint(); + } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return U.is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy_xtrans_redirect {}; + +template +struct Proxy_xtrans_redirect { typedef Proxy_xtrans_default result; }; + +template +struct Proxy_xtrans_redirect { typedef Proxy_xtrans_vector result; }; + + + +template +struct Proxy< Op > + : public + Proxy_xtrans_redirect + < + Op, + ((is_cx::no) && ((Op::is_row) || (Op::is_col)) ) + >::result + { + typedef + typename + Proxy_xtrans_redirect + < + Op, + ((is_cx::no) && ((Op::is_row) || (Op::is_col)) ) + >::result + Proxy_xtrans; + + typedef typename Proxy_xtrans::elem_type elem_type; + typedef typename Proxy_xtrans::pod_type pod_type; + typedef typename Proxy_xtrans::stored_type stored_type; + typedef typename Proxy_xtrans::ea_type ea_type; + typedef typename Proxy_xtrans::aligned_ea_type aligned_ea_type; + + static constexpr bool use_at = Proxy_xtrans::use_at; + static constexpr bool use_mp = Proxy_xtrans::use_mp; + static constexpr bool has_subview = Proxy_xtrans::has_subview; + + static constexpr bool is_row = Proxy_xtrans::is_row; + static constexpr bool is_col = Proxy_xtrans::is_col; + static constexpr bool is_xvec = Proxy_xtrans::is_xvec; + + using Proxy_xtrans::Q; + + inline explicit Proxy(const Op& A) + : Proxy_xtrans(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Proxy_xtrans::get_ea(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Proxy_xtrans::get_aligned_ea(); } + + template + arma_inline bool is_alias(const Mat& X) const { return Proxy_xtrans::is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return Proxy_xtrans::has_overlap(X); } + + arma_inline bool is_aligned() const { return Proxy_xtrans::is_aligned(); } + }; + + + +template +struct Proxy< Op > + : public + Proxy_xtrans_redirect + < + Op, + ( (Op::is_row) || (Op::is_col) ) + >::result + { + typedef + typename + Proxy_xtrans_redirect + < + Op, + ( (Op::is_row) || (Op::is_col) ) + >::result + Proxy_xtrans; + + typedef typename Proxy_xtrans::elem_type elem_type; + typedef typename Proxy_xtrans::pod_type pod_type; + typedef typename Proxy_xtrans::stored_type stored_type; + typedef typename Proxy_xtrans::ea_type ea_type; + typedef typename Proxy_xtrans::aligned_ea_type aligned_ea_type; + + static constexpr bool use_at = Proxy_xtrans::use_at; + static constexpr bool use_mp = Proxy_xtrans::use_mp; + static constexpr bool has_subview = Proxy_xtrans::has_subview; + + static constexpr bool is_row = Proxy_xtrans::is_row; + static constexpr bool is_col = Proxy_xtrans::is_col; + static constexpr bool is_xvec = Proxy_xtrans::is_xvec; + + using Proxy_xtrans::Q; + + inline explicit Proxy(const Op& A) + : Proxy_xtrans(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Proxy_xtrans::get_ea(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Proxy_xtrans::get_aligned_ea(); } + + template + arma_inline bool is_alias(const Mat& X) const { return Proxy_xtrans::is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return Proxy_xtrans::has_overlap(X); } + + arma_inline bool is_aligned() const { return Proxy_xtrans::is_aligned(); } + }; + + + +template +struct Proxy_subview_row_htrans_cx + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_row_htrans stored_type; + typedef const subview_row_htrans& ea_type; + typedef const subview_row_htrans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row_htrans Q; + + inline explicit Proxy_subview_row_htrans_cx(const Op, op_htrans>& A) + : Q(A.m) + { + arma_extra_debug_sigprint(); + } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&(Q.sv_row.m)) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + }; + + + +template +struct Proxy_subview_row_htrans_non_cx + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_row_strans stored_type; + typedef const subview_row_strans& ea_type; + typedef const subview_row_strans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row_strans Q; + + inline explicit Proxy_subview_row_htrans_non_cx(const Op, op_htrans>& A) + : Q(A.m) + { + arma_extra_debug_sigprint(); + } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&(Q.sv_row.m)) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + }; + + + +template +struct Proxy_subview_row_htrans_redirect {}; + +template +struct Proxy_subview_row_htrans_redirect { typedef Proxy_subview_row_htrans_cx result; }; + +template +struct Proxy_subview_row_htrans_redirect { typedef Proxy_subview_row_htrans_non_cx result; }; + + + +template +struct Proxy< Op, op_htrans> > + : public + Proxy_subview_row_htrans_redirect + < + eT, + is_cx::yes + >::result + { + typedef + typename + Proxy_subview_row_htrans_redirect + < + eT, + is_cx::yes + >::result + Proxy_sv_row_ht; + + typedef typename Proxy_sv_row_ht::elem_type elem_type; + typedef typename Proxy_sv_row_ht::pod_type pod_type; + typedef typename Proxy_sv_row_ht::stored_type stored_type; + typedef typename Proxy_sv_row_ht::ea_type ea_type; + typedef typename Proxy_sv_row_ht::ea_type aligned_ea_type; + + static constexpr bool use_at = Proxy_sv_row_ht::use_at; + static constexpr bool use_mp = Proxy_sv_row_ht::use_mp; + static constexpr bool has_subview = Proxy_sv_row_ht::has_subview; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + using Proxy_sv_row_ht::Q; + + inline explicit Proxy(const Op, op_htrans>& A) + : Proxy_sv_row_ht(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return Proxy_sv_row_ht::is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return Proxy_sv_row_ht::has_overlap(X); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< Op, op_strans> > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_row_strans stored_type; + typedef const subview_row_strans& ea_type; + typedef const subview_row_strans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row_strans Q; + + inline explicit Proxy(const Op, op_strans>& A) + : Q(A.m) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&(Q.sv_row.m)) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< Op< Row< std::complex >, op_htrans> > + { + typedef typename std::complex eT; + + typedef typename std::complex elem_type; + typedef T pod_type; + typedef xvec_htrans stored_type; + typedef const xvec_htrans& ea_type; + typedef const xvec_htrans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + const xvec_htrans Q; + const Row& src; + + inline explicit Proxy(const Op< Row< std::complex >, op_htrans>& A) + : Q (A.m.memptr(), A.m.n_rows, A.m.n_cols) + , src(A.m) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return void_ptr(&src) == void_ptr(&X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< Op< Col< std::complex >, op_htrans> > + { + typedef typename std::complex eT; + + typedef typename std::complex elem_type; + typedef T pod_type; + typedef xvec_htrans stored_type; + typedef const xvec_htrans& ea_type; + typedef const xvec_htrans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const xvec_htrans Q; + const Col& src; + + inline explicit Proxy(const Op< Col< std::complex >, op_htrans>& A) + : Q (A.m.memptr(), A.m.n_rows, A.m.n_cols) + , src(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr uword get_n_rows() const { return 1; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return void_ptr(&src) == void_ptr(&X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< Op< subview_col< std::complex >, op_htrans> > + { + typedef typename std::complex eT; + + typedef typename std::complex elem_type; + typedef T pod_type; + typedef xvec_htrans stored_type; + typedef const xvec_htrans& ea_type; + typedef const xvec_htrans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const xvec_htrans Q; + const subview_col& src; + + inline explicit Proxy(const Op< subview_col< std::complex >, op_htrans>& A) + : Q (A.m.colptr(0), A.m.n_rows, A.m.n_cols) + , src(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr uword get_n_rows() const { return 1; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return void_ptr(&src.m) == void_ptr(&X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef eOp< Op, eop_scalar_times> stored_type; + typedef const eOp< Op, eop_scalar_times>& ea_type; + typedef const eOp< Op, eop_scalar_times>& aligned_ea_type; + + static constexpr bool use_at = eOp< Op, eop_scalar_times>::use_at; + static constexpr bool use_mp = eOp< Op, eop_scalar_times>::use_mp; + static constexpr bool has_subview = eOp< Op, eop_scalar_times>::has_subview; + + // NOTE: the Op class takes care of swapping row and col for op_htrans + static constexpr bool is_row = eOp< Op, eop_scalar_times>::is_row; + static constexpr bool is_col = eOp< Op, eop_scalar_times>::is_col; + static constexpr bool is_xvec = eOp< Op, eop_scalar_times>::is_xvec; + + arma_aligned const Op R; + arma_aligned const eOp< Op, eop_scalar_times > Q; + + inline explicit Proxy(const Op& A) + : R(A.m) + , Q(R, A.aux) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.get_n_rows(); } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.get_n_cols(); } + arma_inline uword get_n_elem() const { return Q.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return Q.P.is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return Q.P.is_aligned(); } + }; + + + +template +struct Proxy< subview_row_strans > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_row_strans stored_type; + typedef const subview_row_strans& ea_type; + typedef const subview_row_strans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row_strans& Q; + + inline explicit Proxy(const subview_row_strans& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&(Q.sv_row.m)) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< subview_row_htrans > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_row_htrans stored_type; + typedef const subview_row_htrans& ea_type; + typedef const subview_row_htrans& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row_htrans& Q; + + inline explicit Proxy(const subview_row_htrans& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&(Q.sv_row.m)) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct Proxy< xtrans_mat > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const eT* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const Mat Q; + + inline explicit Proxy(const xtrans_mat& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< xvec_htrans > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const eT* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; + + arma_aligned const Mat Q; + + inline explicit Proxy(const xvec_htrans& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy_vectorise_col_mat + { + inline Proxy_vectorise_col_mat(const T1&) {} + }; + + + +template +struct Proxy_vectorise_col_mat< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const unwrap U; + arma_aligned const Mat Q; + + inline explicit Proxy_vectorise_col_mat(const Op& A) + : U(A.m) + , Q(const_cast(U.M.memptr()), U.M.n_elem, 1, false, false) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return ( void_ptr(&X) == void_ptr(&(U.M)) ); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy_vectorise_col_expr + { + inline Proxy_vectorise_col_expr(const T1&) {} + }; + + + +template +struct Proxy_vectorise_col_expr< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Op stored_type; + typedef typename Proxy::ea_type ea_type; + typedef typename Proxy::aligned_ea_type aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = Proxy::use_mp; + static constexpr bool has_subview = Proxy::has_subview; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const Op& Q; + arma_aligned const Proxy R; + + inline explicit Proxy_vectorise_col_expr(const Op& A) + : Q(A) + , R(A.m) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return R.get_n_elem(); } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return R.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { return R[i]; } + arma_inline elem_type at (const uword r, const uword) const { return R.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return R.at_alt(i); } + + arma_inline ea_type get_ea() const { return R.get_ea(); } + arma_inline aligned_ea_type get_aligned_ea() const { return R.get_aligned_ea(); } + + template + arma_inline bool is_alias(const Mat& X) const { return R.is_alias(X); } + + template + arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return R.is_aligned(); } + }; + + + +template +struct Proxy_vectorise_col_redirect {}; + +template +struct Proxy_vectorise_col_redirect< Op, true > { typedef Proxy_vectorise_col_mat < Op > result; }; + +template +struct Proxy_vectorise_col_redirect< Op, false> { typedef Proxy_vectorise_col_expr< Op > result; }; + + + +template +struct Proxy< Op > + : public Proxy_vectorise_col_redirect< Op, (Proxy::use_at) >::result + { + typedef typename Proxy_vectorise_col_redirect< Op, (Proxy::use_at) >::result Proxy_vectorise_col; + + inline explicit Proxy(const Op& A) + : Proxy_vectorise_col(A) + { + arma_extra_debug_sigprint(); + } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/ProxyCube.hpp b/src/armadillo/include/armadillo_bits/ProxyCube.hpp new file mode 100644 index 0000000..ef63928 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/ProxyCube.hpp @@ -0,0 +1,488 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup ProxyCube +//! @{ + + + +template +struct ProxyCube + { + inline ProxyCube(const T1&) + { + arma_type_check(( is_arma_cube_type::value == false )); + } + }; + + + +// ea_type is the "element accessor" type, +// which can provide access to elements via operator[] + +template +struct ProxyCube< Cube > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Cube stored_type; + typedef const eT* ea_type; + typedef const Cube& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + arma_aligned const Cube& Q; + + inline explicit ProxyCube(const Cube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Cube& X) const { return (void_ptr(&Q) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview_cube& X) const { return is_alias(X.m); } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct ProxyCube< GenCube > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef GenCube stored_type; + typedef const GenCube& ea_type; + typedef const GenCube& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + arma_aligned const GenCube& Q; + + inline explicit ProxyCube(const GenCube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_rows*Q.n_cols; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_rows*Q.n_cols*Q.n_slices; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } + + constexpr bool is_aligned() const { return GenCube::is_simple; } + }; + + + +template +struct ProxyCube< OpCube > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Cube stored_type; + typedef const elem_type* ea_type; + typedef const Cube& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + arma_aligned const Cube Q; + + inline explicit ProxyCube(const OpCube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct ProxyCube< GlueCube > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Cube stored_type; + typedef const elem_type* ea_type; + typedef const Cube& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + arma_aligned const Cube Q; + + inline explicit ProxyCube(const GlueCube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct ProxyCube< subview_cube > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef subview_cube stored_type; + typedef const subview_cube& ea_type; + typedef const subview_cube& aligned_ea_type; + + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + arma_aligned const subview_cube& Q; + + inline explicit ProxyCube(const subview_cube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Cube& X) const { return (void_ptr(&(Q.m)) == void_ptr(&X)); } + + template + arma_inline bool has_overlap(const subview_cube& X) const { return Q.check_overlap(X); } + + constexpr bool is_aligned() const { return false; } + }; + + + +template +struct ProxyCube< subview_cube_slices > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Cube stored_type; + typedef const eT* ea_type; + typedef const Cube& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + arma_aligned const Cube Q; + + inline explicit ProxyCube(const subview_cube_slices& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct ProxyCube< eOpCube > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef eOpCube stored_type; + typedef const eOpCube& ea_type; + typedef const eOpCube& aligned_ea_type; + + static constexpr bool use_at = eOpCube::use_at; + static constexpr bool use_mp = eOpCube::use_mp; + static constexpr bool has_subview = eOpCube::has_subview; + + arma_aligned const eOpCube& Q; + + inline explicit ProxyCube(const eOpCube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.get_n_rows(); } + arma_inline uword get_n_cols() const { return Q.get_n_cols(); } + arma_inline uword get_n_elem_slice() const { return Q.get_n_elem_slice(); } + arma_inline uword get_n_slices() const { return Q.get_n_slices(); } + arma_inline uword get_n_elem() const { return Q.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Cube& X) const { return Q.P.is_alias(X); } + + template + arma_inline bool has_overlap(const subview_cube& X) const { return Q.P.has_overlap(X); } + + arma_inline bool is_aligned() const { return Q.P.is_aligned(); } + }; + + + +template +struct ProxyCube< eGlueCube > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef eGlueCube stored_type; + typedef const eGlueCube& ea_type; + typedef const eGlueCube& aligned_ea_type; + + static constexpr bool use_at = eGlueCube::use_at; + static constexpr bool use_mp = eGlueCube::use_mp; + static constexpr bool has_subview = eGlueCube::has_subview; + + arma_aligned const eGlueCube& Q; + + inline explicit ProxyCube(const eGlueCube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.get_n_rows(); } + arma_inline uword get_n_cols() const { return Q.get_n_cols(); } + arma_inline uword get_n_elem_slice() const { return Q.get_n_elem_slice(); } + arma_inline uword get_n_slices() const { return Q.get_n_slices(); } + arma_inline uword get_n_elem() const { return Q.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q; } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Cube& X) const { return (Q.P1.is_alias(X) || Q.P2.is_alias(X)); } + + template + arma_inline bool has_overlap(const subview_cube& X) const { return (Q.P1.has_overlap(X) || Q.P2.has_overlap(X)); } + + arma_inline bool is_aligned() const { return Q.P1.is_aligned() && Q.P2.is_aligned(); } + }; + + + +template +struct ProxyCube< mtOpCube > + { + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Cube stored_type; + typedef const elem_type* ea_type; + typedef const Cube& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + arma_aligned const Cube Q; + + inline explicit ProxyCube(const mtOpCube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct ProxyCube< mtGlueCube > + { + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Cube stored_type; + typedef const elem_type* ea_type; + typedef const Cube& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + arma_aligned const Cube Q; + + inline explicit ProxyCube(const mtGlueCube& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } + arma_inline uword get_n_slices() const { return Q.n_slices; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Row_bones.hpp b/src/armadillo/include/armadillo_bits/Row_bones.hpp new file mode 100644 index 0000000..96dff22 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Row_bones.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Row +//! @{ + +//! Class for row vectors (matrices with only one row) + +template +class Row : public Mat + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_col = false; + static constexpr bool is_row = true; + static constexpr bool is_xvec = false; + + inline Row(); + inline Row(const Row& X); + + inline explicit Row(const uword N); + inline explicit Row(const uword in_rows, const uword in_cols); + inline explicit Row(const SizeMat& s); + + template inline explicit Row(const uword N, const arma_initmode_indicator&); + template inline explicit Row(const uword in_rows, const uword in_cols, const arma_initmode_indicator&); + template inline explicit Row(const SizeMat& s, const arma_initmode_indicator&); + + template inline Row(const uword n_elem, const fill::fill_class& f); + template inline Row(const uword in_rows, const uword in_cols, const fill::fill_class& f); + template inline Row(const SizeMat& s, const fill::fill_class& f); + + inline Row(const uword N, const fill::scalar_holder f); + inline Row(const uword in_rows, const uword in_cols, const fill::scalar_holder f); + inline Row(const SizeMat& s, const fill::scalar_holder f); + + inline Row(const char* text); + inline Row& operator=(const char* text); + + inline Row(const std::string& text); + inline Row& operator=(const std::string& text); + + inline Row(const std::vector& x); + inline Row& operator=(const std::vector& x); + + inline Row(const std::initializer_list& list); + inline Row& operator=(const std::initializer_list& list); + + inline Row(Row&& m); + inline Row& operator=(Row&& m); + + // inline Row(Mat&& m); + // inline Row& operator=(Mat&& m); + + inline Row& operator=(const eT val); + inline Row& operator=(const Row& X); + + template inline Row(const Base& X); + template inline Row& operator=(const Base& X); + + template inline explicit Row(const SpBase& X); + template inline Row& operator=(const SpBase& X); + + inline Row( eT* aux_mem, const uword aux_length, const bool copy_aux_mem = true, const bool strict = false); + inline Row(const eT* aux_mem, const uword aux_length); + + template + inline explicit Row(const Base& A, const Base& B); + + template inline Row(const BaseCube& X); + template inline Row& operator=(const BaseCube& X); + + inline Row(const subview_cube& X); + inline Row& operator=(const subview_cube& X); + + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const eT val); + + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; + + arma_warn_unused arma_inline const Op,op_strans> as_col() const; + + arma_inline subview_row col(const uword col_num); + arma_inline const subview_row col(const uword col_num) const; + + using Mat::cols; + using Mat::operator(); + + arma_inline subview_row cols(const uword in_col1, const uword in_col2); + arma_inline const subview_row cols(const uword in_col1, const uword in_col2) const; + + arma_inline subview_row subvec(const uword in_col1, const uword in_col2); + arma_inline const subview_row subvec(const uword in_col1, const uword in_col2) const; + + arma_inline subview_row cols(const span& col_span); + arma_inline const subview_row cols(const span& col_span) const; + + arma_inline subview_row subvec(const span& col_span); + arma_inline const subview_row subvec(const span& col_span) const; + + arma_inline subview_row operator()(const span& col_span); + arma_inline const subview_row operator()(const span& col_span) const; + + arma_inline subview_row subvec(const uword start_col, const SizeMat& s); + arma_inline const subview_row subvec(const uword start_col, const SizeMat& s) const; + + arma_inline subview_row head(const uword N); + arma_inline const subview_row head(const uword N) const; + + arma_inline subview_row tail(const uword N); + arma_inline const subview_row tail(const uword N) const; + + arma_inline subview_row head_cols(const uword N); + arma_inline const subview_row head_cols(const uword N) const; + + arma_inline subview_row tail_cols(const uword N); + arma_inline const subview_row tail_cols(const uword N) const; + + + inline void shed_col (const uword col_num); + inline void shed_cols(const uword in_col1, const uword in_col2); + + template inline void shed_cols(const Base& indices); + + arma_deprecated inline void insert_cols(const uword col_num, const uword N, const bool set_to_zero); + inline void insert_cols(const uword col_num, const uword N); + + template inline void insert_cols(const uword col_num, const Base& X); + + + arma_warn_unused arma_inline eT& at(const uword i); + arma_warn_unused arma_inline const eT& at(const uword i) const; + + arma_warn_unused arma_inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at(const uword in_row, const uword in_col) const; + + + typedef eT* row_iterator; + typedef const eT* const_row_iterator; + + inline row_iterator begin_row(const uword row_num); + inline const_row_iterator begin_row(const uword row_num) const; + + inline row_iterator end_row (const uword row_num); + inline const_row_iterator end_row (const uword row_num) const; + + + template class fixed; + + + protected: + + inline Row(const arma_fixed_indicator&, const uword in_n_elem, const eT* in_mem); + + + public: + + #if defined(ARMA_EXTRA_ROW_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_ROW_PROTO) + #endif + }; + + + +template +template +class Row::fixed : public Row + { + private: + + static constexpr bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); + + arma_align_mem eT mem_local_extra[ (use_extra) ? fixed_n_elem : 1 ]; + + + public: + + typedef fixed Row_fixed_type; + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_col = false; + static constexpr bool is_row = true; + static constexpr bool is_xvec = false; + + static const uword n_rows; // value provided below the class definition + static const uword n_cols; // value provided below the class definition + static const uword n_elem; // value provided below the class definition + + arma_inline fixed(); + arma_inline fixed(const fixed& X); + inline fixed(const subview_cube& X); + + inline fixed(const fill::scalar_holder f); + template inline fixed(const fill::fill_class& f); + template inline fixed(const Base& A); + template inline fixed(const Base& A, const Base& B); + + inline fixed(const eT* aux_mem); + + inline fixed(const char* text); + inline fixed(const std::string& text); + + template inline Row& operator=(const Base& A); + + inline Row& operator=(const eT val); + inline Row& operator=(const char* text); + inline Row& operator=(const std::string& text); + inline Row& operator=(const subview_cube& X); + + using Row::operator(); + + inline fixed(const std::initializer_list& list); + inline Row& operator=(const std::initializer_list& list); + + arma_inline Row& operator=(const fixed& X); + + #if defined(ARMA_GOOD_COMPILER) + template inline Row& operator=(const eOp& X); + template inline Row& operator=(const eGlue& X); + #endif + + arma_warn_unused arma_inline const Op< Row_fixed_type, op_htrans > t() const; + arma_warn_unused arma_inline const Op< Row_fixed_type, op_htrans > ht() const; + arma_warn_unused arma_inline const Op< Row_fixed_type, op_strans > st() const; + + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; + + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; + + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; + + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; + + inline const Row& fill(const eT val); + inline const Row& zeros(); + inline const Row& ones(); + }; + + + +// these definitions are outside of the class due to bizarre C++ rules; +// C++17 has inline variables to address this shortcoming + +template +template +const uword Row::fixed::n_rows = 1u; + +template +template +const uword Row::fixed::n_cols = fixed_n_elem; + +template +template +const uword Row::fixed::n_elem = fixed_n_elem; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/Row_meat.hpp b/src/armadillo/include/armadillo_bits/Row_meat.hpp new file mode 100644 index 0000000..f61a179 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/Row_meat.hpp @@ -0,0 +1,1888 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup Row +//! @{ + + +//! construct an empty row vector +template +inline +Row::Row() + : Mat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +Row::Row(const Row& X) + : Mat(arma_vec_indicator(), 1, X.n_elem, 2) + { + arma_extra_debug_sigprint(); + + arrayops::copy((*this).memptr(), X.memptr(), X.n_elem); + } + + + +//! construct a row vector with the specified number of n_elem +template +inline +Row::Row(const uword in_n_elem) + : Mat(arma_vec_indicator(), 1, in_n_elem, 2) + { + arma_extra_debug_sigprint(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +template +inline +Row::Row(const uword in_n_rows, const uword in_n_cols) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +template +inline +Row::Row(const SizeMat& s) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Row::Row(const uword in_n_elem, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 1, in_n_elem, 2) + { + arma_extra_debug_sigprint(); + + if(do_zeros) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Row::Row(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Row::Row(const SizeMat& s, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +template +template +inline +Row::Row(const uword in_n_elem, const fill::fill_class& f) + : Mat(arma_vec_indicator(), 1, in_n_elem, 2) + { + arma_extra_debug_sigprint(); + + (*this).fill(f); + } + + + +template +template +inline +Row::Row(const uword in_n_rows, const uword in_n_cols, const fill::fill_class& f) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + (*this).fill(f); + } + + + +template +template +inline +Row::Row(const SizeMat& s, const fill::fill_class& f) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + (*this).fill(f); + } + + + +template +inline +Row::Row(const uword in_n_elem, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 1, in_n_elem, 2) + { + arma_extra_debug_sigprint(); + + (*this).fill(f.scalar); + } + + + +template +inline +Row::Row(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + (*this).fill(f.scalar); + } + + + +template +inline +Row::Row(const SizeMat& s, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + (*this).fill(f.scalar); + } + + + +template +inline +Row::Row(const char* text) + : Mat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + (*this).operator=(text); + } + + + +template +inline +Row& +Row::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + Mat tmp(text); + + arma_debug_check( ((tmp.n_elem > 0) && (tmp.is_vec() == false)), "Mat::init(): requested size is not compatible with row vector layout" ); + + access::rw(tmp.n_rows) = 1; + access::rw(tmp.n_cols) = tmp.n_elem; + + (*this).steal_mem(tmp); + + return *this; + } + + + +template +inline +Row::Row(const std::string& text) + : Mat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + (*this).operator=(text); + } + + + +template +inline +Row& +Row::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + Mat tmp(text); + + arma_debug_check( ((tmp.n_elem > 0) && (tmp.is_vec() == false)), "Mat::init(): requested size is not compatible with row vector layout" ); + + access::rw(tmp.n_rows) = 1; + access::rw(tmp.n_cols) = tmp.n_elem; + + (*this).steal_mem(tmp); + + return *this; + } + + + +//! create a row vector from std::vector +template +inline +Row::Row(const std::vector& x) + : Mat(arma_vec_indicator(), 1, uword(x.size()), 2) + { + arma_extra_debug_sigprint_this(this); + + const uword N = uword(x.size()); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } + } + + + +//! create a row vector from std::vector +template +inline +Row& +Row::operator=(const std::vector& x) + { + arma_extra_debug_sigprint(); + + const uword N = uword(x.size()); + + Mat::init_warm(1, N); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } + + return *this; + } + + + +template +inline +Row::Row(const std::initializer_list& list) + : Mat(arma_vec_indicator(), 1, uword(list.size()), 2) + { + arma_extra_debug_sigprint_this(this); + + const uword N = uword(list.size()); + + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + } + + + +template +inline +Row& +Row::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + Mat::init_warm(1, N); + + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + + return *this; + } + + + +template +inline +Row::Row(Row&& X) + : Mat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + access::rw(Mat::n_rows) = 1; + access::rw(Mat::n_cols) = X.n_cols; + access::rw(Mat::n_elem) = X.n_elem; + access::rw(Mat::n_alloc) = X.n_alloc; + + if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) + { + access::rw(Mat::mem_state) = X.mem_state; + access::rw(Mat::mem) = X.mem; + + access::rw(X.n_rows) = 1; + access::rw(X.n_cols) = 0; + access::rw(X.n_elem) = 0; + access::rw(X.n_alloc) = 0; + access::rw(X.mem_state) = 0; + access::rw(X.mem) = nullptr; + } + else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) + { + (*this).init_cold(); + + arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); + + if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) + { + access::rw(X.n_rows) = 1; + access::rw(X.n_cols) = 0; + access::rw(X.n_elem) = 0; + access::rw(X.mem) = nullptr; + } + } + } + + + +template +inline +Row& +Row::operator=(Row&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + (*this).steal_mem(X, true); + + return *this; + } + + + +// template +// inline +// Row::Row(Mat&& X) +// : Mat(arma_vec_indicator(), 2) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_rows != 1) { const Mat& XX = X; Mat::operator=(XX); return; } +// +// access::rw(Mat::n_rows) = 1; +// access::rw(Mat::n_cols) = X.n_cols; +// access::rw(Mat::n_elem) = X.n_elem; +// access::rw(Mat::n_alloc) = X.n_alloc; +// +// if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) +// { +// access::rw(Mat::mem_state) = X.mem_state; +// access::rw(Mat::mem) = X.mem; +// +// access::rw(X.n_cols) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.n_alloc) = 0; +// access::rw(X.mem_state) = 0; +// access::rw(X.mem) = nullptr; +// } +// else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) +// { +// (*this).init_cold(); +// +// arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); +// +// if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) +// { +// access::rw(X.n_cols) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.mem) = nullptr; +// } +// } +// } +// +// +// +// template +// inline +// Row& +// Row::operator=(Mat&& X) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_rows != 1) { const Mat& XX = X; Mat::operator=(XX); return *this; } +// +// (*this).steal_mem(X, true); +// +// return *this; +// } + + + +template +inline +Row& +Row::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + Mat::operator=(val); + + return *this; + } + + + +template +inline +Row& +Row::operator=(const Row& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X); + + return *this; + } + + + +template +template +inline +Row::Row(const Base& X) + : Mat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X.get_ref()); + } + + + +template +template +inline +Row& +Row::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X.get_ref()); + + return *this; + } + + + +template +template +inline +Row::Row(const SpBase& X) + : Mat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X.get_ref()); + } + + + +template +template +inline +Row& +Row::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X.get_ref()); + + return *this; + } + + + +//! construct a row vector from a given auxiliary array +template +inline +Row::Row(eT* aux_mem, const uword aux_length, const bool copy_aux_mem, const bool strict) + : Mat(aux_mem, 1, aux_length, copy_aux_mem, strict) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 2; + } + + + +//! construct a row vector from a given auxiliary array +template +inline +Row::Row(const eT* aux_mem, const uword aux_length) + : Mat(aux_mem, 1, aux_length) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 2; + } + + + +template +template +inline +Row::Row + ( + const Base::pod_type, T1>& A, + const Base::pod_type, T2>& B + ) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 2; + + Mat::init(A,B); + } + + + +template +template +inline +Row::Row(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 2; + + Mat::operator=(X); + } + + + +template +template +inline +Row& +Row::operator=(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X); + + return *this; + } + + + +template +inline +Row::Row(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + access::rw(Mat::vec_state) = 2; + + Mat::operator=(X); + } + + + +template +inline +Row& +Row::operator=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + Mat::operator=(X); + + return *this; + } + + + +template +inline +mat_injector< Row > +Row::operator<<(const eT val) + { + return mat_injector< Row >(*this, val); + } + + + +template +arma_inline +const Op,op_htrans> +Row::t() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_htrans> +Row::ht() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +Row::st() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +Row::as_col() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +subview_row +Row::col(const uword in_col1) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_col1 >= Mat::n_cols), "Row::col(): indices out of bounds or incorrectly used" ); + + return subview_row(*this, 0, in_col1, 1); + } + + + +template +arma_inline +const subview_row +Row::col(const uword in_col1) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (in_col1 >= Mat::n_cols), "Row::col(): indices out of bounds or incorrectly used" ); + + return subview_row(*this, 0, in_col1, 1); + } + + + +template +arma_inline +subview_row +Row::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::cols(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview_row(*this, 0, in_col1, subview_n_cols); + } + + + +template +arma_inline +const subview_row +Row::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::cols(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview_row(*this, 0, in_col1, subview_n_cols); + } + + + +template +arma_inline +subview_row +Row::subvec(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview_row(*this, 0, in_col1, subview_n_cols); + } + + + +template +arma_inline +const subview_row +Row::subvec(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return subview_row(*this, 0, in_col1, subview_n_cols); + } + + + +template +arma_inline +subview_row +Row::cols(const span& col_span) + { + arma_extra_debug_sigprint(); + + return subvec(col_span); + } + + + +template +arma_inline +const subview_row +Row::cols(const span& col_span) const + { + arma_extra_debug_sigprint(); + + return subvec(col_span); + } + + + +template +arma_inline +subview_row +Row::subvec(const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = Mat::n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword subvec_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ), "Row::subvec(): indices out of bounds or incorrectly used" ); + + return subview_row(*this, 0, in_col1, subvec_n_cols); + } + + + +template +arma_inline +const subview_row +Row::subvec(const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = Mat::n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword subvec_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ), "Row::subvec(): indices out of bounds or incorrectly used" ); + + return subview_row(*this, 0, in_col1, subvec_n_cols); + } + + + +template +arma_inline +subview_row +Row::operator()(const span& col_span) + { + arma_extra_debug_sigprint(); + + return subvec(col_span); + } + + + +template +arma_inline +const subview_row +Row::operator()(const span& col_span) const + { + arma_extra_debug_sigprint(); + + return subvec(col_span); + } + + + +template +arma_inline +subview_row +Row::subvec(const uword start_col, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_rows != 1), "Row::subvec(): given size does not specify a row vector" ); + + arma_debug_check_bounds( ( (start_col >= Mat::n_cols) || ((start_col + s.n_cols) > Mat::n_cols) ), "Row::subvec(): size out of bounds" ); + + return subview_row(*this, 0, start_col, s.n_cols); + } + + + +template +arma_inline +const subview_row +Row::subvec(const uword start_col, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_rows != 1), "Row::subvec(): given size does not specify a row vector" ); + + arma_debug_check_bounds( ( (start_col >= Mat::n_cols) || ((start_col + s.n_cols) > Mat::n_cols) ), "Row::subvec(): size out of bounds" ); + + return subview_row(*this, 0, start_col, s.n_cols); + } + + + +template +arma_inline +subview_row +Row::head(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_cols), "Row::head(): size out of bounds" ); + + return subview_row(*this, 0, 0, N); + } + + + +template +arma_inline +const subview_row +Row::head(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_cols), "Row::head(): size out of bounds" ); + + return subview_row(*this, 0, 0, N); + } + + + +template +arma_inline +subview_row +Row::tail(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_cols), "Row::tail(): size out of bounds" ); + + const uword start_col = Mat::n_cols - N; + + return subview_row(*this, 0, start_col, N); + } + + + +template +arma_inline +const subview_row +Row::tail(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > Mat::n_cols), "Row::tail(): size out of bounds" ); + + const uword start_col = Mat::n_cols - N; + + return subview_row(*this, 0, start_col, N); + } + + + +template +arma_inline +subview_row +Row::head_cols(const uword N) + { + arma_extra_debug_sigprint(); + + return (*this).head(N); + } + + + +template +arma_inline +const subview_row +Row::head_cols(const uword N) const + { + arma_extra_debug_sigprint(); + + return (*this).head(N); + } + + + +template +arma_inline +subview_row +Row::tail_cols(const uword N) + { + arma_extra_debug_sigprint(); + + return (*this).tail(N); + } + + + +template +arma_inline +const subview_row +Row::tail_cols(const uword N) const + { + arma_extra_debug_sigprint(); + + return (*this).tail(N); + } + + + +//! remove specified columns +template +inline +void +Row::shed_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= Mat::n_cols, "Row::shed_col(): index out of bounds" ); + + shed_cols(col_num, col_num); + } + + + +//! remove specified columns +template +inline +void +Row::shed_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= Mat::n_cols), + "Row::shed_cols(): indices out of bounds or incorrectly used" + ); + + const uword n_keep_front = in_col1; + const uword n_keep_back = Mat::n_cols - (in_col2 + 1); + + Row X(n_keep_front + n_keep_back, arma_nozeros_indicator()); + + eT* X_mem = X.memptr(); + const eT* t_mem = (*this).memptr(); + + if(n_keep_front > 0) + { + arrayops::copy( X_mem, t_mem, n_keep_front ); + } + + if(n_keep_back > 0) + { + arrayops::copy( &(X_mem[n_keep_front]), &(t_mem[in_col2+1]), n_keep_back); + } + + Mat::steal_mem(X); + } + + + +//! remove specified columns +template +template +inline +void +Row::shed_cols(const Base& indices) + { + arma_extra_debug_sigprint(); + + Mat::shed_cols(indices); + } + + + +template +inline +void +Row::insert_cols(const uword col_num, const uword N, const bool set_to_zero) + { + arma_extra_debug_sigprint(); + + arma_ignore(set_to_zero); + + (*this).insert_cols(col_num, N); + } + + + +template +inline +void +Row::insert_cols(const uword col_num, const uword N) + { + arma_extra_debug_sigprint(); + + const uword t_n_cols = Mat::n_cols; + + const uword A_n_cols = col_num; + const uword B_n_cols = t_n_cols - col_num; + + // insertion at col_num == n_cols is in effect an append operation + arma_debug_check_bounds( (col_num > t_n_cols), "Row::insert_cols(): index out of bounds" ); + + if(N == 0) { return; } + + Row out(t_n_cols + N, arma_nozeros_indicator()); + + eT* out_mem = out.memptr(); + const eT* t_mem = (*this).memptr(); + + if(A_n_cols > 0) + { + arrayops::copy( out_mem, t_mem, A_n_cols ); + } + + if(B_n_cols > 0) + { + arrayops::copy( &(out_mem[col_num + N]), &(t_mem[col_num]), B_n_cols ); + } + + arrayops::fill_zeros( &(out_mem[col_num]), N ); + + Mat::steal_mem(out); + } + + + +//! insert the given object at the specified col position; +//! the given object must have one row +template +template +inline +void +Row::insert_cols(const uword col_num, const Base& X) + { + arma_extra_debug_sigprint(); + + Mat::insert_cols(col_num, X); + } + + + +template +arma_inline +eT& +Row::at(const uword i) + { + return access::rw(Mat::mem[i]); + } + + + +template +arma_inline +const eT& +Row::at(const uword i) const + { + return Mat::mem[i]; + } + + + +template +arma_inline +eT& +Row::at(const uword, const uword in_col) + { + return access::rw( Mat::mem[in_col] ); + } + + + +template +arma_inline +const eT& +Row::at(const uword, const uword in_col) const + { + return Mat::mem[in_col]; + } + + + +template +inline +typename Row::row_iterator +Row::begin_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::begin_row(): index out of bounds" ); + + return Mat::memptr(); + } + + + +template +inline +typename Row::const_row_iterator +Row::begin_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::begin_row(): index out of bounds" ); + + return Mat::memptr(); + } + + + +template +inline +typename Row::row_iterator +Row::end_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::end_row(): index out of bounds" ); + + return Mat::memptr() + Mat::n_cols; + } + + + +template +inline +typename Row::const_row_iterator +Row::end_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::end_row(): index out of bounds" ); + + return Mat::memptr() + Mat::n_cols; + } + + + +template +template +arma_inline +Row::fixed::fixed() + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + } + } + + + +template +template +arma_inline +Row::fixed::fixed(const fixed& X) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + eT* dest = (use_extra) ? mem_local_extra : Mat::mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + } + + + +template +template +arma_inline +Row::fixed::fixed(const subview_cube& X) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Row::operator=(X); + } + + + +template +template +inline +Row::fixed::fixed(const fill::scalar_holder f) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).fill(f.scalar); + } + + + +template +template +template +inline +Row::fixed::fixed(const fill::fill_class&) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } + } + + + +template +template +template +arma_inline +Row::fixed::fixed(const Base& A) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Row::operator=(A.get_ref()); + } + + + +template +template +template +arma_inline +Row::fixed::fixed(const Base& A, const Base& B) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Row::init(A,B); + } + + + +template +template +inline +Row::fixed::fixed(const eT* aux_mem) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + eT* dest = (use_extra) ? mem_local_extra : Mat::mem_local; + + arrayops::copy( dest, aux_mem, fixed_n_elem ); + } + + + +template +template +inline +Row::fixed::fixed(const char* text) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Row::operator=(text); + } + + + +template +template +inline +Row::fixed::fixed(const std::string& text) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + Row::operator=(text); + } + + + +template +template +template +Row& +Row::fixed::operator=(const Base& A) + { + arma_extra_debug_sigprint(); + + Row::operator=(A.get_ref()); + + return *this; + } + + + +template +template +Row& +Row::fixed::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + Row::operator=(val); + + return *this; + } + + + +template +template +Row& +Row::fixed::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + Row::operator=(text); + + return *this; + } + + + +template +template +Row& +Row::fixed::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + Row::operator=(text); + + return *this; + } + + + +template +template +Row& +Row::fixed::operator=(const subview_cube& X) + { + arma_extra_debug_sigprint(); + + Row::operator=(X); + + return *this; + } + + + +template +template +inline +Row::fixed::fixed(const std::initializer_list& list) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(list); + } + + + +template +template +inline +Row& +Row::fixed::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + arma_debug_check( (N > fixed_n_elem), "Row::fixed: initialiser list is too long" ); + + eT* this_mem = (*this).memptr(); + + arrayops::copy( this_mem, list.begin(), N ); + + for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } + + return *this; + } + + + +template +template +arma_inline +Row& +Row::fixed::operator=(const fixed& X) + { + arma_extra_debug_sigprint(); + + if(this != &X) + { + eT* dest = (use_extra) ? mem_local_extra : Mat::mem_local; + const eT* src = (use_extra) ? X.mem_local_extra : X.mem_local; + + arrayops::copy( dest, src, fixed_n_elem ); + } + + return *this; + } + + + +#if defined(ARMA_GOOD_COMPILER) + + template + template + template + inline + Row& + Row::fixed::operator=(const eOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias == false) + { + arma_debug_assert_same_size(uword(1), fixed_n_elem, X.get_n_rows(), X.get_n_cols(), "Row::fixed::operator="); + + eop_type::apply(*this, X); + } + else + { + arma_extra_debug_print("bad_alias = true"); + + Row tmp(X); + + (*this) = tmp; + } + + return *this; + } + + + + template + template + template + inline + Row& + Row::fixed::operator=(const eGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias == false) + { + arma_debug_assert_same_size(uword(1), fixed_n_elem, X.get_n_rows(), X.get_n_cols(), "Row::fixed::operator="); + + eglue_type::apply(*this, X); + } + else + { + arma_extra_debug_print("bad_alias = true"); + + Row tmp(X); + + (*this) = tmp; + } + + return *this; + } + +#endif + + + +template +template +arma_inline +const Op< typename Row::template fixed::Row_fixed_type, op_htrans > +Row::fixed::t() const + { + return Op< typename Row::template fixed::Row_fixed_type, op_htrans >(*this); + } + + + +template +template +arma_inline +const Op< typename Row::template fixed::Row_fixed_type, op_htrans > +Row::fixed::ht() const + { + return Op< typename Row::template fixed::Row_fixed_type, op_htrans >(*this); + } + + + +template +template +arma_inline +const Op< typename Row::template fixed::Row_fixed_type, op_strans > +Row::fixed::st() const + { + return Op< typename Row::template fixed::Row_fixed_type, op_strans >(*this); + } + + + +template +template +arma_inline +const eT& +Row::fixed::at_alt(const uword ii) const + { + #if defined(ARMA_HAVE_ALIGNED_ATTRIBUTE) + + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + + #else + const eT* mem_aligned = (use_extra) ? mem_local_extra : Mat::mem_local; + + memory::mark_as_aligned(mem_aligned); + + return mem_aligned[ii]; + #endif + } + + + +template +template +arma_inline +eT& +Row::fixed::operator[] (const uword ii) + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Row::fixed::operator[] (const uword ii) const + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Row::fixed::at(const uword ii) + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Row::fixed::at(const uword ii) const + { + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Row::fixed::operator() (const uword ii) + { + arma_debug_check_bounds( (ii >= fixed_n_elem), "Row::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +const eT& +Row::fixed::operator() (const uword ii) const + { + arma_debug_check_bounds( (ii >= fixed_n_elem), "Row::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; + } + + + +template +template +arma_inline +eT& +Row::fixed::at(const uword, const uword in_col) + { + return (use_extra) ? mem_local_extra[in_col] : Mat::mem_local[in_col]; + } + + + +template +template +arma_inline +const eT& +Row::fixed::at(const uword, const uword in_col) const + { + return (use_extra) ? mem_local_extra[in_col] : Mat::mem_local[in_col]; + } + + + +template +template +arma_inline +eT& +Row::fixed::operator() (const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row > 0) || (in_col >= fixed_n_elem)), "Row::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[in_col] : Mat::mem_local[in_col]; + } + + + +template +template +arma_inline +const eT& +Row::fixed::operator() (const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row > 0) || (in_col >= fixed_n_elem)), "Row::operator(): index out of bounds" ); + + return (use_extra) ? mem_local_extra[in_col] : Mat::mem_local[in_col]; + } + + + +template +template +arma_inline +eT* +Row::fixed::memptr() + { + return (use_extra) ? mem_local_extra : Mat::mem_local; + } + + + +template +template +arma_inline +const eT* +Row::fixed::memptr() const + { + return (use_extra) ? mem_local_extra : Mat::mem_local; + } + + + +template +template +inline +const Row& +Row::fixed::fill(const eT val) + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, val ); + + return *this; + } + + + +template +template +inline +const Row& +Row::fixed::zeros() + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + + return *this; + } + + + +template +template +inline +const Row& +Row::fixed::ones() + { + arma_extra_debug_sigprint(); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(1) ); + + return *this; + } + + + +template +inline +Row::Row(const arma_fixed_indicator&, const uword in_n_elem, const eT* in_mem) + : Mat(arma_fixed_indicator(), 1, in_n_elem, 2, in_mem) + { + arma_extra_debug_sigprint_this(this); + } + + + +#if defined(ARMA_EXTRA_ROW_MEAT) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_ROW_MEAT) +#endif + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SizeCube_bones.hpp b/src/armadillo/include/armadillo_bits/SizeCube_bones.hpp new file mode 100644 index 0000000..96b26af --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SizeCube_bones.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SizeCube +//! @{ + + + +class SizeCube + { + public: + + const uword n_rows; + const uword n_cols; + const uword n_slices; + + inline explicit SizeCube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); + + inline uword operator[](const uword dim) const; + inline uword operator()(const uword dim) const; + + inline bool operator==(const SizeCube& s) const; + inline bool operator!=(const SizeCube& s) const; + + inline SizeCube operator+(const SizeCube& s) const; + inline SizeCube operator-(const SizeCube& s) const; + + inline SizeCube operator+(const uword val) const; + inline SizeCube operator-(const uword val) const; + + inline SizeCube operator*(const uword val) const; + inline SizeCube operator/(const uword val) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SizeCube_meat.hpp b/src/armadillo/include/armadillo_bits/SizeCube_meat.hpp new file mode 100644 index 0000000..8354ca1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SizeCube_meat.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SizeCube +//! @{ + + + +inline +SizeCube::SizeCube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) + : n_rows (in_n_rows ) + , n_cols (in_n_cols ) + , n_slices(in_n_slices) + { + arma_extra_debug_sigprint(); + } + + + +inline +uword +SizeCube::operator[](const uword dim) const + { + if(dim == 0) { return n_rows; } + if(dim == 1) { return n_cols; } + if(dim == 2) { return n_slices; } + + return uword(1); + } + + + +inline +uword +SizeCube::operator()(const uword dim) const + { + if(dim == 0) { return n_rows; } + if(dim == 1) { return n_cols; } + if(dim == 2) { return n_slices; } + + arma_debug_check_bounds(true, "size(): index out of bounds"); + + return uword(1); + } + + + +inline +bool +SizeCube::operator==(const SizeCube& s) const + { + if(n_rows != s.n_rows ) { return false; } + + if(n_cols != s.n_cols ) { return false; } + + if(n_slices != s.n_slices) { return false; } + + return true; + } + + + +inline +bool +SizeCube::operator!=(const SizeCube& s) const + { + if(n_rows != s.n_rows ) { return true; } + + if(n_cols != s.n_cols ) { return true; } + + if(n_slices != s.n_slices) { return true; } + + return false; + } + + + +inline +SizeCube +SizeCube::operator+(const SizeCube& s) const + { + return SizeCube( (n_rows + s.n_rows), (n_cols + s.n_cols), (n_slices + s.n_slices) ); + } + + + +inline +SizeCube +SizeCube::operator-(const SizeCube& s) const + { + const uword out_n_rows = (n_rows > s.n_rows ) ? (n_rows - s.n_rows ) : uword(0); + const uword out_n_cols = (n_cols > s.n_cols ) ? (n_cols - s.n_cols ) : uword(0); + const uword out_n_slices = (n_slices > s.n_slices) ? (n_slices - s.n_slices) : uword(0); + + return SizeCube(out_n_rows, out_n_cols, out_n_slices); + } + + + +inline +SizeCube +SizeCube::operator+(const uword val) const + { + return SizeCube( (n_rows + val), (n_cols + val), (n_slices + val) ); + } + + + +inline +SizeCube +SizeCube::operator-(const uword val) const + { + const uword out_n_rows = (n_rows > val) ? (n_rows - val) : uword(0); + const uword out_n_cols = (n_cols > val) ? (n_cols - val) : uword(0); + const uword out_n_slices = (n_slices > val) ? (n_slices - val) : uword(0); + + return SizeCube(out_n_rows, out_n_cols, out_n_slices); + } + + + +inline +SizeCube +SizeCube::operator*(const uword val) const + { + return SizeCube( (n_rows * val), (n_cols * val), (n_slices * val) ); + } + + + +inline +SizeCube +SizeCube::operator/(const uword val) const + { + return SizeCube( (n_rows / val), (n_cols / val), (n_slices / val) ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SizeMat_bones.hpp b/src/armadillo/include/armadillo_bits/SizeMat_bones.hpp new file mode 100644 index 0000000..6139d33 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SizeMat_bones.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SizeMat +//! @{ + + + +class SizeMat + { + public: + + const uword n_rows; + const uword n_cols; + + inline explicit SizeMat(const uword in_n_rows, const uword in_n_cols); + + inline uword operator[](const uword dim) const; + inline uword operator()(const uword dim) const; + + inline bool operator==(const SizeMat& s) const; + inline bool operator!=(const SizeMat& s) const; + + inline SizeMat operator+(const SizeMat& s) const; + inline SizeMat operator-(const SizeMat& s) const; + + inline SizeMat operator+(const uword val) const; + inline SizeMat operator-(const uword val) const; + + inline SizeMat operator*(const uword val) const; + inline SizeMat operator/(const uword val) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SizeMat_meat.hpp b/src/armadillo/include/armadillo_bits/SizeMat_meat.hpp new file mode 100644 index 0000000..e00fd4f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SizeMat_meat.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SizeMat +//! @{ + + + +inline +SizeMat::SizeMat(const uword in_n_rows, const uword in_n_cols) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + { + arma_extra_debug_sigprint(); + } + + + +inline +uword +SizeMat::operator[](const uword dim) const + { + if(dim == 0) { return n_rows; } + if(dim == 1) { return n_cols; } + + return uword(1); + } + + + +inline +uword +SizeMat::operator()(const uword dim) const + { + if(dim == 0) { return n_rows; } + if(dim == 1) { return n_cols; } + + arma_debug_check_bounds(true, "size(): index out of bounds"); + + return uword(1); + } + + + +inline +bool +SizeMat::operator==(const SizeMat& s) const + { + if(n_rows != s.n_rows) { return false; } + + if(n_cols != s.n_cols) { return false; } + + return true; + } + + + +inline +bool +SizeMat::operator!=(const SizeMat& s) const + { + if(n_rows != s.n_rows) { return true; } + + if(n_cols != s.n_cols) { return true; } + + return false; + } + + + +inline +SizeMat +SizeMat::operator+(const SizeMat& s) const + { + return SizeMat( (n_rows + s.n_rows), (n_cols + s.n_cols) ); + } + + + +inline +SizeMat +SizeMat::operator-(const SizeMat& s) const + { + const uword out_n_rows = (n_rows > s.n_rows) ? (n_rows - s.n_rows) : uword(0); + const uword out_n_cols = (n_cols > s.n_cols) ? (n_cols - s.n_cols) : uword(0); + + return SizeMat(out_n_rows, out_n_cols); + } + + + +inline +SizeMat +SizeMat::operator+(const uword val) const + { + return SizeMat( (n_rows + val), (n_cols + val) ); + } + + + +inline +SizeMat +SizeMat::operator-(const uword val) const + { + const uword out_n_rows = (n_rows > val) ? (n_rows - val) : uword(0); + const uword out_n_cols = (n_cols > val) ? (n_cols - val) : uword(0); + + return SizeMat(out_n_rows, out_n_cols); + } + + + +inline +SizeMat +SizeMat::operator*(const uword val) const + { + return SizeMat( (n_rows * val), (n_cols * val) ); + } + + + +inline +SizeMat +SizeMat::operator/(const uword val) const + { + return SizeMat( (n_rows / val), (n_cols / val) ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpBase_bones.hpp b/src/armadillo/include/armadillo_bits/SpBase_bones.hpp new file mode 100644 index 0000000..d16bf47 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpBase_bones.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpBase +//! @{ + + + +template +struct SpBase_eval_SpMat + { + arma_warn_unused inline const derived& eval() const; + }; + + +template +struct SpBase_eval_expr + { + arma_warn_unused inline SpMat eval() const; //!< force the immediate evaluation of a delayed expression + }; + + +template +struct SpBase_eval {}; + +template +struct SpBase_eval { typedef SpBase_eval_SpMat result; }; + +template +struct SpBase_eval { typedef SpBase_eval_expr result; }; + + + +template +struct SpBase + : public SpBase_eval::value>::result + { + arma_inline const derived& get_ref() const; + + arma_inline bool is_alias(const SpMat& X) const; + + arma_warn_unused inline const SpOp t() const; //!< Hermitian transpose + arma_warn_unused inline const SpOp ht() const; //!< Hermitian transpose + arma_warn_unused inline const SpOp st() const; //!< simple transpose + + arma_cold inline void print( const std::string extra_text = "") const; + arma_cold inline void print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void raw_print( const std::string extra_text = "") const; + arma_cold inline void raw_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void print_dense( const std::string extra_text = "") const; + arma_cold inline void print_dense(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void raw_print_dense( const std::string extra_text = "") const; + arma_cold inline void raw_print_dense(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_cold inline void brief_print( const std::string extra_text = "") const; + arma_cold inline void brief_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_warn_unused inline elem_type min() const; + arma_warn_unused inline elem_type max() const; + + inline elem_type min(uword& index_of_min_val) const; + inline elem_type max(uword& index_of_max_val) const; + + inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; + inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; + + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + arma_warn_unused inline bool is_symmetric() const; + arma_warn_unused inline bool is_symmetric(const typename get_pod_type::result tol) const; + + arma_warn_unused inline bool is_hermitian() const; + arma_warn_unused inline bool is_hermitian(const typename get_pod_type::result tol) const; + + arma_warn_unused inline bool is_zero(const typename get_pod_type::result tol = 0) const; + + arma_warn_unused inline bool is_trimatu() const; + arma_warn_unused inline bool is_trimatl() const; + arma_warn_unused inline bool is_diagmat() const; + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_square() const; + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_colvec() const; + arma_warn_unused inline bool is_rowvec() const; + arma_warn_unused inline bool is_finite() const; + + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; + + arma_warn_unused inline const SpOp as_col() const; + arma_warn_unused inline const SpOp as_row() const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpBase_meat.hpp b/src/armadillo/include/armadillo_bits/SpBase_meat.hpp new file mode 100644 index 0000000..4ebc424 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpBase_meat.hpp @@ -0,0 +1,883 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpBase +//! @{ + + + +template +arma_inline +const derived& +SpBase::get_ref() const + { + return static_cast(*this); + } + + + +template +arma_inline +bool +SpBase::is_alias(const SpMat& X) const + { + return (*this).get_ref().is_alias(X); + } + + + +template +inline +const SpOp +SpBase::t() const + { + return SpOp( (*this).get_ref() ); + } + + +template +inline +const SpOp +SpBase::ht() const + { + return SpOp( (*this).get_ref() ); + } + + + +template +inline +const SpOp +SpBase::st() const + { + return SpOp( (*this).get_ref() ); + } + + + +template +inline +void +SpBase::print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, true); + } + + + +template +inline +void +SpBase::print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, true); + } + + + +template +inline +void +SpBase::raw_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, false); + } + + + +template +inline +void +SpBase::raw_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, false); + } + + + +template +inline +void +SpBase::print_dense(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print_dense(get_cout_stream(), tmp.M, true); + } + + + +template +inline +void +SpBase::print_dense(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print_dense(user_stream, tmp.M, true); + } + + + +template +inline +void +SpBase::raw_print_dense(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print_dense(get_cout_stream(), tmp.M, false); + } + + + +template +inline +void +SpBase::raw_print_dense(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print_dense(user_stream, tmp.M, false); + } + + + +template +inline +void +SpBase::brief_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::brief_print(get_cout_stream(), tmp.M); + } + + + +template +inline +void +SpBase::brief_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::brief_print(user_stream, tmp.M); + } + + + +// +// extra functions defined in SpBase_eval_SpMat + +template +inline +const derived& +SpBase_eval_SpMat::eval() const + { + arma_extra_debug_sigprint(); + + return static_cast(*this); + } + + + +// +// extra functions defined in SpBase_eval_expr + +template +inline +SpMat +SpBase_eval_expr::eval() const + { + arma_extra_debug_sigprint(); + + return SpMat( static_cast(*this) ); + } + + + +template +inline +elem_type +SpBase::min() const + { + return spop_min::min( (*this).get_ref() ); + } + + + +template +inline +elem_type +SpBase::max() const + { + return spop_max::max( (*this).get_ref() ); + } + + + +template +inline +elem_type +SpBase::min(uword& index_of_min_val) const + { + const SpProxy P( (*this).get_ref() ); + + return spop_min::min_with_index(P, index_of_min_val); + } + + + +template +inline +elem_type +SpBase::max(uword& index_of_max_val) const + { + const SpProxy P( (*this).get_ref() ); + + return spop_max::max_with_index(P, index_of_max_val); + } + + + +template +inline +elem_type +SpBase::min(uword& row_of_min_val, uword& col_of_min_val) const + { + const SpProxy P( (*this).get_ref() ); + + uword index = 0; + + const elem_type val = spop_min::min_with_index(P, index); + + const uword local_n_rows = P.get_n_rows(); + + row_of_min_val = index % local_n_rows; + col_of_min_val = index / local_n_rows; + + return val; + } + + + +template +inline +elem_type +SpBase::max(uword& row_of_max_val, uword& col_of_max_val) const + { + const SpProxy P( (*this).get_ref() ); + + uword index = 0; + + const elem_type val = spop_max::max_with_index(P, index); + + const uword local_n_rows = P.get_n_rows(); + + row_of_max_val = index % local_n_rows; + col_of_max_val = index / local_n_rows; + + return val; + } + + + +template +inline +uword +SpBase::index_min() const + { + const SpProxy P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_min(): object has no elements"); + } + else + { + spop_min::min_with_index(P, index); + } + + return index; + } + + + +template +inline +uword +SpBase::index_max() const + { + const SpProxy P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_max(): object has no elements"); + } + else + { + spop_max::max_with_index(P, index); + } + + return index; + } + + + +template +inline +bool +SpBase::is_symmetric() const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + return tmp.M.is_symmetric(); + } + + + +template +inline +bool +SpBase::is_symmetric(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + return tmp.M.is_symmetric(tol); + } + + + +template +inline +bool +SpBase::is_hermitian() const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + return tmp.M.is_hermitian(); + } + + + +template +inline +bool +SpBase::is_hermitian(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + return tmp.M.is_hermitian(tol); + } + + + +template +inline +bool +SpBase::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + arma_debug_check( (tol < T(0)), "is_zero(): parameter 'tol' must be >= 0" ); + + const SpProxy P( (*this).get_ref() ); + + if(P.get_n_elem() == 0) { return false; } + + if(P.get_n_nonzero() == 0) { return true; } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat::stored_type> U(P.Q); + + return arrayops::is_zero(U.M.values, U.M.n_nonzero, tol); + } + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + if(is_cx::yes) + { + while(it != it_end) + { + const elem_type val = (*it); + + const T val_real = access::tmp_real(val); + const T val_imag = access::tmp_imag(val); + + if(eop_aux::arma_abs(val_real) > tol) { return false; } + if(eop_aux::arma_abs(val_imag) > tol) { return false; } + + ++it; + } + } + else // not complex + { + while(it != it_end) + { + if(eop_aux::arma_abs(*it) > tol) { return false; } + + ++it; + } + } + + return true; + } + + + +template +inline +bool +SpBase::is_trimatu() const + { + arma_extra_debug_sigprint(); + + const SpProxy P( (*this).get_ref() ); + + if(P.get_n_rows() != P.get_n_cols()) { return false; } + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(it.row() > it.col()) { return false; } + ++it; + } + + return true; + } + + + +template +inline +bool +SpBase::is_trimatl() const + { + arma_extra_debug_sigprint(); + + const SpProxy P( (*this).get_ref() ); + + if(P.get_n_rows() != P.get_n_cols()) { return false; } + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(it.row() < it.col()) { return false; } + ++it; + } + + return true; + } + + + +template +inline +bool +SpBase::is_diagmat() const + { + arma_extra_debug_sigprint(); + + const SpProxy P( (*this).get_ref() ); + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(it.row() != it.col()) { return false; } + ++it; + } + + return true; + } + + + +template +inline +bool +SpBase::is_empty() const + { + arma_extra_debug_sigprint(); + + const SpProxy P( (*this).get_ref() ); + + return (P.get_n_elem() == uword(0)); + } + + + +template +inline +bool +SpBase::is_square() const + { + arma_extra_debug_sigprint(); + + const SpProxy P( (*this).get_ref() ); + + return (P.get_n_rows() == P.get_n_cols()); + } + + + +template +inline +bool +SpBase::is_vec() const + { + arma_extra_debug_sigprint(); + + if( (SpProxy::is_row) || (SpProxy::is_col) || (SpProxy::is_xvec) ) { return true; } + + const SpProxy P( (*this).get_ref() ); + + return ( (P.get_n_rows() == uword(1)) || (P.get_n_cols() == uword(1)) ); + } + + + +template +inline +bool +SpBase::is_colvec() const + { + arma_extra_debug_sigprint(); + + if(SpProxy::is_col) { return true; } + + const SpProxy P( (*this).get_ref() ); + + return (P.get_n_cols() == uword(1)); + } + + + +template +inline +bool +SpBase::is_rowvec() const + { + arma_extra_debug_sigprint(); + + if(SpProxy::is_row) { return true; } + + const SpProxy P( (*this).get_ref() ); + + return (P.get_n_rows() == uword(1)); + } + + + +template +inline +bool +SpBase::is_finite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat U( (*this).get_ref() ); + + return U.M.internal_is_finite(); + } + else + { + const SpProxy P( (*this).get_ref() ); + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(arma_isfinite(*it) == false) { return false; } + ++it; + } + } + + return true; + } + + + +template +inline +bool +SpBase::has_inf() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat U( (*this).get_ref() ); + + return U.M.internal_has_inf(); + } + else + { + const SpProxy P( (*this).get_ref() ); + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(arma_isinf(*it)) { return true; } + ++it; + } + } + + return false; + } + + + +template +inline +bool +SpBase::has_nan() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat U( (*this).get_ref() ); + + return U.M.internal_has_nan(); + } + else + { + const SpProxy P( (*this).get_ref() ); + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(arma_isnan(*it)) { return true; } + ++it; + } + } + + return false; + } + + + +template +inline +bool +SpBase::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat U( (*this).get_ref() ); + + return U.M.internal_has_nonfinite(); + } + else + { + const SpProxy P( (*this).get_ref() ); + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(arma_isfinite(*it) == false) { return true; } + ++it; + } + } + + return false; + } + + + +template +inline +const SpOp +SpBase::as_col() const + { + return SpOp( (*this).get_ref() ); + } + + + +template +inline +const SpOp +SpBase::as_row() const + { + return SpOp( (*this).get_ref() ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpCol_bones.hpp b/src/armadillo/include/armadillo_bits/SpCol_bones.hpp new file mode 100644 index 0000000..b49f1a5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpCol_bones.hpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpCol +//! @{ + + +//! Class for sparse column vectors (matrices with only one column) +template +class SpCol : public SpMat + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + + inline SpCol(); + inline explicit SpCol(const uword n_elem); + inline explicit SpCol(const uword in_rows, const uword in_cols); + inline explicit SpCol(const SizeMat& s); + + inline SpCol(const char* text); + inline SpCol& operator=(const char* text); + + inline SpCol(const std::string& text); + inline SpCol& operator=(const std::string& text); + + inline SpCol& operator=(const eT val); + + template inline SpCol(const Base& X); + template inline SpCol& operator=(const Base& X); + + template inline SpCol(const SpBase& X); + template inline SpCol& operator=(const SpBase& X); + + template + inline explicit SpCol(const SpBase& A, const SpBase& B); + + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + + inline void shed_row (const uword row_num); + inline void shed_rows(const uword in_row1, const uword in_row2); + + // inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero = true); + + + typedef typename SpMat::iterator row_iterator; + typedef typename SpMat::const_iterator const_row_iterator; + + inline row_iterator begin_row(const uword row_num = 0); + inline const_row_iterator begin_row(const uword row_num = 0) const; + + inline row_iterator end_row (const uword row_num = 0); + inline const_row_iterator end_row (const uword row_num = 0) const; + + + #if defined(ARMA_EXTRA_SPCOL_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPCOL_PROTO) + #endif + }; diff --git a/src/armadillo/include/armadillo_bits/SpCol_meat.hpp b/src/armadillo/include/armadillo_bits/SpCol_meat.hpp new file mode 100644 index 0000000..9b3c824 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpCol_meat.hpp @@ -0,0 +1,432 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpCol +//! @{ + + + +template +inline +SpCol::SpCol() + : SpMat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpCol::SpCol(const uword in_n_elem) + : SpMat(arma_vec_indicator(), in_n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpCol::SpCol(const uword in_n_rows, const uword in_n_cols) + : SpMat(arma_vec_indicator(), in_n_rows, in_n_cols, 1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpCol::SpCol(const SizeMat& s) + : SpMat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + SpMat::init(s.n_rows, s.n_cols); + } + + + +template +inline +SpCol::SpCol(const char* text) + : SpMat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + SpMat::init(std::string(text)); + } + + + +template +inline +SpCol& +SpCol::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + SpMat::init(std::string(text)); + + return *this; + } + + + +template +inline +SpCol::SpCol(const std::string& text) + : SpMat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + SpMat::init(text); + } + + + +template +inline +SpCol& +SpCol::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + SpMat::init(text); + + return *this; + } + + + +template +inline +SpCol& +SpCol::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(val); + + return *this; + } + + + +template +template +inline +SpCol::SpCol(const Base& X) + : SpMat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + } + + + +template +template +inline +SpCol& +SpCol::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + + return *this; + } + + + +template +template +inline +SpCol::SpCol(const SpBase& X) + : SpMat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + } + + + +template +template +inline +SpCol& +SpCol::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + + return *this; + } + + + +template +template +inline +SpCol::SpCol + ( + const SpBase::pod_type, T1>& A, + const SpBase::pod_type, T2>& B + ) + : SpMat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(); + + SpMat::init(A,B); + } + + + +template +inline +const SpOp,spop_htrans> +SpCol::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpCol::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpCol::st() const + { + return SpOp,spop_strans>(*this); + } + + + +//! remove specified row +template +inline +void +SpCol::shed_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= SpMat::n_rows, "SpCol::shed_row(): out of bounds" ); + + shed_rows(row_num, row_num); + } + + + +//! remove specified rows +template +inline +void +SpCol::shed_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= SpMat::n_rows), + "SpCol::shed_rows(): indices out of bounds or incorrectly used" + ); + + SpMat::sync_csc(); + + const uword diff = (in_row2 - in_row1 + 1); + + // This is easy because everything is in one column. + uword start = 0, end = 0; + bool start_found = false, end_found = false; + for(uword i = 0; i < SpMat::n_nonzero; ++i) + { + // Start position found? + if(SpMat::row_indices[i] >= in_row1 && !start_found) + { + start = i; + start_found = true; + } + + // End position found? + if(SpMat::row_indices[i] > in_row2) + { + end = i; + end_found = true; + break; + } + } + + if(!end_found) + { + end = SpMat::n_nonzero; + } + + // Now we can make the copy. + if(start != end) + { + const uword elem_diff = end - start; + + eT* new_values = memory::acquire (SpMat::n_nonzero - elem_diff); + uword* new_row_indices = memory::acquire(SpMat::n_nonzero - elem_diff); + + // Copy before the section we are dropping (if it exists). + if(start > 0) + { + arrayops::copy(new_values, SpMat::values, start); + arrayops::copy(new_row_indices, SpMat::row_indices, start); + } + + // Copy after the section we are dropping (if it exists). + if(end != SpMat::n_nonzero) + { + arrayops::copy(new_values + start, SpMat::values + end, (SpMat::n_nonzero - end)); + arrayops::copy(new_row_indices + start, SpMat::row_indices + end, (SpMat::n_nonzero - end)); + arrayops::inplace_minus(new_row_indices + start, diff, (SpMat::n_nonzero - end)); + } + + memory::release(SpMat::values); + memory::release(SpMat::row_indices); + + access::rw(SpMat::values) = new_values; + access::rw(SpMat::row_indices) = new_row_indices; + + access::rw(SpMat::n_nonzero) -= elem_diff; + access::rw(SpMat::col_ptrs[1]) -= elem_diff; + } + + access::rw(SpMat::n_rows) -= diff; + access::rw(SpMat::n_elem) -= diff; + + SpMat::invalidate_cache(); + } + + + +// //! insert N rows at the specified row position, +// //! optionally setting the elements of the inserted rows to zero +// template +// inline +// void +// SpCol::insert_rows(const uword row_num, const uword N, const bool set_to_zero) +// { +// arma_extra_debug_sigprint(); +// +// arma_debug_check(set_to_zero == false, "SpCol::insert_rows(): cannot set nonzero values"); +// +// arma_debug_check_bounds((row_num > SpMat::n_rows), "SpCol::insert_rows(): out of bounds"); +// +// for(uword row = 0; row < SpMat::n_rows; ++row) +// { +// if(SpMat::row_indices[row] >= row_num) +// { +// access::rw(SpMat::row_indices[row]) += N; +// } +// } +// +// access::rw(SpMat::n_rows) += N; +// access::rw(SpMat::n_elem) += N; +// } + + + +template +inline +typename SpCol::row_iterator +SpCol::begin_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::begin_row(): index out of bounds" ); + + SpMat::sync_csc(); + + return row_iterator(*this, row_num, 0); + } + + + +template +inline +typename SpCol::const_row_iterator +SpCol::begin_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::begin_row(): index out of bounds" ); + + SpMat::sync_csc(); + + return const_row_iterator(*this, row_num, 0); + } + + + +template +inline +typename SpCol::row_iterator +SpCol::end_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::end_row(): index out of bounds" ); + + SpMat::sync_csc(); + + return row_iterator(*this, row_num + 1, 0); + } + + + +template +inline +typename SpCol::const_row_iterator +SpCol::end_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::end_row(): index out of bounds" ); + + SpMat::sync_csc(); + + return const_row_iterator(*this, row_num + 1, 0); + } + + + +#if defined(ARMA_EXTRA_SPCOL_MEAT) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPCOL_MEAT) +#endif + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpGlue_bones.hpp b/src/armadillo/include/armadillo_bits/SpGlue_bones.hpp new file mode 100644 index 0000000..3c5432d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpGlue_bones.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpGlue +//! @{ + + + +template +class SpGlue : public SpBase< typename T1::elem_type, SpGlue > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = spglue_type::template traits::is_row; + static constexpr bool is_col = spglue_type::template traits::is_col; + static constexpr bool is_xvec = spglue_type::template traits::is_xvec; + + inline SpGlue(const T1& in_A, const T2& in_B); + inline SpGlue(const T1& in_A, const T2& in_B, const elem_type in_aux); + inline ~SpGlue(); + + arma_inline bool is_alias(const SpMat& X) const; + + const T1& A; //!< first operand; must be derived from SpBase + const T2& B; //!< second operand; must be derived from SpBase + elem_type aux; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpGlue_meat.hpp b/src/armadillo/include/armadillo_bits/SpGlue_meat.hpp new file mode 100644 index 0000000..04d40e1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpGlue_meat.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpGlue +//! @{ + + + +template +inline +SpGlue::SpGlue(const T1& in_A, const T2& in_B) + : A(in_A) + , B(in_B) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpGlue::SpGlue(const T1& in_A, const T2& in_B, const typename T1::elem_type in_aux) + : A(in_A) + , B(in_B) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpGlue::~SpGlue() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +bool +SpGlue::is_alias(const SpMat& X) const + { + return (A.is_alias(X) || B.is_alias(X)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpMat_bones.hpp b/src/armadillo/include/armadillo_bits/SpMat_bones.hpp new file mode 100644 index 0000000..e96f153 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpMat_bones.hpp @@ -0,0 +1,747 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpMat +//! @{ + + +//! Sparse matrix class, with data stored in compressed sparse column (CSC) format +template +class SpMat : public SpBase< eT, SpMat > + { + public: + + typedef eT elem_type; //!< the type of elements stored in the matrix + typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const uword n_rows; //!< number of rows (read-only) + const uword n_cols; //!< number of columns (read-only) + const uword n_elem; //!< number of elements (read-only) + const uword n_nonzero; //!< number of nonzero elements (read-only) + const uword vec_state; //!< 0: matrix; 1: column vector; 2: row vector + + + // The memory used to store the values of the matrix. + // In accordance with the CSC format, this stores only the actual values. + // The correct locations of the values are assembled from the row indices and column pointers. + // + // The length of this array is (n_nonzero + 1). + // The final value values[n_nonzero] must be zero to ensure integrity of iterators. + // Use mem_resize(new_n_nonzero) to resize this array. + // + // WARNING: the 'values' array is only valid after sync() is called; + // WARNING: there is a separate cache for fast element insertion + + arma_aligned const eT* const values; + + + // The row indices of each value. row_indices[i] is the row of values[i]. + // + // The length of this array is (n_nonzero + 1). + // The final value row_indices[n_nonzero] must be zero to ensure integrity of iterators. + // Use mem_resize(new_n_nonzero) to resize this array. + // + // WARNING: the 'row_indices' array is only valid after sync() is called; + // WARNING: there is a separate cache for fast element insertion + + arma_aligned const uword* const row_indices; + + + // The column pointers. This stores the index of the first item in column i. + // That is, values[col_ptrs[i]] is the first value in column i, + // and it is in the row indicated by row_indices[col_ptrs[i]]. + // + // The length of this array is (n_cols + 2). + // The element col_ptrs[n_cols] must be equal to n_nonzero. + // The element col_ptrs[n_cols + 1] must be an invalid very large value to ensure integrity of iterators. + // + // The col_ptrs array is set by the init() function + // (which is called by constructors, set_size() and other functions that change the matrix size). + // + // WARNING: the 'col_ptrs' array is only valid after sync() is called; + // WARNING: there is a separate cache for fast element insertion + + arma_aligned const uword* const col_ptrs; + + inline SpMat(); + inline ~SpMat(); + + inline explicit SpMat(const uword in_rows, const uword in_cols); + inline explicit SpMat(const SizeMat& s); + + inline SpMat(const char* text); + inline SpMat& operator=(const char* text); + inline SpMat(const std::string& text); + inline SpMat& operator=(const std::string& text); + inline SpMat(const SpMat& x); + + inline SpMat(SpMat&& m); + inline SpMat& operator=(SpMat&& m); + + inline explicit SpMat(const MapMat& x); + inline SpMat& operator= (const MapMat& x); + + template + inline SpMat(const Base& rowind, const Base& colptr, const Base& values, const uword n_rows, const uword n_cols, const bool check_for_zeros = true); + + template + inline SpMat(const Base& locations, const Base& values, const bool sort_locations = true); + + template + inline SpMat(const Base& locations, const Base& values, const uword n_rows, const uword n_cols, const bool sort_locations = true, const bool check_for_zeros = true); + + template + inline SpMat(const bool add_values, const Base& locations, const Base& values, const uword n_rows, const uword n_cols, const bool sort_locations = true, const bool check_for_zeros = true); + + inline SpMat& operator= (const eT val); //! sets size to 1x1 + inline SpMat& operator*=(const eT val); + inline SpMat& operator/=(const eT val); + // operator+=(val) and operator-=(val) are not defined as they don't make sense for sparse matrices + + inline SpMat& operator= (const SpMat& m); + inline SpMat& operator+=(const SpMat& m); + inline SpMat& operator-=(const SpMat& m); + inline SpMat& operator*=(const SpMat& m); + inline SpMat& operator%=(const SpMat& m); + inline SpMat& operator/=(const SpMat& m); + + template inline explicit SpMat(const Base& m); + template inline SpMat& operator= (const Base& m); + template inline SpMat& operator+=(const Base& m); + template inline SpMat& operator-=(const Base& m); + template inline SpMat& operator*=(const Base& m); + template inline SpMat& operator/=(const Base& m); + template inline SpMat& operator%=(const Base& m); + + template inline explicit SpMat(const Op& expr); + template inline SpMat& operator= (const Op& expr); + template inline SpMat& operator+=(const Op& expr); + template inline SpMat& operator-=(const Op& expr); + template inline SpMat& operator*=(const Op& expr); + template inline SpMat& operator/=(const Op& expr); + template inline SpMat& operator%=(const Op& expr); + + //! explicit specification of sparse +/- scalar + template inline explicit SpMat(const SpToDOp& expr); + + //! construction of complex matrix out of two non-complex matrices + template + inline explicit SpMat(const SpBase& A, const SpBase& B); + + inline SpMat(const SpSubview& X); + inline SpMat& operator= (const SpSubview& X); + inline SpMat& operator+=(const SpSubview& X); + inline SpMat& operator-=(const SpSubview& X); + inline SpMat& operator*=(const SpSubview& X); + inline SpMat& operator%=(const SpSubview& X); + inline SpMat& operator/=(const SpSubview& X); + + template inline SpMat(const SpSubview_col_list& X); + template inline SpMat& operator= (const SpSubview_col_list& X); + template inline SpMat& operator+=(const SpSubview_col_list& X); + template inline SpMat& operator-=(const SpSubview_col_list& X); + template inline SpMat& operator*=(const SpSubview_col_list& X); + template inline SpMat& operator%=(const SpSubview_col_list& X); + template inline SpMat& operator/=(const SpSubview_col_list& X); + + inline SpMat(const spdiagview& X); + inline SpMat& operator= (const spdiagview& X); + inline SpMat& operator+=(const spdiagview& X); + inline SpMat& operator-=(const spdiagview& X); + inline SpMat& operator*=(const spdiagview& X); + inline SpMat& operator%=(const spdiagview& X); + inline SpMat& operator/=(const spdiagview& X); + + // delayed unary ops + template inline SpMat(const SpOp& X); + template inline SpMat& operator= (const SpOp& X); + template inline SpMat& operator+=(const SpOp& X); + template inline SpMat& operator-=(const SpOp& X); + template inline SpMat& operator*=(const SpOp& X); + template inline SpMat& operator%=(const SpOp& X); + template inline SpMat& operator/=(const SpOp& X); + + // delayed binary ops + template inline SpMat(const SpGlue& X); + template inline SpMat& operator= (const SpGlue& X); + template inline SpMat& operator+=(const SpGlue& X); + template inline SpMat& operator-=(const SpGlue& X); + template inline SpMat& operator*=(const SpGlue& X); + template inline SpMat& operator%=(const SpGlue& X); + template inline SpMat& operator/=(const SpGlue& X); + + // delayed mixed-type unary ops + template inline SpMat(const mtSpOp& X); + template inline SpMat& operator= (const mtSpOp& X); + template inline SpMat& operator+=(const mtSpOp& X); + template inline SpMat& operator-=(const mtSpOp& X); + template inline SpMat& operator*=(const mtSpOp& X); + template inline SpMat& operator%=(const mtSpOp& X); + template inline SpMat& operator/=(const mtSpOp& X); + + // delayed mixed-type binary ops + template inline SpMat(const mtSpGlue& X); + template inline SpMat& operator= (const mtSpGlue& X); + template inline SpMat& operator+=(const mtSpGlue& X); + template inline SpMat& operator-=(const mtSpGlue& X); + template inline SpMat& operator*=(const mtSpGlue& X); + template inline SpMat& operator%=(const mtSpGlue& X); + template inline SpMat& operator/=(const mtSpGlue& X); + + + arma_inline SpSubview_row row(const uword row_num); + arma_inline const SpSubview_row row(const uword row_num) const; + + inline SpSubview_row operator()(const uword row_num, const span& col_span); + inline const SpSubview_row operator()(const uword row_num, const span& col_span) const; + + arma_inline SpSubview_col col(const uword col_num); + arma_inline const SpSubview_col col(const uword col_num) const; + + inline SpSubview_col operator()(const span& row_span, const uword col_num); + inline const SpSubview_col operator()(const span& row_span, const uword col_num) const; + + arma_inline SpSubview rows(const uword in_row1, const uword in_row2); + arma_inline const SpSubview rows(const uword in_row1, const uword in_row2) const; + + arma_inline SpSubview cols(const uword in_col1, const uword in_col2); + arma_inline const SpSubview cols(const uword in_col1, const uword in_col2) const; + + arma_inline SpSubview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); + arma_inline const SpSubview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const; + + arma_inline SpSubview submat(const uword in_row1, const uword in_col1, const SizeMat& s); + arma_inline const SpSubview submat(const uword in_row1, const uword in_col1, const SizeMat& s) const; + + inline SpSubview submat (const span& row_span, const span& col_span); + inline const SpSubview submat (const span& row_span, const span& col_span) const; + + inline SpSubview operator()(const span& row_span, const span& col_span); + inline const SpSubview operator()(const span& row_span, const span& col_span) const; + + arma_inline SpSubview operator()(const uword in_row1, const uword in_col1, const SizeMat& s); + arma_inline const SpSubview operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const; + + + inline SpSubview head_rows(const uword N); + inline const SpSubview head_rows(const uword N) const; + + inline SpSubview tail_rows(const uword N); + inline const SpSubview tail_rows(const uword N) const; + + inline SpSubview head_cols(const uword N); + inline const SpSubview head_cols(const uword N) const; + + inline SpSubview tail_cols(const uword N); + inline const SpSubview tail_cols(const uword N) const; + + + template arma_inline SpSubview_col_list cols(const Base& ci); + template arma_inline const SpSubview_col_list cols(const Base& ci) const; + + + inline spdiagview diag(const sword in_id = 0); + inline const spdiagview diag(const sword in_id = 0) const; + + + inline void swap_rows(const uword in_row1, const uword in_row2); + inline void swap_cols(const uword in_col1, const uword in_col2); + + inline void shed_row(const uword row_num); + inline void shed_col(const uword col_num); + + inline void shed_rows(const uword in_row1, const uword in_row2); + inline void shed_cols(const uword in_col1, const uword in_col2); + + + // access the i-th element; if there is nothing at element i, 0 is returned + arma_warn_unused arma_inline SpMat_MapMat_val operator[] (const uword i); + arma_warn_unused arma_inline eT operator[] (const uword i) const; + + arma_warn_unused arma_inline SpMat_MapMat_val at (const uword i); + arma_warn_unused arma_inline eT at (const uword i) const; + + arma_warn_unused arma_inline SpMat_MapMat_val operator() (const uword i); + arma_warn_unused arma_inline eT operator() (const uword i) const; + + // access the element at the given row and column; if there is nothing at that position, 0 is returned + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline SpMat_MapMat_val operator[] (const uword in_row, const uword in_col); + arma_warn_unused arma_inline eT operator[] (const uword in_row, const uword in_col) const; + #endif + + arma_warn_unused arma_inline SpMat_MapMat_val at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline eT at (const uword in_row, const uword in_col) const; + + arma_warn_unused arma_inline SpMat_MapMat_val operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline eT operator() (const uword in_row, const uword in_col) const; + + + arma_warn_unused arma_inline bool is_empty() const; + arma_warn_unused arma_inline bool is_vec() const; + arma_warn_unused arma_inline bool is_rowvec() const; + arma_warn_unused arma_inline bool is_colvec() const; + arma_warn_unused arma_inline bool is_square() const; + + arma_warn_unused inline bool is_symmetric() const; + arma_warn_unused inline bool is_symmetric(const typename get_pod_type::result tol) const; + + arma_warn_unused inline bool is_hermitian() const; + arma_warn_unused inline bool is_hermitian(const typename get_pod_type::result tol) const; + + arma_warn_unused inline bool internal_is_finite() const; + arma_warn_unused inline bool internal_has_inf() const; + arma_warn_unused inline bool internal_has_nan() const; + arma_warn_unused inline bool internal_has_nonfinite() const; + + arma_warn_unused arma_inline bool in_range(const uword i) const; + arma_warn_unused arma_inline bool in_range(const span& x) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const span& col_span) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; + + + template inline SpMat& copy_size(const SpMat& m); + template inline SpMat& copy_size(const Mat& m); + + inline SpMat& set_size(const uword in_elem); + inline SpMat& set_size(const uword in_rows, const uword in_cols); + inline SpMat& set_size(const SizeMat& s); + + inline SpMat& resize(const uword in_rows, const uword in_cols); + inline SpMat& resize(const SizeMat& s); + + inline SpMat& reshape(const uword in_rows, const uword in_cols); + inline SpMat& reshape(const SizeMat& s); + + inline void reshape_helper_generic(const uword in_rows, const uword in_cols); //! internal use only + inline void reshape_helper_intovec(); //! internal use only + + template inline SpMat& for_each(functor F); + template inline const SpMat& for_each(functor F) const; + + template inline SpMat& transform(functor F); + + inline SpMat& replace(const eT old_val, const eT new_val); + + inline SpMat& clean(const pod_type threshold); + + inline SpMat& clamp(const eT min_val, const eT max_val); + + inline SpMat& zeros(); + inline SpMat& zeros(const uword in_elem); + inline SpMat& zeros(const uword in_rows, const uword in_cols); + inline SpMat& zeros(const SizeMat& s); + + inline SpMat& eye(); + inline SpMat& eye(const uword in_rows, const uword in_cols); + inline SpMat& eye(const SizeMat& s); + + inline SpMat& speye(); + inline SpMat& speye(const uword in_rows, const uword in_cols); + inline SpMat& speye(const SizeMat& s); + + inline SpMat& sprandu(const uword in_rows, const uword in_cols, const double density); + inline SpMat& sprandu(const SizeMat& s, const double density); + + inline SpMat& sprandn(const uword in_rows, const uword in_cols, const double density); + inline SpMat& sprandn(const SizeMat& s, const double density); + + inline void reset(); + inline void reset_cache(); + + //! don't use this unless you're writing internal Armadillo code + inline void reserve(const uword in_rows, const uword in_cols, const uword new_n_nonzero); + + //! don't use this unless you're writing internal Armadillo code + inline SpMat(const arma_reserve_indicator&, const uword in_rows, const uword in_cols, const uword new_n_nonzero); + + //! don't use this unless you're writing internal Armadillo code + template + inline SpMat(const arma_layout_indicator&, const SpMat& x); + + template inline void set_real(const SpBase& X); + template inline void set_imag(const SpBase& X); + + + // saving and loading + // TODO: implement auto_detect for sparse matrices + + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save(const csv_name& spec, const file_type type = csv_ascii) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; + + arma_cold inline bool load(const std::string name, const file_type type = arma_binary); + arma_cold inline bool load(const csv_name& spec, const file_type type = csv_ascii); + arma_cold inline bool load( std::istream& is, const file_type type = arma_binary); + + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = arma_binary); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = arma_binary); + + + + // necessary forward declarations + class iterator_base; + class iterator; + class const_iterator; + class row_iterator; + class const_row_iterator; + + // iterator_base provides basic operators but not how to compare or how to iterate + class iterator_base + { + public: + + inline iterator_base(); + inline iterator_base(const SpMat& in_M); + inline iterator_base(const SpMat& in_M, const uword col, const uword pos); + + arma_inline eT operator*() const; + + // don't hold location internally; call "dummy" methods to get that information + arma_inline uword row() const { return M->row_indices[internal_pos]; } + arma_inline uword col() const { return internal_col; } + arma_inline uword pos() const { return internal_pos; } + + arma_aligned const SpMat* M; + arma_aligned uword internal_col; + arma_aligned uword internal_pos; + + typedef std::bidirectional_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef const eT* pointer; + typedef const eT& reference; + }; + + class const_iterator : public iterator_base + { + public: + + inline const_iterator(); + + inline const_iterator(const SpMat& in_M, uword initial_pos = 0); // assumes initial_pos is valid + inline const_iterator(const SpMat& in_M, uword in_row, uword in_col); // iterator will be at the first nonzero value after the given position (using forward columnwise traversal) + inline const_iterator(const SpMat& in_M, uword in_row, uword in_col, uword in_pos); // if the exact position of the iterator is known; in_row is a dummy argument + + inline const_iterator(const const_iterator& other); + inline const_iterator& operator= (const const_iterator& other) = default; + + arma_hot inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); + + arma_hot inline const_iterator& operator--(); + arma_warn_unused inline const_iterator operator--(int); + + arma_hot inline bool operator==(const const_iterator& rhs) const; + arma_hot inline bool operator!=(const const_iterator& rhs) const; + + arma_hot inline bool operator==(const typename SpSubview::const_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_iterator& rhs) const; + + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; + + arma_hot inline bool operator==(const typename SpSubview::const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_row_iterator& rhs) const; + }; + + /** + * So that we can iterate over nonzero values, we need an iterator implementation. + * This can't be as simple as for Mat, which is just a pointer to an eT. + * If a value is set to 0 using this iterator, the iterator is no longer valid! + */ + class iterator : public const_iterator + { + public: + + inline iterator() : const_iterator() { } + + inline iterator(SpMat& in_M, uword initial_pos = 0) : const_iterator(in_M, initial_pos) { } + inline iterator(SpMat& in_M, uword in_row, uword in_col) : const_iterator(in_M, in_row, in_col) { } + inline iterator(SpMat& in_M, uword in_row, uword in_col, uword in_pos) : const_iterator(in_M, in_row, in_col, in_pos) { } + + inline iterator (const iterator& other) : const_iterator(other) { } + inline iterator& operator=(const iterator& other) = default; + + arma_hot inline SpValProxy< SpMat > operator*(); + + // overloads needed for return type correctness + arma_hot inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); + + arma_hot inline iterator& operator--(); + arma_warn_unused inline iterator operator--(int); + + // this has a different value_type than iterator_base + typedef SpValProxy< SpMat > value_type; + typedef const SpValProxy< SpMat >* pointer; + typedef const SpValProxy< SpMat >& reference; + }; + + class const_row_iterator : public iterator_base + { + public: + + inline const_row_iterator(); + inline const_row_iterator(const SpMat& in_M, uword initial_pos = 0); + inline const_row_iterator(const SpMat& in_M, uword in_row, uword in_col); + + inline const_row_iterator(const const_row_iterator& other); + inline const_row_iterator& operator= (const const_row_iterator& other) = default; + + arma_hot inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); + + arma_hot inline const_row_iterator& operator--(); + arma_warn_unused inline const_row_iterator operator--(int); + + uword internal_row; // hold row internally + uword actual_pos; // hold the true position we are at in the matrix, as column-major indexing + + arma_inline eT operator*() const { return iterator_base::M->values[actual_pos]; } + + arma_inline uword row() const { return internal_row; } + + arma_hot inline bool operator==(const const_iterator& rhs) const; + arma_hot inline bool operator!=(const const_iterator& rhs) const; + + arma_hot inline bool operator==(const typename SpSubview::const_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_iterator& rhs) const; + + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; + + arma_hot inline bool operator==(const typename SpSubview::const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_row_iterator& rhs) const; + }; + + class row_iterator : public const_row_iterator + { + public: + + inline row_iterator() : const_row_iterator() {} + + inline row_iterator(SpMat& in_M, uword initial_pos = 0) : const_row_iterator(in_M, initial_pos) { } + inline row_iterator(SpMat& in_M, uword in_row, uword in_col) : const_row_iterator(in_M, in_row, in_col) { } + + inline row_iterator(const row_iterator& other) : const_row_iterator(other) { } + inline row_iterator& operator= (const row_iterator& other) = default; + + arma_hot inline SpValProxy< SpMat > operator*(); + + // overloads required for return type correctness + arma_hot inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); + + arma_hot inline row_iterator& operator--(); + arma_warn_unused inline row_iterator operator--(int); + + // this has a different value_type than iterator_base + typedef SpValProxy< SpMat > value_type; + typedef const SpValProxy< SpMat >* pointer; + typedef const SpValProxy< SpMat >& reference; + }; + + + typedef iterator col_iterator; + typedef const_iterator const_col_iterator; + + typedef iterator row_col_iterator; + typedef const_iterator const_row_col_iterator; + + + inline iterator begin(); + inline const_iterator begin() const; + inline const_iterator cbegin() const; + + inline iterator end(); + inline const_iterator end() const; + inline const_iterator cend() const; + + inline col_iterator begin_col(const uword col_num); + inline const_col_iterator begin_col(const uword col_num) const; + + inline col_iterator begin_col_no_sync(const uword col_num); + inline const_col_iterator begin_col_no_sync(const uword col_num) const; + + inline col_iterator end_col(const uword col_num); + inline const_col_iterator end_col(const uword col_num) const; + + inline col_iterator end_col_no_sync(const uword col_num); + inline const_col_iterator end_col_no_sync(const uword col_num) const; + + inline row_iterator begin_row(const uword row_num = 0); + inline const_row_iterator begin_row(const uword row_num = 0) const; + + inline row_iterator end_row(); + inline const_row_iterator end_row() const; + + inline row_iterator end_row(const uword row_num); + inline const_row_iterator end_row(const uword row_num) const; + + inline row_col_iterator begin_row_col(); + inline const_row_col_iterator begin_row_col() const; + + inline row_col_iterator end_row_col(); + inline const_row_col_iterator end_row_col() const; + + + inline void clear(); + inline bool empty() const; + inline uword size() const; + + arma_warn_unused arma_inline SpMat_MapMat_val front(); + arma_warn_unused arma_inline eT front() const; + + arma_warn_unused arma_inline SpMat_MapMat_val back(); + arma_warn_unused arma_inline eT back() const; + + // Resize memory. + // If the new size is larger, the column pointers and new memory still need to be correctly set. + // If the new size is smaller, the first new_n_nonzero elements will be copied. + // n_nonzero is updated. + inline void mem_resize(const uword new_n_nonzero); + + //! synchronise CSC from cache + inline void sync() const; + + //! don't use this unless you're writing internal Armadillo code + inline void remove_zeros(); + + //! don't use this unless you're writing internal Armadillo code + inline void steal_mem(SpMat& X); + + //! don't use this unless you're writing internal Armadillo code + inline void steal_mem_simple(SpMat& X); + + //! don't use this unless you're writing internal Armadillo code + template< typename T1, typename Functor> inline void init_xform (const SpBase& x, const Functor& func); + template inline void init_xform_mt(const SpBase& x, const Functor& func); + + //! don't use this unless you're writing internal Armadillo code + arma_inline bool is_alias(const SpMat& X) const; + + + protected: + + inline void init(uword in_rows, uword in_cols, const uword new_n_nonzero = 0); + arma_cold inline void init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero = 0); + + inline void init(const std::string& text); + inline void init(const SpMat& x); + inline void init(const MapMat& x); + + inline void init_simple(const SpMat& x); + + inline void init_batch_std(const Mat& locations, const Mat& values, const bool sort_locations); + inline void init_batch_add(const Mat& locations, const Mat& values, const bool sort_locations); + + inline SpMat(const arma_vec_indicator&, const uword in_vec_state); + inline SpMat(const arma_vec_indicator&, const uword in_n_rows, const uword in_n_cols, const uword in_vec_state); + + + private: + + arma_warn_unused arma_hot inline const eT* find_value_csc(const uword in_row, const uword in_col) const; + + arma_warn_unused arma_hot inline eT get_value(const uword i ) const; + arma_warn_unused arma_hot inline eT get_value(const uword in_row, const uword in_col) const; + + arma_warn_unused arma_hot inline eT get_value_csc(const uword i ) const; + arma_warn_unused arma_hot inline eT get_value_csc(const uword in_row, const uword in_col) const; + + arma_warn_unused arma_hot inline bool try_set_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_add_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_sub_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_mul_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_div_value_csc(const uword in_row, const uword in_col, const eT in_val); + + arma_warn_unused inline eT& insert_element(const uword in_row, const uword in_col, const eT in_val = eT(0)); + inline void delete_element(const uword in_row, const uword in_col); + + + // cache related + + arma_aligned mutable MapMat cache; + arma_aligned mutable state_type sync_state; + // 0: cache needs to be updated from CSC (ie. CSC has more recent data) + // 1: CSC needs to be updated from cache (ie. cache has more recent data) + // 2: no update required (ie. CSC and cache contain the same data) + + #if (!defined(ARMA_DONT_USE_STD_MUTEX)) + arma_aligned mutable std::mutex cache_mutex; + #endif + + arma_inline void invalidate_cache() const; + arma_inline void invalidate_csc() const; + + inline void sync_cache() const; + inline void sync_cache_simple() const; + inline void sync_csc() const; + inline void sync_csc_simple() const; + + + friend class SpValProxy< SpMat >; // allow SpValProxy to call insert_element() and delete_element() + friend class SpSubview; + friend class SpRow; + friend class SpCol; + friend class SpMat_MapMat_val; + friend class SpSubview_MapMat_val; + friend class spdiagview; + + template friend class SpSubview_col_list; + + public: + + #if defined(ARMA_EXTRA_SPMAT_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_PROTO) + #endif + }; + + + +class SpMat_aux + { + public: + + template inline static void set_real(SpMat& out, const SpBase& X); + template inline static void set_real(SpMat< std::complex >& out, const SpBase< T,T1>& X); + + template inline static void set_imag(SpMat& out, const SpBase& X); + template inline static void set_imag(SpMat< std::complex >& out, const SpBase< T,T1>& X); + }; + + + +#define ARMA_HAS_SPMAT + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpMat_iterators_meat.hpp b/src/armadillo/include/armadillo_bits/SpMat_iterators_meat.hpp new file mode 100644 index 0000000..ed29640 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpMat_iterators_meat.hpp @@ -0,0 +1,964 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpMat +//! @{ + + +/////////////////////////////////////////////////////////////////////////////// +// SpMat::iterator_base implementation // +/////////////////////////////////////////////////////////////////////////////// + + +template +inline +SpMat::iterator_base::iterator_base() + : M(nullptr) + , internal_col(0) + , internal_pos(0) + { + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +SpMat::iterator_base::iterator_base(const SpMat& in_M) + : M(&in_M) + , internal_col(0) + , internal_pos(0) + { + // Technically this iterator is invalid (it may not point to a valid element) + } + + + +template +inline +SpMat::iterator_base::iterator_base(const SpMat& in_M, const uword in_col, const uword in_pos) + : M(&in_M) + , internal_col(in_col) + , internal_pos(in_pos) + { + // Nothing to do. + } + + + +template +arma_inline +eT +SpMat::iterator_base::operator*() const + { + return M->values[internal_pos]; + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpMat::const_iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpMat::const_iterator::const_iterator() + : iterator_base() + { + } + + + +template +inline +SpMat::const_iterator::const_iterator(const SpMat& in_M, uword initial_pos) + : iterator_base(in_M, 0, initial_pos) + { + // Corner case for empty matrices. + if(in_M.n_nonzero == 0) + { + iterator_base::internal_col = in_M.n_cols; + return; + } + + // Determine which column we should be in. + while(iterator_base::M->col_ptrs[iterator_base::internal_col + 1] <= iterator_base::internal_pos) + { + iterator_base::internal_col++; + } + } + + + +template +inline +SpMat::const_iterator::const_iterator(const SpMat& in_M, uword in_row, uword in_col) + : iterator_base(in_M, in_col, 0) + { + // So we have a position we want to be right after. Skip to the column. + iterator_base::internal_pos = iterator_base::M->col_ptrs[iterator_base::internal_col]; + + // Now we have to make sure that is the right column. + while(iterator_base::M->col_ptrs[iterator_base::internal_col + 1] <= iterator_base::internal_pos) + { + iterator_base::internal_col++; + } + + // Now we have to get to the right row. + while((iterator_base::M->row_indices[iterator_base::internal_pos] < in_row) && (iterator_base::internal_col == in_col)) + { + ++(*this); // Increment iterator. + } + } + + + +template +inline +SpMat::const_iterator::const_iterator(const SpMat& in_M, const uword /* in_row */, const uword in_col, const uword in_pos) + : iterator_base(in_M, in_col, in_pos) + { + // Nothing to do. + } + + + +template +inline +SpMat::const_iterator::const_iterator(const typename SpMat::const_iterator& other) + : iterator_base(*other.M, other.internal_col, other.internal_pos) + { + // Nothing to do. + } + + + +template +inline +typename SpMat::const_iterator& +SpMat::const_iterator::operator++() + { + ++iterator_base::internal_pos; + + if(iterator_base::internal_pos == iterator_base::M->n_nonzero) + { + iterator_base::internal_col = iterator_base::M->n_cols; + return *this; + } + + // Check to see if we moved a column. + while(iterator_base::M->col_ptrs[iterator_base::internal_col + 1] <= iterator_base::internal_pos) + { + ++iterator_base::internal_col; + } + + return *this; + } + + + +template +inline +typename SpMat::const_iterator +SpMat::const_iterator::operator++(int) + { + typename SpMat::const_iterator tmp(*this); + + ++(*this); + + return tmp; + } + + + +template +inline +typename SpMat::const_iterator& +SpMat::const_iterator::operator--() + { + --iterator_base::internal_pos; + + // First, see if we moved back a column. + while(iterator_base::internal_pos < iterator_base::M->col_ptrs[iterator_base::internal_col]) + { + --iterator_base::internal_col; + } + + return *this; + } + + + +template +inline +typename SpMat::const_iterator +SpMat::const_iterator::operator--(int) + { + typename SpMat::const_iterator tmp(*this); + + --(*this); + + return tmp; + } + + + +template +inline +bool +SpMat::const_iterator::operator==(const const_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_iterator::operator!=(const const_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_iterator::operator==(const typename SpSubview::const_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_iterator::operator!=(const typename SpSubview::const_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_iterator::operator==(const const_row_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_iterator::operator!=(const const_row_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_iterator::operator==(const typename SpSubview::const_row_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_iterator::operator!=(const typename SpSubview::const_row_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpMat::iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpValProxy< SpMat > +SpMat::iterator::operator*() + { + return SpValProxy< SpMat >( + iterator_base::M->row_indices[iterator_base::internal_pos], + iterator_base::internal_col, + access::rw(*iterator_base::M), + &access::rw(iterator_base::M->values[iterator_base::internal_pos])); + } + + + +template +inline +typename SpMat::iterator& +SpMat::iterator::operator++() + { + const_iterator::operator++(); + + return *this; + } + + + +template +inline +typename SpMat::iterator +SpMat::iterator::operator++(int) + { + typename SpMat::iterator tmp(*this); + + const_iterator::operator++(); + + return tmp; + } + + + +template +inline +typename SpMat::iterator& +SpMat::iterator::operator--() + { + const_iterator::operator--(); + + return *this; + } + + + +template +inline +typename SpMat::iterator +SpMat::iterator::operator--(int) + { + typename SpMat::iterator tmp(*this); + + const_iterator::operator--(); + + return tmp; + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpMat::const_row_iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +/** + * Initialize the const_row_iterator. + */ + +template +inline +SpMat::const_row_iterator::const_row_iterator() + : iterator_base() + , internal_row(0) + , actual_pos(0) + { + } + + + +template +inline +SpMat::const_row_iterator::const_row_iterator(const SpMat& in_M, uword initial_pos) + : iterator_base(in_M, 0, initial_pos) + , internal_row(0) + , actual_pos(0) + { + // Corner case for the end of a matrix. + if(initial_pos == in_M.n_nonzero) + { + iterator_base::internal_col = 0; + internal_row = in_M.n_rows; + actual_pos = in_M.n_nonzero; + iterator_base::internal_pos = in_M.n_nonzero; + + return; + } + + // We don't count zeros in our position count, so we have to find the nonzero + // value corresponding to the given initial position. We assume initial_pos + // is valid. + + // This is irritating because we don't know where the elements are in each row. + // What we will do is loop across all columns looking for elements in row 0 + // (and add to our sum), then in row 1, and so forth, until we get to the desired position. + uword cur_pos = std::numeric_limits::max(); // Invalid value. + uword cur_actual_pos = 0; + + for(uword row = 0; row < iterator_base::M->n_rows; ++row) + { + for(uword col = 0; col < iterator_base::M->n_cols; ++col) + { + // Find the first element with row greater than or equal to in_row. + const uword col_offset = iterator_base::M->col_ptrs[col ]; + const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; + + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, row); + + // This is the number of elements in the column with row index less than in_row. + const uword offset = uword(pos_ptr - start_ptr); + + if(iterator_base::M->row_indices[col_offset + offset] == row) + { + cur_actual_pos = col_offset + offset; + + // Increment position portably. + if(cur_pos == std::numeric_limits::max()) + { cur_pos = 0; } + else + { ++cur_pos; } + + // Do we terminate? + if(cur_pos == initial_pos) + { + internal_row = row; + iterator_base::internal_col = col; + iterator_base::internal_pos = cur_pos; + actual_pos = cur_actual_pos; + + return; + } + } + } + } + } + + // If we got to here, then we have gone past the end of the matrix. + // This shouldn't happen... + iterator_base::internal_pos = iterator_base::M->n_nonzero; + iterator_base::internal_col = 0; + internal_row = iterator_base::M->n_rows; + actual_pos = iterator_base::M->n_nonzero; + } + + + +template +inline +SpMat::const_row_iterator::const_row_iterator(const SpMat& in_M, uword in_row, uword in_col) + : iterator_base(in_M, in_col, 0) + , internal_row(0) + , actual_pos(0) + { + // Start our search in the given row. We need to find two things: + // + // 1. The first nonzero element (iterating by rows) after (in_row, in_col). + // 2. The number of nonzero elements (iterating by rows) that come before + // (in_row, in_col). + // + // We'll find these simultaneously, though we will have to loop over all + // columns. + + // This will hold the total number of points with rows less than in_row. + uword cur_pos = 0; + uword cur_min_row = iterator_base::M->n_rows; + uword cur_min_col = 0; + uword cur_actual_pos = 0; + + for(uword col = 0; col < iterator_base::M->n_cols; ++col) + { + // Find the first element with row greater than or equal to in_row. + const uword col_offset = iterator_base::M->col_ptrs[col ]; + const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; + + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, in_row); + + // This is the number of elements in the column with row index less than in_row. + const uword offset = uword(pos_ptr - start_ptr); + + cur_pos += offset; + + if(pos_ptr != end_ptr) + { + // This is the row index of the first element in the column with row index + // greater than or equal to in_row. + if((*pos_ptr) < cur_min_row) + { + // If we are in the desired row but before the desired column, + // we can't take this. + if(col >= in_col) + { + cur_min_row = (*pos_ptr); + cur_min_col = col; + cur_actual_pos = col_offset + offset; + } + } + } + } + } + + // Now we know what the minimum row is. + internal_row = cur_min_row; + iterator_base::internal_col = cur_min_col; + iterator_base::internal_pos = cur_pos; + actual_pos = cur_actual_pos; + } + + + +/** + * Initialize the const_row_iterator from another const_row_iterator. + */ +template +inline +SpMat::const_row_iterator::const_row_iterator(const typename SpMat::const_row_iterator& other) + : iterator_base(*other.M, other.internal_col, other.internal_pos) + , internal_row(other.internal_row) + , actual_pos(other.actual_pos) + { + // Nothing to do. + } + + + +/** + * Increment the row_iterator. + */ +template +inline +typename SpMat::const_row_iterator& +SpMat::const_row_iterator::operator++() + { + // We just need to find the next nonzero element. + iterator_base::internal_pos++; + + if(iterator_base::internal_pos == iterator_base::M->n_nonzero) + { + internal_row = iterator_base::M->n_rows; + iterator_base::internal_col = 0; + + return *this; + } + + // Otherwise, we need to search. We can start in the next column and use + // lower_bound() to find the next element. + uword next_min_row = iterator_base::M->n_rows; + uword next_min_col = iterator_base::M->n_cols; + uword next_actual_pos = 0; + + // Search from the current column to the end of the matrix. + for(uword col = iterator_base::internal_col + 1; col < iterator_base::M->n_cols; ++col) + { + // Find the first element with row greater than or equal to in_row. + const uword col_offset = iterator_base::M->col_ptrs[col ]; + const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; + + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + // Find the first element in the column with row greater than or equal to + // the current row. + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row); + + if(pos_ptr != end_ptr) + { + // We found something in the column, but is the row index correct? + if((*pos_ptr) == internal_row) + { + // Exact match---so we are done. + iterator_base::internal_col = col; + actual_pos = col_offset + (pos_ptr - start_ptr); + return *this; + } + else if((*pos_ptr) < next_min_row) + { + // The first element in this column is in a subsequent row, but it's + // the minimum row we've seen so far. + next_min_row = (*pos_ptr); + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + else if((*pos_ptr) == next_min_row && col < next_min_col) + { + // The first element in this column is in a subsequent row that we + // already have another element for, but the column index is less so + // this element will come first. + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + } + } + } + + // Restart the search in the next row. + for(uword col = 0; col <= iterator_base::internal_col; ++col) + { + // Find the first element with row greater than or equal to in_row + 1. + const uword col_offset = iterator_base::M->col_ptrs[col ]; + const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; + + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + 1); + + if(pos_ptr != end_ptr) + { + // We found something in the column, but is the row index correct? + if((*pos_ptr) == internal_row + 1) + { + // Exact match---so we are done. + iterator_base::internal_col = col; + internal_row++; + actual_pos = col_offset + (pos_ptr - start_ptr); + return *this; + } + else if((*pos_ptr) < next_min_row) + { + // The first element in this column is in a subsequent row, + // but it's the minimum row we've seen so far. + next_min_row = (*pos_ptr); + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + else if((*pos_ptr) == next_min_row && col < next_min_col) + { + // The first element in this column is in a subsequent row that we + // already have another element for, but the column index is less so + // this element will come first. + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + } + } + } + + iterator_base::internal_col = next_min_col; + internal_row = next_min_row; + actual_pos = next_actual_pos; + + return *this; // Now we are done. + } + + + +/** + * Increment the row_iterator (but do not return anything. + */ +template +inline +typename SpMat::const_row_iterator +SpMat::const_row_iterator::operator++(int) + { + typename SpMat::const_row_iterator tmp(*this); + + ++(*this); + + return tmp; + } + + + +/** + * Decrement the row_iterator. + */ +template +inline +typename SpMat::const_row_iterator& +SpMat::const_row_iterator::operator--() + { + if(iterator_base::internal_pos == 0) + { + // Do nothing; we are already at the beginning. + return *this; + } + + iterator_base::internal_pos--; + + // We have to search backwards. We'll do this by going backwards over columns + // and seeing if we find an element in the same row. + uword max_row = 0; + uword max_col = 0; + uword next_actual_pos = 0; + + //for(uword col = iterator_base::internal_col; col > 1; --col) + for(uword col = iterator_base::internal_col; col >= 1; --col) + { + // Find the first element with row greater than or equal to in_row + 1. + const uword col_offset = iterator_base::M->col_ptrs[col - 1]; + const uword next_col_offset = iterator_base::M->col_ptrs[col ]; + + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + // There are elements in this column. + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + 1); + + if(pos_ptr != start_ptr) + { + // The element before pos_ptr is the one we are interested in. + if(*(pos_ptr - 1) > max_row) + { + max_row = *(pos_ptr - 1); + max_col = col - 1; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + else if(*(pos_ptr - 1) == max_row && (col - 1) > max_col) + { + max_col = col - 1; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + } + } + } + + // Now loop around to the columns at the end of the matrix. + for(uword col = iterator_base::M->n_cols - 1; col >= iterator_base::internal_col; --col) + { + // Find the first element with row greater than or equal to in_row + 1. + const uword col_offset = iterator_base::M->col_ptrs[col ]; + const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; + + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + // There are elements in this column. + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row); + + if(pos_ptr != start_ptr) + { + // There are elements in this column with row index < internal_row. + if(*(pos_ptr - 1) > max_row) + { + max_row = *(pos_ptr - 1); + max_col = col; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + else if(*(pos_ptr - 1) == max_row && col > max_col) + { + max_col = col; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + } + } + + if(col == 0) // Catch edge case that the loop termination condition won't. + { + break; + } + } + + iterator_base::internal_col = max_col; + internal_row = max_row; + actual_pos = next_actual_pos; + + return *this; + } + + + +/** + * Decrement the row_iterator. + */ +template +inline +typename SpMat::const_row_iterator +SpMat::const_row_iterator::operator--(int) + { + typename SpMat::const_row_iterator tmp(*this); + + --(*this); + + return tmp; + } + + + +template +inline +bool +SpMat::const_row_iterator::operator==(const const_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_row_iterator::operator!=(const const_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_row_iterator::operator==(const typename SpSubview::const_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_row_iterator::operator!=(const typename SpSubview::const_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_row_iterator::operator==(const const_row_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_row_iterator::operator!=(const const_row_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_row_iterator::operator==(const typename SpSubview::const_row_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpMat::const_row_iterator::operator!=(const typename SpSubview::const_row_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpMat::row_iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpValProxy< SpMat > +SpMat::row_iterator::operator*() + { + return SpValProxy< SpMat >( + const_row_iterator::internal_row, + iterator_base::internal_col, + access::rw(*iterator_base::M), + &access::rw(iterator_base::M->values[const_row_iterator::actual_pos])); + } + + + +template +inline +typename SpMat::row_iterator& +SpMat::row_iterator::operator++() + { + const_row_iterator::operator++(); + + return *this; + } + + + +template +inline +typename SpMat::row_iterator +SpMat::row_iterator::operator++(int) + { + typename SpMat::row_iterator tmp(*this); + + const_row_iterator::operator++(); + + return tmp; + } + + + +template +inline +typename SpMat::row_iterator& +SpMat::row_iterator::operator--() + { + const_row_iterator::operator--(); + + return *this; + } + + + +template +inline +typename SpMat::row_iterator +SpMat::row_iterator::operator--(int) + { + typename SpMat::row_iterator tmp(*this); + + const_row_iterator::operator--(); + + return tmp; + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpMat_meat.hpp b/src/armadillo/include/armadillo_bits/SpMat_meat.hpp new file mode 100644 index 0000000..b8d51cf --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpMat_meat.hpp @@ -0,0 +1,6855 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpMat +//! @{ + + +/** + * Initialize a sparse matrix with size 0x0 (empty). + */ +template +inline +SpMat::SpMat() + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(0,0); + } + + + +/** + * Clean up the memory of a sparse matrix and destruct it. + */ +template +inline +SpMat::~SpMat() + { + arma_extra_debug_sigprint_this(this); + + if(values ) { memory::release(access::rw(values)); } + if(row_indices) { memory::release(access::rw(row_indices)); } + if(col_ptrs ) { memory::release(access::rw(col_ptrs)); } + } + + + +/** + * Constructor with size given. + */ +template +inline +SpMat::SpMat(const uword in_rows, const uword in_cols) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(in_rows, in_cols); + } + + + +template +inline +SpMat::SpMat(const SizeMat& s) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(s.n_rows, s.n_cols); + } + + + +template +inline +SpMat::SpMat(const arma_reserve_indicator&, const uword in_rows, const uword in_cols, const uword new_n_nonzero) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(in_rows, in_cols, new_n_nonzero); + } + + + +template +template +inline +SpMat::SpMat(const arma_layout_indicator&, const SpMat& x) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(x.n_rows, x.n_cols, x.n_nonzero); + + if(x.n_nonzero == 0) { return; } + + if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); } + if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); } + + // NOTE: 'values' array is not initialised + } + + + +/** + * Assemble from text. + */ +template +inline +SpMat::SpMat(const char* text) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(std::string(text)); + } + + + +template +inline +SpMat& +SpMat::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + init(std::string(text)); + + return *this; + } + + + +template +inline +SpMat::SpMat(const std::string& text) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint(); + + init(text); + } + + + +template +inline +SpMat& +SpMat::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + init(text); + + return *this; + } + + + +template +inline +SpMat::SpMat(const SpMat& x) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(x); + } + + + +template +inline +SpMat::SpMat(SpMat&& in_mat) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); + + (*this).steal_mem(in_mat); + } + + + +template +inline +SpMat& +SpMat::operator=(SpMat&& in_mat) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); + + (*this).steal_mem(in_mat); + + return *this; + } + + + +template +inline +SpMat::SpMat(const MapMat& x) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(x); + } + + + +template +inline +SpMat& +SpMat::operator=(const MapMat& x) + { + arma_extra_debug_sigprint(); + + init(x); + + return *this; + } + + + +//! Insert a large number of values at once. +//! locations.row[0] should be row indices, locations.row[1] should be column indices, +//! and values should be the corresponding values. +//! If sort_locations is false, then it is assumed that the locations and values +//! are already sorted in column-major ordering. +template +template +inline +SpMat::SpMat(const Base& locations_expr, const Base& vals_expr, const bool sort_locations) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + const quasi_unwrap locs_tmp( locations_expr.get_ref() ); + const quasi_unwrap vals_tmp( vals_expr.get_ref() ); + + const Mat& locs = locs_tmp.M; + const Mat& vals = vals_tmp.M; + + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); + arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" ); + arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" ); + + // If there are no elements in the list, max() will fail. + if(locs.n_cols == 0) { init_cold(0, 0); return; } + + // Automatically determine size before pruning zeros. + uvec bounds = arma::max(locs, 1); + init_cold(bounds[0] + 1, bounds[1] + 1); + + // Ensure that there are no zeros + const uword N_old = vals.n_elem; + uword N_new = 0; + + for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); } + + if(N_new != N_old) + { + Col filtered_vals( N_new, arma_nozeros_indicator()); + Mat filtered_locs(2, N_new, arma_nozeros_indicator()); + + uword index = 0; + for(uword i = 0; i < N_old; ++i) + { + if(vals[i] != eT(0)) + { + filtered_vals[index] = vals[i]; + + filtered_locs.at(0, index) = locs.at(0, i); + filtered_locs.at(1, index) = locs.at(1, i); + + ++index; + } + } + + init_batch_std(filtered_locs, filtered_vals, sort_locations); + } + else + { + init_batch_std(locs, vals, sort_locations); + } + } + + + +//! Insert a large number of values at once. +//! locations.row[0] should be row indices, locations.row[1] should be column indices, +//! and values should be the corresponding values. +//! If sort_locations is false, then it is assumed that the locations and values +//! are already sorted in column-major ordering. +//! In this constructor the size is explicitly given. +template +template +inline +SpMat::SpMat(const Base& locations_expr, const Base& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations, const bool check_for_zeros) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + const quasi_unwrap locs_tmp( locations_expr.get_ref() ); + const quasi_unwrap vals_tmp( vals_expr.get_ref() ); + + const Mat& locs = locs_tmp.M; + const Mat& vals = vals_tmp.M; + + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); + arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" ); + arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" ); + + init_cold(in_n_rows, in_n_cols); + + // Ensure that there are no zeros, unless the user asked not to. + if(check_for_zeros) + { + const uword N_old = vals.n_elem; + uword N_new = 0; + + for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); } + + if(N_new != N_old) + { + Col filtered_vals( N_new, arma_nozeros_indicator()); + Mat filtered_locs(2, N_new, arma_nozeros_indicator()); + + uword index = 0; + for(uword i = 0; i < N_old; ++i) + { + if(vals[i] != eT(0)) + { + filtered_vals[index] = vals[i]; + + filtered_locs.at(0, index) = locs.at(0, i); + filtered_locs.at(1, index) = locs.at(1, i); + + ++index; + } + } + + init_batch_std(filtered_locs, filtered_vals, sort_locations); + } + else + { + init_batch_std(locs, vals, sort_locations); + } + } + else + { + init_batch_std(locs, vals, sort_locations); + } + } + + + +template +template +inline +SpMat::SpMat(const bool add_values, const Base& locations_expr, const Base& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations, const bool check_for_zeros) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + const quasi_unwrap locs_tmp( locations_expr.get_ref() ); + const quasi_unwrap vals_tmp( vals_expr.get_ref() ); + + const Mat& locs = locs_tmp.M; + const Mat& vals = vals_tmp.M; + + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); + arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" ); + arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" ); + + init_cold(in_n_rows, in_n_cols); + + // Ensure that there are no zeros, unless the user asked not to. + if(check_for_zeros) + { + const uword N_old = vals.n_elem; + uword N_new = 0; + + for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); } + + if(N_new != N_old) + { + Col filtered_vals( N_new, arma_nozeros_indicator()); + Mat filtered_locs(2, N_new, arma_nozeros_indicator()); + + uword index = 0; + for(uword i = 0; i < N_old; ++i) + { + if(vals[i] != eT(0)) + { + filtered_vals[index] = vals[i]; + + filtered_locs.at(0, index) = locs.at(0, i); + filtered_locs.at(1, index) = locs.at(1, i); + + ++index; + } + } + + add_values ? init_batch_add(filtered_locs, filtered_vals, sort_locations) : init_batch_std(filtered_locs, filtered_vals, sort_locations); + } + else + { + add_values ? init_batch_add(locs, vals, sort_locations) : init_batch_std(locs, vals, sort_locations); + } + } + else + { + add_values ? init_batch_add(locs, vals, sort_locations) : init_batch_std(locs, vals, sort_locations); + } + } + + + +//! Insert a large number of values at once. +//! Per CSC format, rowind_expr should be row indices, +//! colptr_expr should column ptr indices locations, +//! and values should be the corresponding values. +//! In this constructor the size is explicitly given. +//! Values are assumed to be sorted, and the size +//! information is trusted +template +template +inline +SpMat::SpMat + ( + const Base& rowind_expr, + const Base& colptr_expr, + const Base& values_expr, + const uword in_n_rows, + const uword in_n_cols, + const bool check_for_zeros + ) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + const quasi_unwrap rowind_tmp( rowind_expr.get_ref() ); + const quasi_unwrap colptr_tmp( colptr_expr.get_ref() ); + const quasi_unwrap vals_tmp( values_expr.get_ref() ); + + const Mat& rowind = rowind_tmp.M; + const Mat& colptr = colptr_tmp.M; + const Mat& vals = vals_tmp.M; + + arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object must be a vector" ); + arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object must be a vector" ); + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); + + // Resize to correct number of elements (this also sets n_nonzero) + init_cold(in_n_rows, in_n_cols, vals.n_elem); + + arma_debug_check( (rowind.n_elem != vals.n_elem), "SpMat::SpMat(): number of row indices is not equal to number of values" ); + arma_debug_check( (colptr.n_elem != (n_cols+1) ), "SpMat::SpMat(): number of column pointers is not equal to n_cols+1" ); + + // copy supplied values into sparse matrix -- not checked for consistency + arrayops::copy(access::rwp(row_indices), rowind.memptr(), rowind.n_elem ); + arrayops::copy(access::rwp(col_ptrs), colptr.memptr(), colptr.n_elem ); + arrayops::copy(access::rwp(values), vals.memptr(), vals.n_elem ); + + // important: set the sentinel as well + access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits::max(); + + // make sure no zeros are stored + if(check_for_zeros) { remove_zeros(); } + } + + + +template +inline +SpMat& +SpMat::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val != eT(0)) + { + // Resize to 1x1 then set that to the right value. + init(1, 1, 1); // Sets col_ptrs to 0. + + // Manually set element. + access::rw(values[0]) = val; + access::rw(row_indices[0]) = 0; + access::rw(col_ptrs[1]) = 1; + } + else + { + init(0, 0); + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator*=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val != eT(0)) + { + sync_csc(); + invalidate_cache(); + + const uword n_nz = n_nonzero; + + eT* vals = access::rwp(values); + + bool has_zero = false; + + for(uword i=0; i +inline +SpMat& +SpMat::operator/=(const eT val) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (val == eT(0)), "element-wise division: division by zero" ); + + sync_csc(); + invalidate_cache(); + + const uword n_nz = n_nonzero; + + eT* vals = access::rwp(values); + + bool has_zero = false; + + for(uword i=0; i +inline +SpMat& +SpMat::operator=(const SpMat& x) + { + arma_extra_debug_sigprint(); + + init(x); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator+=(const SpMat& x) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat out = (*this) + x; + + steal_mem(out); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator-=(const SpMat& x) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat out = (*this) - x; + + steal_mem(out); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator*=(const SpMat& y) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat z = (*this) * y; + + steal_mem(z); + + return *this; + } + + + +// This is in-place element-wise matrix multiplication. +template +inline +SpMat& +SpMat::operator%=(const SpMat& y) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat z = (*this) % y; + + steal_mem(z); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator/=(const SpMat& x) + { + arma_extra_debug_sigprint(); + + // NOTE: use of this function is not advised; it is implemented only for completeness + + arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division"); + + for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) + { + at(r, c) /= x.at(r, c); + } + + return *this; + } + + + +template +template +inline +SpMat::SpMat(const SpToDOp& expr) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + typedef typename T1::elem_type T; + + // Make sure the type is compatible. + arma_type_check(( is_same_type< eT, T >::no )); + + op_type::apply(*this, expr); + } + + + +// Construct a complex matrix out of two non-complex matrices +template +template +inline +SpMat::SpMat + ( + const SpBase::pod_type, T1>& A, + const SpBase::pod_type, T2>& B + ) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type T; + + // Make sure eT is complex and T is not (compile-time check). + arma_type_check(( is_cx::no )); + arma_type_check(( is_cx< T>::yes )); + + // Compile-time abort if types are not compatible. + arma_type_check(( is_same_type< std::complex, eT >::no )); + + const unwrap_spmat tmp1(A.get_ref()); + const unwrap_spmat tmp2(B.get_ref()); + + const SpMat& X = tmp1.M; + const SpMat& Y = tmp2.M; + + arma_debug_assert_same_size(X.n_rows, X.n_cols, Y.n_rows, Y.n_cols, "SpMat()"); + + const uword l_n_rows = X.n_rows; + const uword l_n_cols = X.n_cols; + + // Set size of matrix correctly. + init_cold(l_n_rows, l_n_cols, n_unique(X, Y, op_n_unique_count())); + + // Now on a second iteration, fill it. + typename SpMat::const_iterator x_it = X.begin(); + typename SpMat::const_iterator x_end = X.end(); + + typename SpMat::const_iterator y_it = Y.begin(); + typename SpMat::const_iterator y_end = Y.end(); + + uword cur_pos = 0; + + while((x_it != x_end) || (y_it != y_end)) + { + if(x_it == y_it) // if we are at the same place + { + access::rw(values[cur_pos]) = std::complex((T) *x_it, (T) *y_it); + access::rw(row_indices[cur_pos]) = x_it.row(); + ++access::rw(col_ptrs[x_it.col() + 1]); + + ++x_it; + ++y_it; + } + else + { + if((x_it.col() < y_it.col()) || ((x_it.col() == y_it.col()) && (x_it.row() < y_it.row()))) // if y is closer to the end + { + access::rw(values[cur_pos]) = std::complex((T) *x_it, T(0)); + access::rw(row_indices[cur_pos]) = x_it.row(); + ++access::rw(col_ptrs[x_it.col() + 1]); + + ++x_it; + } + else // x is closer to the end + { + access::rw(values[cur_pos]) = std::complex(T(0), (T) *y_it); + access::rw(row_indices[cur_pos]) = y_it.row(); + ++access::rw(col_ptrs[y_it.col() + 1]); + + ++y_it; + } + } + + ++cur_pos; + } + + // Now fix the column pointers; they are supposed to be a sum. + for(uword c = 1; c <= n_cols; ++c) + { + access::rw(col_ptrs[c]) += col_ptrs[c - 1]; + } + + } + + + +template +template +inline +SpMat::SpMat(const Base& x) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(x); + } + + + +template +template +inline +SpMat& +SpMat::operator=(const Base& expr) + { + arma_extra_debug_sigprint(); + + if(is_same_type< T1, Gen, gen_zeros> >::yes) + { + const Proxy P(expr.get_ref()); + + (*this).zeros( P.get_n_rows(), P.get_n_cols() ); + + return *this; + } + + if(is_same_type< T1, Gen, gen_eye> >::yes) + { + const Proxy P(expr.get_ref()); + + (*this).eye( P.get_n_rows(), P.get_n_cols() ); + + return *this; + } + + const quasi_unwrap tmp(expr.get_ref()); + const Mat& x = tmp.M; + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + const uword x_n_elem = x.n_elem; + + // Count number of nonzero elements in base object. + uword n = 0; + + const eT* x_mem = x.memptr(); + + for(uword i=0; i < x_n_elem; ++i) { n += (x_mem[i] != eT(0)) ? uword(1) : uword(0); } + + init(x_n_rows, x_n_cols, n); + + if(n == 0) { return *this; } + + // Now the memory is resized correctly; set nonzero elements. + n = 0; + for(uword j = 0; j < x_n_cols; ++j) + for(uword i = 0; i < x_n_rows; ++i) + { + const eT val = (*x_mem); x_mem++; + + if(val != eT(0)) + { + access::rw(values[n]) = val; + access::rw(row_indices[n]) = i; + access::rw(col_ptrs[j + 1])++; + ++n; + } + } + + // Sum column counts to be column pointers. + for(uword c = 1; c <= n_cols; ++c) + { + access::rw(col_ptrs[c]) += col_ptrs[c - 1]; + } + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator+=(const Base& x) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return (*this).operator=( (*this) + x.get_ref() ); + } + + + +template +template +inline +SpMat& +SpMat::operator-=(const Base& x) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return (*this).operator=( (*this) - x.get_ref() ); + } + + + +template +template +inline +SpMat& +SpMat::operator*=(const Base& x) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return (*this).operator=( (*this) * x.get_ref() ); + } + + + +// NOTE: use of this function is not advised; it is implemented only for completeness +template +template +inline +SpMat& +SpMat::operator/=(const Base& x) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat tmp = (*this) / x.get_ref(); + + steal_mem(tmp); + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator%=(const Base& x) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U(x.get_ref()); + const Mat& B = U.M; + + arma_debug_assert_same_size(n_rows, n_cols, B.n_rows, B.n_cols, "element-wise multiplication"); + + sync_csc(); + invalidate_cache(); + + constexpr eT zero = eT(0); + + bool has_zero = false; + + for(uword c=0; c < n_cols; ++c) + { + const uword index_start = col_ptrs[c ]; + const uword index_end = col_ptrs[c + 1]; + + for(uword i=index_start; i < index_end; ++i) + { + const uword r = row_indices[i]; + + eT& val = access::rw(values[i]); + + const eT result = val * B.at(r,c); + + val = result; + + if(result == zero) { has_zero = true; } + } + } + + if(has_zero) { remove_zeros(); } + + return *this; + } + + + +template +template +inline +SpMat::SpMat(const Op& expr) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(expr); + } + + + +template +template +inline +SpMat& +SpMat::operator=(const Op& expr) + { + arma_extra_debug_sigprint(); + + const diagmat_proxy P(expr.m); + + const uword max_n_nonzero = (std::min)(P.n_rows, P.n_cols); + + // resize memory to upper bound + init(P.n_rows, P.n_cols, max_n_nonzero); + + uword count = 0; + + for(uword i=0; i < max_n_nonzero; ++i) + { + const eT val = P[i]; + + if(val != eT(0)) + { + access::rw(values[count]) = val; + access::rw(row_indices[count]) = i; + access::rw(col_ptrs[i + 1])++; + ++count; + } + } + + // fix column pointers to be cumulative + for(uword i = 1; i < n_cols + 1; ++i) + { + access::rw(col_ptrs[i]) += col_ptrs[i - 1]; + } + + // quick resize without reallocating memory and copying data + access::rw( n_nonzero) = count; + access::rw( values[count]) = eT(0); + access::rw(row_indices[count]) = uword(0); + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator+=(const Op& expr) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(expr); + + return (*this).operator+=(tmp); + } + + + +template +template +inline +SpMat& +SpMat::operator-=(const Op& expr) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(expr); + + return (*this).operator-=(tmp); + } + + + +template +template +inline +SpMat& +SpMat::operator*=(const Op& expr) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(expr); + + return (*this).operator*=(tmp); + } + + + +template +template +inline +SpMat& +SpMat::operator/=(const Op& expr) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(expr); + + return (*this).operator/=(tmp); + } + + + +template +template +inline +SpMat& +SpMat::operator%=(const Op& expr) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(expr); + + return (*this).operator%=(tmp); + } + + + +/** + * Functions on subviews. + */ +template +inline +SpMat::SpMat(const SpSubview& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(X); + } + + + +template +inline +SpMat& +SpMat::operator=(const SpSubview& X) + { + arma_extra_debug_sigprint(); + + if(X.n_nonzero == 0) { zeros(X.n_rows, X.n_cols); return *this; } + + X.m.sync_csc(); + + const bool alias = (this == &(X.m)); + + if(alias) + { + SpMat tmp(X); + + steal_mem(tmp); + } + else + { + init(X.n_rows, X.n_cols, X.n_nonzero); + + if(X.n_rows == X.m.n_rows) + { + const uword sv_col_start = X.aux_col1; + const uword sv_col_end = X.aux_col1 + X.n_cols - 1; + + typename SpMat::const_col_iterator m_it = X.m.begin_col_no_sync(sv_col_start); + typename SpMat::const_col_iterator m_it_end = X.m.end_col_no_sync(sv_col_end); + + uword count = 0; + + while(m_it != m_it_end) + { + const uword m_it_col_adjusted = m_it.col() - sv_col_start; + + access::rw(row_indices[count]) = m_it.row(); + access::rw(values[count]) = (*m_it); + ++access::rw(col_ptrs[m_it_col_adjusted + 1]); + + count++; + + ++m_it; + } + } + else + { + typename SpSubview::const_iterator it = X.begin(); + typename SpSubview::const_iterator it_end = X.end(); + + while(it != it_end) + { + const uword it_pos = it.pos(); + + access::rw(row_indices[it_pos]) = it.row(); + access::rw(values[it_pos]) = (*it); + ++access::rw(col_ptrs[it.col() + 1]); + ++it; + } + } + + // Now sum column pointers. + for(uword c = 1; c <= n_cols; ++c) + { + access::rw(col_ptrs[c]) += col_ptrs[c - 1]; + } + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator+=(const SpSubview& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat tmp = (*this) + X; + + steal_mem(tmp); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator-=(const SpSubview& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat tmp = (*this) - X; + + steal_mem(tmp); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator*=(const SpSubview& y) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat z = (*this) * y; + + steal_mem(z); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator%=(const SpSubview& x) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat tmp = (*this) % x; + + steal_mem(tmp); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator/=(const SpSubview& x) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division"); + + // There is no pretty way to do this. + for(uword elem = 0; elem < n_elem; elem++) + { + at(elem) /= x(elem); + } + + return *this; + } + + + +template +template +inline +SpMat::SpMat(const SpSubview_col_list& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + SpSubview_col_list::extract(*this, X); + } + + + +template +template +inline +SpMat& +SpMat::operator=(const SpSubview_col_list& X) + { + arma_extra_debug_sigprint(); + + const bool alias = (this == &(X.m)); + + if(alias == false) + { + SpSubview_col_list::extract(*this, X); + } + else + { + SpMat tmp(X); + + steal_mem(tmp); + } + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator+=(const SpSubview_col_list& X) + { + arma_extra_debug_sigprint(); + + SpSubview_col_list::plus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator-=(const SpSubview_col_list& X) + { + arma_extra_debug_sigprint(); + + SpSubview_col_list::minus_inplace(*this, X); + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator*=(const SpSubview_col_list& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + SpMat z = (*this) * X; + + steal_mem(z); + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator%=(const SpSubview_col_list& X) + { + arma_extra_debug_sigprint(); + + SpSubview_col_list::schur_inplace(*this, X); + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator/=(const SpSubview_col_list& X) + { + arma_extra_debug_sigprint(); + + SpSubview_col_list::div_inplace(*this, X); + + return *this; + } + + + +template +inline +SpMat::SpMat(const spdiagview& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + spdiagview::extract(*this, X); + } + + + +template +inline +SpMat& +SpMat::operator=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + spdiagview::extract(*this, X); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator+=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator+=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator-=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator-=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator*=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator*=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator%=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator%=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator/=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator/=(tmp); + } + + + +template +template +inline +SpMat::SpMat(const SpOp& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) // set in application of sparse operation + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + spop_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + } + + + +template +template +inline +SpMat& +SpMat::operator=(const SpOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + spop_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator+=(const SpOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator-=(const SpOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator*=(const SpOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator*=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator%=(const SpOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator/=(const SpOp& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator/=(m); + } + + + +template +template +inline +SpMat::SpMat(const SpGlue& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + spglue_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + } + + + +template +template +inline +SpMat& +SpMat::operator=(const SpGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + spglue_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator+=(const SpGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator-=(const SpGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator*=(const SpGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator*=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator%=(const SpGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator/=(const SpGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator/=(m); + } + + + +template +template +inline +SpMat::SpMat(const mtSpOp& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + spop_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + } + + + +template +template +inline +SpMat& +SpMat::operator=(const mtSpOp& X) + { + arma_extra_debug_sigprint(); + + spop_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator+=(const mtSpOp& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator-=(const mtSpOp& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator*=(const mtSpOp& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator*=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator%=(const mtSpOp& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator/=(const mtSpOp& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator/=(m); + } + + + +template +template +inline +SpMat::SpMat(const mtSpGlue& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + spglue_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + } + + + +template +template +inline +SpMat& +SpMat::operator=(const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + spglue_type::apply(*this, X); + + sync_csc(); // in case apply() used element accessors + invalidate_cache(); // in case apply() modified the CSC representation + + return *this; + } + + + +template +template +inline +SpMat& +SpMat::operator+=(const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator-=(const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator*=(const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator*=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator%=(const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +SpMat& +SpMat::operator/=(const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const SpMat m(X); + + return (*this).operator/=(m); + } + + + +template +arma_inline +SpSubview_row +SpMat::row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(row_num >= n_rows, "SpMat::row(): out of bounds"); + + return SpSubview_row(*this, row_num); + } + + + +template +arma_inline +const SpSubview_row +SpMat::row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(row_num >= n_rows, "SpMat::row(): out of bounds"); + + return SpSubview_row(*this, row_num); + } + + + +template +inline +SpSubview_row +SpMat::operator()(const uword row_num, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + (row_num >= n_rows) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "SpMat::operator(): indices out of bounds or incorrectly used" + ); + + return SpSubview_row(*this, row_num, in_col1, submat_n_cols); + } + + + +template +inline +const SpSubview_row +SpMat::operator()(const uword row_num, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + (row_num >= n_rows) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "SpMat::operator(): indices out of bounds or incorrectly used" + ); + + return SpSubview_row(*this, row_num, in_col1, submat_n_cols); + } + + + +template +arma_inline +SpSubview_col +SpMat::col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(col_num >= n_cols, "SpMat::col(): out of bounds"); + + return SpSubview_col(*this, col_num); + } + + + +template +arma_inline +const SpSubview_col +SpMat::col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(col_num >= n_cols, "SpMat::col(): out of bounds"); + + return SpSubview_col(*this, col_num); + } + + + +template +inline +SpSubview_col +SpMat::operator()(const span& row_span, const uword col_num) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds + ( + (col_num >= n_cols) + || + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "SpMat::operator(): indices out of bounds or incorrectly used" + ); + + return SpSubview_col(*this, col_num, in_row1, submat_n_rows); + } + + + +template +inline +const SpSubview_col +SpMat::operator()(const span& row_span, const uword col_num) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + arma_debug_check_bounds + ( + (col_num >= n_cols) + || + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "SpMat::operator(): indices out of bounds or incorrectly used" + ); + + return SpSubview_col(*this, col_num, in_row1, submat_n_rows); + } + + + +template +arma_inline +SpSubview +SpMat::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "SpMat::rows(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return SpSubview(*this, in_row1, 0, subview_n_rows, n_cols); + } + + + +template +arma_inline +const SpSubview +SpMat::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "SpMat::rows(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + return SpSubview(*this, in_row1, 0, subview_n_rows, n_cols); + } + + + +template +arma_inline +SpSubview +SpMat::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "SpMat::cols(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return SpSubview(*this, 0, in_col1, n_rows, subview_n_cols); + } + + + +template +arma_inline +const SpSubview +SpMat::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "SpMat::cols(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return SpSubview(*this, 0, in_col1, n_rows, subview_n_cols); + } + + + +template +arma_inline +SpSubview +SpMat::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "SpMat::submat(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return SpSubview(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); + } + + + +template +arma_inline +const SpSubview +SpMat::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "SpMat::submat(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword subview_n_cols = in_col2 - in_col1 + 1; + + return SpSubview(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); + } + + + +template +arma_inline +SpSubview +SpMat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "SpMat::submat(): indices or size out of bounds" + ); + + return SpSubview(*this, in_row1, in_col1, s_n_rows, s_n_cols); + } + + + +template +arma_inline +const SpSubview +SpMat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "SpMat::submat(): indices or size out of bounds" + ); + + return SpSubview(*this, in_row1, in_col1, s_n_rows, s_n_cols); + } + + + +template +inline +SpSubview +SpMat::submat(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "SpMat::submat(): indices out of bounds or incorrectly used" + ); + + return SpSubview(*this, in_row1, in_col1, submat_n_rows, submat_n_cols); + } + + + +template +inline +const SpSubview +SpMat::submat(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "SpMat::submat(): indices out of bounds or incorrectly used" + ); + + return SpSubview(*this, in_row1, in_col1, submat_n_rows, submat_n_cols); + } + + + +template +inline +SpSubview +SpMat::operator()(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + return submat(row_span, col_span); + } + + + +template +inline +const SpSubview +SpMat::operator()(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + return submat(row_span, col_span); + } + + + +template +arma_inline +SpSubview +SpMat::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).submat(in_row1, in_col1, s); + } + + + +template +arma_inline +const SpSubview +SpMat::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + return (*this).submat(in_row1, in_col1, s); + } + + + +template +inline +SpSubview +SpMat::head_rows(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "SpMat::head_rows(): size out of bounds" ); + + return SpSubview(*this, 0, 0, N, n_cols); + } + + + +template +inline +const SpSubview +SpMat::head_rows(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "SpMat::head_rows(): size out of bounds" ); + + return SpSubview(*this, 0, 0, N, n_cols); + } + + + +template +inline +SpSubview +SpMat::tail_rows(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "SpMat::tail_rows(): size out of bounds" ); + + const uword start_row = n_rows - N; + + return SpSubview(*this, start_row, 0, N, n_cols); + } + + + +template +inline +const SpSubview +SpMat::tail_rows(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_rows), "SpMat::tail_rows(): size out of bounds" ); + + const uword start_row = n_rows - N; + + return SpSubview(*this, start_row, 0, N, n_cols); + } + + + +template +inline +SpSubview +SpMat::head_cols(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "SpMat::head_cols(): size out of bounds" ); + + return SpSubview(*this, 0, 0, n_rows, N); + } + + + +template +inline +const SpSubview +SpMat::head_cols(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "SpMat::head_cols(): size out of bounds" ); + + return SpSubview(*this, 0, 0, n_rows, N); + } + + + +template +inline +SpSubview +SpMat::tail_cols(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "SpMat::tail_cols(): size out of bounds" ); + + const uword start_col = n_cols - N; + + return SpSubview(*this, 0, start_col, n_rows, N); + } + + + +template +inline +const SpSubview +SpMat::tail_cols(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > n_cols), "SpMat::tail_cols(): size out of bounds" ); + + const uword start_col = n_cols - N; + + return SpSubview(*this, 0, start_col, n_rows, N); + } + + + +template +template +arma_inline +SpSubview_col_list +SpMat::cols(const Base& indices) + { + arma_extra_debug_sigprint(); + + return SpSubview_col_list(*this, indices); + } + + + +template +template +arma_inline +const SpSubview_col_list +SpMat::cols(const Base& indices) const + { + arma_extra_debug_sigprint(); + + return SpSubview_col_list(*this, indices); + } + + + +//! creation of spdiagview (diagonal) +template +inline +spdiagview +SpMat::diag(const sword in_id) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (in_id < 0) ? uword(-in_id) : 0; + const uword col_offset = (in_id > 0) ? uword( in_id) : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "SpMat::diag(): requested diagonal out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + return spdiagview(*this, row_offset, col_offset, len); + } + + + +//! creation of spdiagview (diagonal) +template +inline +const spdiagview +SpMat::diag(const sword in_id) const + { + arma_extra_debug_sigprint(); + + const uword row_offset = uword( (in_id < 0) ? -in_id : 0 ); + const uword col_offset = uword( (in_id > 0) ? in_id : 0 ); + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "SpMat::diag(): requested diagonal out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + return spdiagview(*this, row_offset, col_offset, len); + } + + + +template +inline +void +SpMat::swap_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ((in_row1 >= n_rows) || (in_row2 >= n_rows)), "SpMat::swap_rows(): out of bounds" ); + + if(in_row1 == in_row2) { return; } + + sync_csc(); + invalidate_cache(); + + // The easier way to do this, instead of collecting all the elements in one row and then swapping with the other, will be + // to iterate over each column of the matrix (since we store in column-major format) and then swap the two elements in the two rows at that time. + // We will try to avoid using the at() call since it is expensive, instead preferring to use an iterator to track our position. + uword col1 = (in_row1 < in_row2) ? in_row1 : in_row2; + uword col2 = (in_row1 < in_row2) ? in_row2 : in_row1; + + for(uword lcol = 0; lcol < n_cols; lcol++) + { + // If there is nothing in this column we can ignore it. + if(col_ptrs[lcol] == col_ptrs[lcol + 1]) + { + continue; + } + + // These will represent the positions of the items themselves. + uword loc1 = n_nonzero + 1; + uword loc2 = n_nonzero + 1; + + for(uword search_pos = col_ptrs[lcol]; search_pos < col_ptrs[lcol + 1]; search_pos++) + { + if(row_indices[search_pos] == col1) + { + loc1 = search_pos; + } + + if(row_indices[search_pos] == col2) + { + loc2 = search_pos; + break; // No need to look any further. + } + } + + // There are four cases: we found both elements; we found one element (loc1); we found one element (loc2); we found zero elements. + // If we found zero elements no work needs to be done and we can continue to the next column. + if((loc1 != (n_nonzero + 1)) && (loc2 != (n_nonzero + 1))) + { + // This is an easy case: just swap the values. No index modifying necessary. + eT tmp = values[loc1]; + access::rw(values[loc1]) = values[loc2]; + access::rw(values[loc2]) = tmp; + } + else if(loc1 != (n_nonzero + 1)) // We only found loc1 and not loc2. + { + // We need to find the correct place to move our value to. It will be forward (not backwards) because in_row2 > in_row1. + // Each iteration of the loop swaps the current value (loc1) with (loc1 + 1); in this manner we move our value down to where it should be. + while(((loc1 + 1) < col_ptrs[lcol + 1]) && (row_indices[loc1 + 1] < in_row2)) + { + // Swap both the values and the indices. The column should not change. + eT tmp = values[loc1]; + access::rw(values[loc1]) = values[loc1 + 1]; + access::rw(values[loc1 + 1]) = tmp; + + uword tmp_index = row_indices[loc1]; + access::rw(row_indices[loc1]) = row_indices[loc1 + 1]; + access::rw(row_indices[loc1 + 1]) = tmp_index; + + loc1++; // And increment the counter. + } + + // Now set the row index correctly. + access::rw(row_indices[loc1]) = in_row2; + + } + else if(loc2 != (n_nonzero + 1)) + { + // We need to find the correct place to move our value to. It will be backwards (not forwards) because in_row1 < in_row2. + // Each iteration of the loop swaps the current value (loc2) with (loc2 - 1); in this manner we move our value up to where it should be. + while(((loc2 - 1) >= col_ptrs[lcol]) && (row_indices[loc2 - 1] > in_row1)) + { + // Swap both the values and the indices. The column should not change. + eT tmp = values[loc2]; + access::rw(values[loc2]) = values[loc2 - 1]; + access::rw(values[loc2 - 1]) = tmp; + + uword tmp_index = row_indices[loc2]; + access::rw(row_indices[loc2]) = row_indices[loc2 - 1]; + access::rw(row_indices[loc2 - 1]) = tmp_index; + + loc2--; // And decrement the counter. + } + + // Now set the row index correctly. + access::rw(row_indices[loc2]) = in_row1; + + } + /* else: no need to swap anything; both values are zero */ + } + } + + + +template +inline +void +SpMat::swap_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ((in_col1 >= n_cols) || (in_col2 >= n_cols)), "SpMat::swap_cols(): out of bounds" ); + + if(in_col1 == in_col2) { return; } + + // TODO: this is a rudimentary implementation + + const SpMat tmp1 = (*this).col(in_col1); + const SpMat tmp2 = (*this).col(in_col2); + + (*this).col(in_col2) = tmp1; + (*this).col(in_col1) = tmp2; + + // for(uword lrow = 0; lrow < n_rows; ++lrow) + // { + // const eT tmp = at(lrow, in_col1); + // at(lrow, in_col1) = eT( at(lrow, in_col2) ); + // at(lrow, in_col2) = tmp; + // } + } + + + +template +inline +void +SpMat::shed_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(row_num >= n_rows, "SpMat::shed_row(): out of bounds"); + + shed_rows (row_num, row_num); + } + + + +template +inline +void +SpMat::shed_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(col_num >= n_cols, "SpMat::shed_col(): out of bounds"); + + shed_cols(col_num, col_num); + } + + + +template +inline +void +SpMat::shed_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "SpMat::shed_rows(): indices out of bounds or incorectly used" + ); + + sync_csc(); + + SpMat newmat(n_rows - (in_row2 - in_row1 + 1), n_cols); + + // First, count the number of elements we will be removing. + uword removing = 0; + for(uword i = 0; i < n_nonzero; ++i) + { + const uword lrow = row_indices[i]; + if(lrow >= in_row1 && lrow <= in_row2) + { + ++removing; + } + } + + // Obtain counts of the number of points in each column and store them as the + // (invalid) column pointers of the new matrix. + for(uword i = 1; i < n_cols + 1; ++i) + { + access::rw(newmat.col_ptrs[i]) = col_ptrs[i] - col_ptrs[i - 1]; + } + + // Now initialize memory for the new matrix. + newmat.mem_resize(n_nonzero - removing); + + // Now, copy over the elements. + // i is the index in the old matrix; j is the index in the new matrix. + const_iterator it = cbegin(); + const_iterator it_end = cend(); + + uword j = 0; // The index in the new matrix. + while(it != it_end) + { + const uword lrow = it.row(); + const uword lcol = it.col(); + + if(lrow >= in_row1 && lrow <= in_row2) + { + // This element is being removed. Subtract it from the column counts. + --access::rw(newmat.col_ptrs[lcol + 1]); + } + else + { + // This element is being kept. We may need to map the row index, + // if it is past the section of rows we are removing. + if(lrow > in_row2) + { + access::rw(newmat.row_indices[j]) = lrow - (in_row2 - in_row1 + 1); + } + else + { + access::rw(newmat.row_indices[j]) = lrow; + } + + access::rw(newmat.values[j]) = (*it); + ++j; // Increment index in new matrix. + } + + ++it; + } + + // Finally, sum the column counts so they are correct column pointers. + for(uword i = 1; i < n_cols + 1; ++i) + { + access::rw(newmat.col_ptrs[i]) += newmat.col_ptrs[i - 1]; + } + + // Now steal the memory of the new matrix. + steal_mem(newmat); + } + + + +template +inline +void +SpMat::shed_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "SpMat::shed_cols(): indices out of bounds or incorrectly used" + ); + + sync_csc(); + invalidate_cache(); + + // First we find the locations in values and row_indices for the column entries. + uword col_beg = col_ptrs[in_col1]; + uword col_end = col_ptrs[in_col2 + 1]; + + // Then we find the number of entries in the column. + uword diff = col_end - col_beg; + + if(diff > 0) + { + eT* new_values = memory::acquire (n_nonzero + 1 - diff); + uword* new_row_indices = memory::acquire(n_nonzero + 1 - diff); + + // Copy first part. + if(col_beg != 0) + { + arrayops::copy(new_values, values, col_beg); + arrayops::copy(new_row_indices, row_indices, col_beg); + } + + // Copy second part. + if(col_end != n_nonzero) + { + arrayops::copy(new_values + col_beg, values + col_end, n_nonzero - col_end); + arrayops::copy(new_row_indices + col_beg, row_indices + col_end, n_nonzero - col_end); + } + + // Copy sentry element. + new_values[n_nonzero - diff] = values[n_nonzero]; + new_row_indices[n_nonzero - diff] = row_indices[n_nonzero]; + + if(values) { memory::release(access::rw(values)); } + if(row_indices) { memory::release(access::rw(row_indices)); } + + access::rw(values) = new_values; + access::rw(row_indices) = new_row_indices; + + // Update counts and such. + access::rw(n_nonzero) -= diff; + } + + // Update column pointers. + const uword new_n_cols = n_cols - ((in_col2 - in_col1) + 1); + + uword* new_col_ptrs = memory::acquire(new_n_cols + 2); + new_col_ptrs[new_n_cols + 1] = std::numeric_limits::max(); + + // Copy first set of columns (no manipulation required). + if(in_col1 != 0) + { + arrayops::copy(new_col_ptrs, col_ptrs, in_col1); + } + + // Copy second set of columns (manipulation required). + uword cur_col = in_col1; + for(uword i = in_col2 + 1; i <= n_cols; ++i, ++cur_col) + { + new_col_ptrs[cur_col] = col_ptrs[i] - diff; + } + + if(col_ptrs) { memory::release(access::rw(col_ptrs)); } + access::rw(col_ptrs) = new_col_ptrs; + + // We update the element and column counts, and we're done. + access::rw(n_cols) = new_n_cols; + access::rw(n_elem) = n_cols * n_rows; + } + + + +/** + * Element access; acces the i'th element (works identically to the Mat accessors). + * If there is nothing at element i, 0 is returned. + */ + +template +arma_inline +SpMat_MapMat_val +SpMat::operator[](const uword i) + { + const uword in_col = i / n_rows; + const uword in_row = i % n_rows; + + return SpMat_MapMat_val((*this), cache, in_row, in_col); + } + + + +template +arma_inline +eT +SpMat::operator[](const uword i) const + { + return get_value(i); + } + + + +template +arma_inline +SpMat_MapMat_val +SpMat::at(const uword i) + { + const uword in_col = i / n_rows; + const uword in_row = i % n_rows; + + return SpMat_MapMat_val((*this), cache, in_row, in_col); + } + + + +template +arma_inline +eT +SpMat::at(const uword i) const + { + return get_value(i); + } + + + +template +arma_inline +SpMat_MapMat_val +SpMat::operator()(const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "SpMat::operator(): out of bounds" ); + + const uword in_col = i / n_rows; + const uword in_row = i % n_rows; + + return SpMat_MapMat_val((*this), cache, in_row, in_col); + } + + + +template +arma_inline +eT +SpMat::operator()(const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "SpMat::operator(): out of bounds" ); + + return get_value(i); + } + + + +/** + * Element access; access the element at row in_rows and column in_col. + * If there is nothing at that position, 0 is returned. + */ + +#if defined(__cpp_multidimensional_subscript) + + template + arma_inline + SpMat_MapMat_val + SpMat::operator[] (const uword in_row, const uword in_col) + { + return SpMat_MapMat_val((*this), cache, in_row, in_col); + } + + + + template + arma_inline + eT + SpMat::operator[] (const uword in_row, const uword in_col) const + { + return get_value(in_row, in_col); + } + +#endif + + + +template +arma_inline +SpMat_MapMat_val +SpMat::at(const uword in_row, const uword in_col) + { + return SpMat_MapMat_val((*this), cache, in_row, in_col); + } + + + +template +arma_inline +eT +SpMat::at(const uword in_row, const uword in_col) const + { + return get_value(in_row, in_col); + } + + + +template +arma_inline +SpMat_MapMat_val +SpMat::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds" ); + + return SpMat_MapMat_val((*this), cache, in_row, in_col); + } + + + +template +arma_inline +eT +SpMat::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds" ); + + return get_value(in_row, in_col); + } + + + +/** + * Check if matrix is empty (no size, no values). + */ +template +arma_inline +bool +SpMat::is_empty() const + { + return (n_elem == 0); + } + + + +//! returns true if the object can be interpreted as a column or row vector +template +arma_inline +bool +SpMat::is_vec() const + { + return ( (n_rows == 1) || (n_cols == 1) ); + } + + + +//! returns true if the object can be interpreted as a row vector +template +arma_inline +bool +SpMat::is_rowvec() const + { + return (n_rows == 1); + } + + + +//! returns true if the object can be interpreted as a column vector +template +arma_inline +bool +SpMat::is_colvec() const + { + return (n_cols == 1); + } + + + +//! returns true if the object has the same number of non-zero rows and columnns +template +arma_inline +bool +SpMat::is_square() const + { + return (n_rows == n_cols); + } + + + +template +inline +bool +SpMat::is_symmetric() const + { + arma_extra_debug_sigprint(); + + const SpMat& A = (*this); + + if(A.n_rows != A.n_cols) { return false; } + + const SpMat tmp = A - A.st(); + + return (tmp.n_nonzero == uword(0)); + } + + + +template +inline +bool +SpMat::is_symmetric(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(tol == T(0)) { return (*this).is_symmetric(); } + + arma_debug_check( (tol < T(0)), "is_symmetric(): parameter 'tol' must be >= 0" ); + + const SpMat& A = (*this); + + if(A.n_rows != A.n_cols) { return false; } + + const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) ); + + if(norm_A == T(0)) { return true; } + + const T norm_A_Ast = as_scalar( arma::max(sum(abs(A - A.st()), 1), 0) ); + + return ( (norm_A_Ast / norm_A) <= tol ); + } + + + +template +inline +bool +SpMat::is_hermitian() const + { + arma_extra_debug_sigprint(); + + const SpMat& A = (*this); + + if(A.n_rows != A.n_cols) { return false; } + + const SpMat tmp = A - A.t(); + + return (tmp.n_nonzero == uword(0)); + } + + + +template +inline +bool +SpMat::is_hermitian(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(tol == T(0)) { return (*this).is_hermitian(); } + + arma_debug_check( (tol < T(0)), "is_hermitian(): parameter 'tol' must be >= 0" ); + + const SpMat& A = (*this); + + if(A.n_rows != A.n_cols) { return false; } + + const T norm_A = as_scalar( arma::max(sum(abs(A), 1), 0) ); + + if(norm_A == T(0)) { return true; } + + const T norm_A_At = as_scalar( arma::max(sum(abs(A - A.t()), 1), 0) ); + + return ( (norm_A_At / norm_A) <= tol ); + } + + + +template +inline +bool +SpMat::internal_is_finite() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return arrayops::is_finite(values, n_nonzero); + } + + + +template +inline +bool +SpMat::internal_has_inf() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return arrayops::has_inf(values, n_nonzero); + } + + + +template +inline +bool +SpMat::internal_has_nan() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return arrayops::has_nan(values, n_nonzero); + } + + + +template +inline +bool +SpMat::internal_has_nonfinite() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return (arrayops::is_finite(values, n_nonzero) == false); + } + + + +//! returns true if the given index is currently in range +template +arma_inline +bool +SpMat::in_range(const uword i) const + { + return (i < n_elem); + } + + +//! returns true if the given start and end indices are currently in range +template +arma_inline +bool +SpMat::in_range(const span& x) const + { + arma_extra_debug_sigprint(); + + if(x.whole) + { + return true; + } + else + { + const uword a = x.a; + const uword b = x.b; + + return ( (a <= b) && (b < n_elem) ); + } + } + + + +//! returns true if the given location is currently in range +template +arma_inline +bool +SpMat::in_range(const uword in_row, const uword in_col) const + { + return ( (in_row < n_rows) && (in_col < n_cols) ); + } + + + +template +arma_inline +bool +SpMat::in_range(const span& row_span, const uword in_col) const + { + arma_extra_debug_sigprint(); + + if(row_span.whole) + { + return (in_col < n_cols); + } + else + { + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) ); + } + } + + + +template +arma_inline +bool +SpMat::in_range(const uword in_row, const span& col_span) const + { + arma_extra_debug_sigprint(); + + if(col_span.whole) + { + return (in_row < n_rows); + } + else + { + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) ); + } + } + + + +template +arma_inline +bool +SpMat::in_range(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) ); + const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) ); + + return ( rows_ok && cols_ok ); + } + + + +template +arma_inline +bool +SpMat::in_range(const uword in_row, const uword in_col, const SizeMat& s) const + { + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + if( (in_row >= l_n_rows) || (in_col >= l_n_cols) || ((in_row + s.n_rows) > l_n_rows) || ((in_col + s.n_cols) > l_n_cols) ) + { + return false; + } + else + { + return true; + } + } + + + +//! Set the size to the size of another matrix. +template +template +inline +SpMat& +SpMat::copy_size(const SpMat& m) + { + arma_extra_debug_sigprint(); + + return set_size(m.n_rows, m.n_cols); + } + + + +template +template +inline +SpMat& +SpMat::copy_size(const Mat& m) + { + arma_extra_debug_sigprint(); + + return set_size(m.n_rows, m.n_cols); + } + + + +template +inline +SpMat& +SpMat::set_size(const uword in_elem) + { + arma_extra_debug_sigprint(); + + // If this is a row vector, we resize to a row vector. + if(vec_state == 2) + { + set_size(1, in_elem); + } + else + { + set_size(in_elem, 1); + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::set_size(const uword in_rows, const uword in_cols) + { + arma_extra_debug_sigprint(); + + invalidate_cache(); // placed here, as set_size() is used during matrix modification + + if( (n_rows == in_rows) && (n_cols == in_cols) ) { return *this; } + + init(in_rows, in_cols); + + return *this; + } + + + +template +inline +SpMat& +SpMat::set_size(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).set_size(s.n_rows, s.n_cols); + } + + + +template +inline +SpMat& +SpMat::resize(const uword in_rows, const uword in_cols) + { + arma_extra_debug_sigprint(); + + if( (n_rows == in_rows) && (n_cols == in_cols) ) { return *this; } + + if( (n_elem == 0) || (n_nonzero == 0) ) { return set_size(in_rows, in_cols); } + + SpMat tmp(in_rows, in_cols); + + if(tmp.n_elem > 0) + { + sync_csc(); + + const uword last_row = (std::min)(in_rows, n_rows) - 1; + const uword last_col = (std::min)(in_cols, n_cols) - 1; + + tmp.submat(0, 0, last_row, last_col) = (*this).submat(0, 0, last_row, last_col); + } + + steal_mem(tmp); + + return *this; + } + + + +template +inline +SpMat& +SpMat::resize(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).resize(s.n_rows, s.n_cols); + } + + + +template +inline +SpMat& +SpMat::reshape(const uword in_rows, const uword in_cols) + { + arma_extra_debug_sigprint(); + + arma_check( ((in_rows*in_cols) != n_elem), "SpMat::reshape(): changing the number of elements in a sparse matrix is currently not supported" ); + + if( (n_rows == in_rows) && (n_cols == in_cols) ) { return *this; } + + if(vec_state == 1) { arma_debug_check( (in_cols != 1), "SpMat::reshape(): object is a column vector; requested size is not compatible" ); } + if(vec_state == 2) { arma_debug_check( (in_rows != 1), "SpMat::reshape(): object is a row vector; requested size is not compatible" ); } + + if(n_nonzero == 0) { return (*this).zeros(in_rows, in_cols); } + + if(in_cols == 1) + { + (*this).reshape_helper_intovec(); + } + else + { + (*this).reshape_helper_generic(in_rows, in_cols); + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::reshape(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).reshape(s.n_rows, s.n_cols); + } + + + +template +inline +void +SpMat::reshape_helper_generic(const uword in_rows, const uword in_cols) + { + arma_extra_debug_sigprint(); + + sync_csc(); + invalidate_cache(); + + // We have to modify all of the relevant row indices and the relevant column pointers. + // Iterate over all the points to do this. We won't be deleting any points, but we will be modifying + // columns and rows. We'll have to store a new set of column vectors. + uword* new_col_ptrs = memory::acquire(in_cols + 2); + new_col_ptrs[in_cols + 1] = std::numeric_limits::max(); + + uword* new_row_indices = memory::acquire(n_nonzero + 1); + access::rw(new_row_indices[n_nonzero]) = 0; + + arrayops::fill_zeros(new_col_ptrs, in_cols + 1); + + const_iterator it = cbegin(); + const_iterator it_end = cend(); + + for(; it != it_end; ++it) + { + uword vector_position = (it.col() * n_rows) + it.row(); + new_row_indices[it.pos()] = vector_position % in_rows; + ++new_col_ptrs[vector_position / in_rows + 1]; + } + + // Now sum the column counts to get the new column pointers. + for(uword i = 1; i <= in_cols; i++) + { + access::rw(new_col_ptrs[i]) += new_col_ptrs[i - 1]; + } + + // Copy the new row indices. + if(row_indices) { memory::release(access::rw(row_indices)); } + if(col_ptrs) { memory::release(access::rw(col_ptrs)); } + + access::rw(row_indices) = new_row_indices; + access::rw(col_ptrs) = new_col_ptrs; + + // Now set the size. + access::rw(n_rows) = in_rows; + access::rw(n_cols) = in_cols; + } + + + +template +inline +void +SpMat::reshape_helper_intovec() + { + arma_extra_debug_sigprint(); + + sync_csc(); + invalidate_cache(); + + const_iterator it = cbegin(); + + const uword t_n_rows = n_rows; + const uword t_n_nonzero = n_nonzero; + + for(uword i=0; i < t_n_nonzero; ++i) + { + const uword t_index = (it.col() * t_n_rows) + it.row(); + + // ensure the iterator is pointing to the next element + // before we overwrite the row index of the current element + ++it; + + access::rw(row_indices[i]) = t_index; + } + + access::rw(row_indices[n_nonzero]) = 0; + + access::rw(col_ptrs[0]) = 0; + access::rw(col_ptrs[1]) = n_nonzero; + access::rw(col_ptrs[2]) = std::numeric_limits::max(); + + access::rw(n_rows) = (n_rows * n_cols); + access::rw(n_cols) = 1; + } + + + +//! apply a functor to each non-zero element +template +template +inline +SpMat& +SpMat::for_each(functor F) + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const uword N = (*this).n_nonzero; + + eT* rw_values = access::rwp(values); + + bool modified = false; + bool has_zero = false; + + for(uword i=0; i < N; ++i) + { + eT& new_value = rw_values[i]; + const eT old_value = new_value; + + F(new_value); + + if(new_value != old_value) { modified = true; } + if(new_value == eT(0) ) { has_zero = true; } + } + + if(modified) { invalidate_cache(); } + if(has_zero) { remove_zeros(); } + + return *this; + } + + + +template +template +inline +const SpMat& +SpMat::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + const uword N = (*this).n_nonzero; + + for(uword i=0; i < N; ++i) { F(values[i]); } + + return *this; + } + + + +//! transform each non-zero element using a functor +template +template +inline +SpMat& +SpMat::transform(functor F) + { + arma_extra_debug_sigprint(); + + sync_csc(); + invalidate_cache(); + + const uword N = (*this).n_nonzero; + + eT* rw_values = access::rwp(values); + + bool has_zero = false; + + for(uword i=0; i < N; ++i) + { + eT& rw_values_i = rw_values[i]; + + rw_values_i = eT( F(rw_values_i) ); + + if(rw_values_i == eT(0)) { has_zero = true; } + } + + if(has_zero) { remove_zeros(); } + + return *this; + } + + + +template +inline +SpMat& +SpMat::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + if(old_val == eT(0)) + { + arma_debug_warn_level(1, "SpMat::replace(): replacement not done, as old_val = 0"); + } + else + { + sync_csc(); + invalidate_cache(); + + arrayops::replace(access::rwp(values), n_nonzero, old_val, new_val); + + if(new_val == eT(0)) { remove_zeros(); } + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + if(n_nonzero == 0) { return *this; } + + sync_csc(); + invalidate_cache(); + + arrayops::clean(access::rwp(values), n_nonzero, threshold); + + remove_zeros(); + + return *this; + } + + + +template +inline +SpMat& +SpMat::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpMat::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpMat::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "SpMat::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + if(n_nonzero == 0) { return *this; } + + sync_csc(); + invalidate_cache(); + + arrayops::clamp(access::rwp(values), n_nonzero, min_val, max_val); + + if( (min_val == eT(0)) || (max_val == eT(0)) ) { remove_zeros(); } + + return *this; + } + + + +template +inline +SpMat& +SpMat::zeros() + { + arma_extra_debug_sigprint(); + + if((n_nonzero == 0) && (values != nullptr)) + { + invalidate_cache(); + } + else + { + init(n_rows, n_cols); + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::zeros(const uword in_elem) + { + arma_extra_debug_sigprint(); + + if(vec_state == 2) + { + zeros(1, in_elem); // Row vector + } + else + { + zeros(in_elem, 1); + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::zeros(const uword in_rows, const uword in_cols) + { + arma_extra_debug_sigprint(); + + if((n_nonzero == 0) && (n_rows == in_rows) && (n_cols == in_cols) && (values != nullptr)) + { + invalidate_cache(); + } + else + { + init(in_rows, in_cols); + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::zeros(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).zeros(s.n_rows, s.n_cols); + } + + + +template +inline +SpMat& +SpMat::eye() + { + arma_extra_debug_sigprint(); + + return (*this).eye(n_rows, n_cols); + } + + + +template +inline +SpMat& +SpMat::eye(const uword in_rows, const uword in_cols) + { + arma_extra_debug_sigprint(); + + const uword N = (std::min)(in_rows, in_cols); + + init(in_rows, in_cols, N); + + arrayops::inplace_set(access::rwp(values), eT(1), N); + + for(uword i = 0; i < N; ++i) { access::rw(row_indices[i]) = i; } + + for(uword i = 0; i <= N; ++i) { access::rw(col_ptrs[i]) = i; } + + // take into account non-square matrices + for(uword i = (N+1); i <= in_cols; ++i) { access::rw(col_ptrs[i]) = N; } + + access::rw(n_nonzero) = N; + + return *this; + } + + + +template +inline +SpMat& +SpMat::eye(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).eye(s.n_rows, s.n_cols); + } + + + +template +inline +SpMat& +SpMat::speye() + { + arma_extra_debug_sigprint(); + + return (*this).eye(n_rows, n_cols); + } + + + +template +inline +SpMat& +SpMat::speye(const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + return (*this).eye(in_n_rows, in_n_cols); + } + + + +template +inline +SpMat& +SpMat::speye(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).eye(s.n_rows, s.n_cols); + } + + + +template +inline +SpMat& +SpMat::sprandu(const uword in_rows, const uword in_cols, const double density) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandu(): density must be in the [0,1] interval" ); + + const uword new_n_nonzero = uword(density * double(in_rows) * double(in_cols) + 0.5); + + init(in_rows, in_cols, new_n_nonzero); + + if(new_n_nonzero == 0) { return *this; } + + arma_rng::randu::fill( access::rwp(values), new_n_nonzero ); + + uvec indices = linspace( 0u, in_rows*in_cols-1, new_n_nonzero ); + + // perturb the indices + for(uword i=1; i < new_n_nonzero-1; ++i) + { + const uword index_left = indices[i-1]; + const uword index_right = indices[i+1]; + + const uword center = (index_left + index_right) / 2; + + const uword delta1 = center - index_left - 1; + const uword delta2 = index_right - center - 1; + + const uword min_delta = (std::min)(delta1, delta2); + + uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) ); + + // paranoia, but better be safe than sorry + if( (index_left < index_new) && (index_new < index_right) ) + { + indices[i] = index_new; + } + } + + uword cur_index = 0; + uword count = 0; + + for(uword lcol = 0; lcol < in_cols; ++lcol) + for(uword lrow = 0; lrow < in_rows; ++lrow) + { + if(count == indices[cur_index]) + { + access::rw(row_indices[cur_index]) = lrow; + access::rw(col_ptrs[lcol + 1])++; + ++cur_index; + } + + ++count; + } + + if(cur_index != new_n_nonzero) + { + // Fix size to correct size. + mem_resize(cur_index); + } + + // Sum column pointers. + for(uword lcol = 1; lcol <= in_cols; ++lcol) + { + access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1]; + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::sprandu(const SizeMat& s, const double density) + { + arma_extra_debug_sigprint(); + + return (*this).sprandu(s.n_rows, s.n_cols, density); + } + + + +template +inline +SpMat& +SpMat::sprandn(const uword in_rows, const uword in_cols, const double density) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandn(): density must be in the [0,1] interval" ); + + const uword new_n_nonzero = uword(density * double(in_rows) * double(in_cols) + 0.5); + + init(in_rows, in_cols, new_n_nonzero); + + if(new_n_nonzero == 0) { return *this; } + + arma_rng::randn::fill( access::rwp(values), new_n_nonzero ); + + uvec indices = linspace( 0u, in_rows*in_cols-1, new_n_nonzero ); + + // perturb the indices + for(uword i=1; i < new_n_nonzero-1; ++i) + { + const uword index_left = indices[i-1]; + const uword index_right = indices[i+1]; + + const uword center = (index_left + index_right) / 2; + + const uword delta1 = center - index_left - 1; + const uword delta2 = index_right - center - 1; + + const uword min_delta = (std::min)(delta1, delta2); + + uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) ); + + // paranoia, but better be safe than sorry + if( (index_left < index_new) && (index_new < index_right) ) + { + indices[i] = index_new; + } + } + + uword cur_index = 0; + uword count = 0; + + for(uword lcol = 0; lcol < in_cols; ++lcol) + for(uword lrow = 0; lrow < in_rows; ++lrow) + { + if(count == indices[cur_index]) + { + access::rw(row_indices[cur_index]) = lrow; + access::rw(col_ptrs[lcol + 1])++; + ++cur_index; + } + + ++count; + } + + if(cur_index != new_n_nonzero) + { + // Fix size to correct size. + mem_resize(cur_index); + } + + // Sum column pointers. + for(uword lcol = 1; lcol <= in_cols; ++lcol) + { + access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1]; + } + + return *this; + } + + + +template +inline +SpMat& +SpMat::sprandn(const SizeMat& s, const double density) + { + arma_extra_debug_sigprint(); + + return (*this).sprandn(s.n_rows, s.n_cols, density); + } + + + +template +inline +void +SpMat::reset() + { + arma_extra_debug_sigprint(); + + switch(vec_state) + { + default: init(0, 0); break; + case 1: init(0, 1); break; + case 2: init(1, 0); break; + } + } + + + +template +inline +void +SpMat::reset_cache() + { + arma_extra_debug_sigprint(); + + sync_csc(); + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_SpMat_cache) + { + cache.reset(); + + sync_state = 0; + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(cache_mutex); + + cache.reset(); + + sync_state = 0; + } + #else + { + cache.reset(); + + sync_state = 0; + } + #endif + } + + + +template +inline +void +SpMat::reserve(const uword in_rows, const uword in_cols, const uword new_n_nonzero) + { + arma_extra_debug_sigprint(); + + init(in_rows, in_cols, new_n_nonzero); + } + + + +template +template +inline +void +SpMat::set_real(const SpBase::pod_type,T1>& X) + { + arma_extra_debug_sigprint(); + + SpMat_aux::set_real(*this, X); + } + + + +template +template +inline +void +SpMat::set_imag(const SpBase::pod_type,T1>& X) + { + arma_extra_debug_sigprint(); + + SpMat_aux::set_imag(*this, X); + } + + + +//! save the matrix to a file +template +inline +bool +SpMat::save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + bool save_okay; + + switch(type) + { + case csv_ascii: + return (*this).save(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).save(csv_name(name), type); + break; + + case arma_binary: + save_okay = diskio::save_arma_binary(*this, name); + break; + + case coord_ascii: + save_okay = diskio::save_coord_ascii(*this, name); + break; + + default: + arma_debug_warn_level(1, "SpMat::save(): unsupported file type"); + save_okay = false; + } + + if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): write failed; file: ", name); } + + return save_okay; + } + + + +template +inline +bool +SpMat::save(const csv_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("SpMat::save(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + + arma_extra_debug_print("SpMat::save(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + if(with_header) + { + if( (spec.header_ro.n_cols != 1) && (spec.header_ro.n_rows != 1) ) + { + arma_debug_warn_level(1, "SpMat::save(): given header must have a vector layout"); + return false; + } + + for(uword i=0; i < spec.header_ro.n_elem; ++i) + { + const std::string& token = spec.header_ro.at(i); + + if(token.find(separator) != std::string::npos) + { + arma_debug_warn_level(1, "SpMat::save(): token within the header contains the separator character: '", token, "'"); + return false; + } + } + + const uword save_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(spec.header_ro.n_elem != save_n_cols) + { + arma_debug_warn_level(1, "SpMat::save(): size mismatch between header and matrix"); + return false; + } + } + + bool save_okay = false; + + if(do_trans) + { + const SpMat tmp = (*this).st(); + + save_okay = diskio::save_csv_ascii(tmp, spec.filename, spec.header_ro, with_header, separator); + } + else + { + save_okay = diskio::save_csv_ascii(*this, spec.filename, spec.header_ro, with_header, separator); + } + + if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): write failed; file: ", spec.filename); } + + return save_okay; + } + + + +//! save the matrix to a stream +template +inline +bool +SpMat::save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + bool save_okay; + + switch(type) + { + case csv_ascii: + save_okay = diskio::save_csv_ascii(*this, os, char(',')); + break; + + case ssv_ascii: + save_okay = diskio::save_csv_ascii(*this, os, char(';')); + break; + + case arma_binary: + save_okay = diskio::save_arma_binary(*this, os); + break; + + case coord_ascii: + save_okay = diskio::save_coord_ascii(*this, os); + break; + + default: + arma_debug_warn_level(1, "SpMat::save(): unsupported file type"); + save_okay = false; + } + + if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): stream write failed"); } + + return save_okay; + } + + + +//! load a matrix from a file +template +inline +bool +SpMat::load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + invalidate_cache(); + + bool load_okay; + std::string err_msg; + + switch(type) + { + // case auto_detect: + // load_okay = diskio::load_auto_detect(*this, name, err_msg); + // break; + + case csv_ascii: + return (*this).load(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).load(csv_name(name), type); + break; + + case arma_binary: + load_okay = diskio::load_arma_binary(*this, name, err_msg); + break; + + case coord_ascii: + load_okay = diskio::load_coord_ascii(*this, name, err_msg); + break; + + default: + arma_debug_warn_level(1, "SpMat::load(): unsupported file type"); + load_okay = false; + } + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "SpMat::load(): ", err_msg, "; file: ", name); + } + else + { + arma_debug_warn_level(3, "SpMat::load(): read failed; file: ", name); + } + } + + if(load_okay == false) { (*this).reset(); } + + return load_okay; + } + + + +template +inline +bool +SpMat::load(const csv_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("SpMat::load(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + const bool strict = bool(spec.opts.flags & csv_opts::flag_strict ); + + arma_extra_debug_print("SpMat::load(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + if(strict ) { arma_extra_debug_print("strict"); } + + if(strict) { arma_debug_warn_level(1, "SpMat::load(): option 'strict' not implemented for sparse matrices"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + bool load_okay = false; + std::string err_msg; + + if(do_trans) + { + SpMat tmp_mat; + + load_okay = diskio::load_csv_ascii(tmp_mat, spec.filename, err_msg, spec.header_rw, with_header, separator); + + if(load_okay) + { + (*this) = tmp_mat.st(); + + if(with_header) + { + // field::set_size() preserves data if the number of elements hasn't changed + spec.header_rw.set_size(spec.header_rw.n_elem, 1); + } + } + } + else + { + load_okay = diskio::load_csv_ascii(*this, spec.filename, err_msg, spec.header_rw, with_header, separator); + } + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "SpMat::load(): ", err_msg, "; file: ", spec.filename); + } + else + { + arma_debug_warn_level(3, "SpMat::load(): read failed; file: ", spec.filename); + } + } + else + { + const uword load_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(with_header && (spec.header_rw.n_elem != load_n_cols)) + { + arma_debug_warn_level(3, "SpMat::load(): size mismatch between header and matrix"); + } + } + + if(load_okay == false) + { + (*this).reset(); + + if(with_header) { spec.header_rw.reset(); } + } + + return load_okay; + } + + + +//! load a matrix from a stream +template +inline +bool +SpMat::load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + invalidate_cache(); + + bool load_okay; + std::string err_msg; + + switch(type) + { + // case auto_detect: + // load_okay = diskio::load_auto_detect(*this, is, err_msg); + // break; + + case csv_ascii: + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(',')); + break; + + case ssv_ascii: + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(';')); + break; + + case arma_binary: + load_okay = diskio::load_arma_binary(*this, is, err_msg); + break; + + case coord_ascii: + load_okay = diskio::load_coord_ascii(*this, is, err_msg); + break; + + default: + arma_debug_warn_level(1, "SpMat::load(): unsupported file type"); + load_okay = false; + } + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "SpMat::load(): ", err_msg); + } + else + { + arma_debug_warn_level(3, "SpMat::load(): stream read failed"); + } + } + + if(load_okay == false) { (*this).reset(); } + + return load_okay; + } + + + +template +inline +bool +SpMat::quiet_save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(name, type); + } + + + +template +inline +bool +SpMat::quiet_save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(os, type); + } + + + +template +inline +bool +SpMat::quiet_load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(name, type); + } + + + +template +inline +bool +SpMat::quiet_load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(is, type); + } + + + +/** + * Initialize the matrix to the specified size. Data is not preserved, so the matrix is assumed to be entirely sparse (empty). + */ +template +inline +void +SpMat::init(uword in_rows, uword in_cols, const uword new_n_nonzero) + { + arma_extra_debug_sigprint(); + + invalidate_cache(); // placed here, as init() is used during matrix modification + + // Clean out the existing memory. + if(values ) { memory::release(access::rw(values)); } + if(row_indices) { memory::release(access::rw(row_indices)); } + if(col_ptrs ) { memory::release(access::rw(col_ptrs)); } + + // in case init_cold() throws an exception + access::rw(n_rows) = 0; + access::rw(n_cols) = 0; + access::rw(n_elem) = 0; + access::rw(n_nonzero) = 0; + access::rw(values) = nullptr; + access::rw(row_indices) = nullptr; + access::rw(col_ptrs) = nullptr; + + init_cold(in_rows, in_cols, new_n_nonzero); + } + + + +template +inline +void +SpMat::init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero) + { + arma_extra_debug_sigprint(); + + // Verify that we are allowed to do this. + if(vec_state > 0) + { + if((in_rows == 0) && (in_cols == 0)) + { + if(vec_state == 1) { in_cols = 1; } + if(vec_state == 2) { in_rows = 1; } + } + else + { + if(vec_state == 1) { arma_debug_check( (in_cols != 1), "SpMat::init(): object is a column vector; requested size is not compatible" ); } + if(vec_state == 2) { arma_debug_check( (in_rows != 1), "SpMat::init(): object is a row vector; requested size is not compatible" ); } + } + } + + #if defined(ARMA_64BIT_WORD) + const char* error_message = "SpMat::init(): requested size is too large"; + #else + const char* error_message = "SpMat::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + // Ensure that n_elem can hold the result of (n_rows * n_cols) + arma_debug_check + ( + ( + ( (in_rows > ARMA_MAX_UHWORD) || (in_cols > ARMA_MAX_UHWORD) ) + ? ( (double(in_rows) * double(in_cols)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message + ); + + access::rw(col_ptrs) = memory::acquire(in_cols + 2); + access::rw(values) = memory::acquire (new_n_nonzero + 1); + access::rw(row_indices) = memory::acquire(new_n_nonzero + 1); + + // fill column pointers with 0, + // except for the last element which contains the maximum possible element + // (so iterators terminate correctly). + arrayops::fill_zeros(access::rwp(col_ptrs), in_cols + 1); + + access::rw(col_ptrs[in_cols + 1]) = std::numeric_limits::max(); + + access::rw( values[new_n_nonzero]) = 0; + access::rw(row_indices[new_n_nonzero]) = 0; + + // Set the new size accordingly. + access::rw(n_rows) = in_rows; + access::rw(n_cols) = in_cols; + access::rw(n_elem) = (in_rows * in_cols); + access::rw(n_nonzero) = new_n_nonzero; + } + + + +template +inline +void +SpMat::init(const std::string& text) + { + arma_extra_debug_sigprint(); + + Mat tmp(text); + + if(vec_state == 1) + { + if((tmp.n_elem > 0) && tmp.is_vec()) + { + access::rw(tmp.n_rows) = tmp.n_elem; + access::rw(tmp.n_cols) = 1; + } + } + + if(vec_state == 2) + { + if((tmp.n_elem > 0) && tmp.is_vec()) + { + access::rw(tmp.n_rows) = 1; + access::rw(tmp.n_cols) = tmp.n_elem; + } + } + + (*this).operator=(tmp); + } + + + +template +inline +void +SpMat::init(const SpMat& x) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + bool init_done = false; + + #if defined(ARMA_USE_OPENMP) + if(x.sync_state == 1) + { + #pragma omp critical (arma_SpMat_init) + if(x.sync_state == 1) + { + (*this).init(x.cache); + init_done = true; + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + if(x.sync_state == 1) + { + const std::lock_guard lock(x.cache_mutex); + + if(x.sync_state == 1) + { + (*this).init(x.cache); + init_done = true; + } + } + #else + if(x.sync_state == 1) + { + (*this).init(x.cache); + init_done = true; + } + #endif + + if(init_done == false) + { + (*this).init_simple(x); + } + } + + + +template +inline +void +SpMat::init(const MapMat& x) + { + arma_extra_debug_sigprint(); + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + const uword x_n_nz = x.get_n_nonzero(); + + init(x_n_rows, x_n_cols, x_n_nz); + + if(x_n_nz == 0) { return; } + + typename MapMat::map_type& x_map_ref = *(x.map_ptr); + + typename MapMat::map_type::const_iterator x_it = x_map_ref.begin(); + + uword x_col = 0; + uword x_col_index_start = 0; + uword x_col_index_endp1 = x_n_rows; + + for(uword i=0; i < x_n_nz; ++i) + { + const std::pair& x_entry = (*x_it); + + const uword x_index = x_entry.first; + const eT x_val = x_entry.second; + + // have we gone past the curent column? + if(x_index >= x_col_index_endp1) + { + x_col = x_index / x_n_rows; + + x_col_index_start = x_col * x_n_rows; + x_col_index_endp1 = x_col_index_start + x_n_rows; + } + + const uword x_row = x_index - x_col_index_start; + + // // sanity check + // + // const uword tmp_x_row = x_index % x_n_rows; + // const uword tmp_x_col = x_index / x_n_rows; + // + // if(x_row != tmp_x_row) { cout << "x_row != tmp_x_row" << endl; exit(-1); } + // if(x_col != tmp_x_col) { cout << "x_col != tmp_x_col" << endl; exit(-1); } + + access::rw(values[i]) = x_val; + access::rw(row_indices[i]) = x_row; + + access::rw(col_ptrs[ x_col + 1 ])++; + + ++x_it; + } + + + for(uword i = 0; i < x_n_cols; ++i) + { + access::rw(col_ptrs[i + 1]) += col_ptrs[i]; + } + + + // // OLD METHOD + // + // for(uword i=0; i < x_n_nz; ++i) + // { + // const std::pair& x_entry = (*x_it); + // + // const uword x_index = x_entry.first; + // const eT x_val = x_entry.second; + // + // const uword x_row = x_index % x_n_rows; + // const uword x_col = x_index / x_n_rows; + // + // access::rw(values[i]) = x_val; + // access::rw(row_indices[i]) = x_row; + // + // access::rw(col_ptrs[ x_col + 1 ])++; + // + // ++x_it; + // } + // + // + // for(uword i = 0; i < x_n_cols; ++i) + // { + // access::rw(col_ptrs[i + 1]) += col_ptrs[i]; + // } + } + + + +template +inline +void +SpMat::init_simple(const SpMat& x) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + if((x.n_nonzero == 0) && (n_nonzero == 0) && (n_rows == x.n_rows) && (n_cols == x.n_cols) && (values != nullptr)) + { + invalidate_cache(); + } + else + { + init(x.n_rows, x.n_cols, x.n_nonzero); + } + + if(x.n_nonzero != 0) + { + if(x.values ) { arrayops::copy(access::rwp(values), x.values, x.n_nonzero + 1); } + if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); } + if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); } + } + } + + + +template +inline +void +SpMat::init_batch_std(const Mat& locs, const Mat& vals, const bool sort_locations) + { + arma_extra_debug_sigprint(); + + // Resize to correct number of elements. + mem_resize(vals.n_elem); + + // Reset column pointers to zero. + arrayops::fill_zeros(access::rwp(col_ptrs), n_cols + 1); + + bool actually_sorted = true; + + if(sort_locations) + { + // check if we really need a time consuming sort + + const uword locs_n_cols = locs.n_cols; + + for(uword i = 1; i < locs_n_cols; ++i) + { + const uword* locs_i = locs.colptr(i ); + const uword* locs_im1 = locs.colptr(i-1); + + const uword row_i = locs_i[0]; + const uword col_i = locs_i[1]; + + const uword row_im1 = locs_im1[0]; + const uword col_im1 = locs_im1[1]; + + if( (col_i < col_im1) || ((col_i == col_im1) && (row_i <= row_im1)) ) + { + actually_sorted = false; + break; + } + } + + if(actually_sorted == false) + { + // see op_sort_index_bones.hpp for the definition of arma_sort_index_packet and arma_sort_index_helper_ascend + + std::vector< arma_sort_index_packet > packet_vec(locs_n_cols); + + const uword* locs_mem = locs.memptr(); + + for(uword i = 0; i < locs_n_cols; ++i) + { + const uword row = (*locs_mem); locs_mem++; + const uword col = (*locs_mem); locs_mem++; + + packet_vec[i].val = (col * n_rows) + row; + packet_vec[i].index = i; + } + + arma_sort_index_helper_ascend comparator; + + std::sort( packet_vec.begin(), packet_vec.end(), comparator ); + + // insert the elements in the sorted order + for(uword i = 0; i < locs_n_cols; ++i) + { + const uword index = packet_vec[i].index; + + const uword* locs_i = locs.colptr(index); + + const uword row_i = locs_i[0]; + const uword col_i = locs_i[1]; + + arma_debug_check( ( (row_i >= n_rows) || (col_i >= n_cols) ), "SpMat::SpMat(): invalid row or column index" ); + + if(i > 0) + { + const uword prev_index = packet_vec[i-1].index; + + const uword* locs_im1 = locs.colptr(prev_index); + + const uword row_im1 = locs_im1[0]; + const uword col_im1 = locs_im1[1]; + + arma_debug_check( ( (row_i == row_im1) && (col_i == col_im1) ), "SpMat::SpMat(): detected identical locations" ); + } + + access::rw(values[i]) = vals[index]; + access::rw(row_indices[i]) = row_i; + + access::rw(col_ptrs[ col_i + 1 ])++; + } + } + } + + if( (sort_locations == false) || (actually_sorted == true) ) + { + // Now set the values and row indices correctly. + // Increment the column pointers in each column (so they are column "counts"). + + const uword locs_n_cols = locs.n_cols; + + for(uword i=0; i < locs_n_cols; ++i) + { + const uword* locs_i = locs.colptr(i); + + const uword row_i = locs_i[0]; + const uword col_i = locs_i[1]; + + arma_debug_check( ( (row_i >= n_rows) || (col_i >= n_cols) ), "SpMat::SpMat(): invalid row or column index" ); + + if(i > 0) + { + const uword* locs_im1 = locs.colptr(i-1); + + const uword row_im1 = locs_im1[0]; + const uword col_im1 = locs_im1[1]; + + arma_debug_check + ( + ( (col_i < col_im1) || ((col_i == col_im1) && (row_i < row_im1)) ), + "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering" + ); + + arma_debug_check( ( (col_i == col_im1) && (row_i == row_im1) ), "SpMat::SpMat(): detected identical locations" ); + } + + access::rw(values[i]) = vals[i]; + access::rw(row_indices[i]) = row_i; + + access::rw(col_ptrs[ col_i + 1 ])++; + } + } + + // Now fix the column pointers. + for(uword i = 0; i < n_cols; ++i) + { + access::rw(col_ptrs[i + 1]) += col_ptrs[i]; + } + } + + + +template +inline +void +SpMat::init_batch_add(const Mat& locs, const Mat& vals, const bool sort_locations) + { + arma_extra_debug_sigprint(); + + if(locs.n_cols < 2) + { + init_batch_std(locs, vals, false); + return; + } + + // Reset column pointers to zero. + arrayops::fill_zeros(access::rwp(col_ptrs), n_cols + 1); + + bool actually_sorted = true; + + if(sort_locations) + { + // sort_index() uses std::sort() which may use quicksort... so we better + // make sure it's not already sorted before taking an O(N^2) sort penalty. + for(uword i = 1; i < locs.n_cols; ++i) + { + const uword* locs_i = locs.colptr(i ); + const uword* locs_im1 = locs.colptr(i-1); + + if( (locs_i[1] < locs_im1[1]) || (locs_i[1] == locs_im1[1] && locs_i[0] <= locs_im1[0]) ) + { + actually_sorted = false; + break; + } + } + + if(actually_sorted == false) + { + // This may not be the fastest possible implementation but it maximizes code reuse. + Col abslocs(locs.n_cols, arma_nozeros_indicator()); + + for(uword i = 0; i < locs.n_cols; ++i) + { + const uword* locs_i = locs.colptr(i); + + abslocs[i] = locs_i[1] * n_rows + locs_i[0]; + } + + uvec sorted_indices = sort_index(abslocs); // Ascending sort. + + // work out the number of unique elments + uword n_unique = 1; // first element is unique + + for(uword i=1; i < sorted_indices.n_elem; ++i) + { + const uword* locs_i = locs.colptr( sorted_indices[i ] ); + const uword* locs_im1 = locs.colptr( sorted_indices[i-1] ); + + if( (locs_i[1] != locs_im1[1]) || (locs_i[0] != locs_im1[0]) ) { ++n_unique; } + } + + // resize to correct number of elements + mem_resize(n_unique); + + // Now we add the elements in this sorted order. + uword count = 0; + + // first element + { + const uword i = 0; + const uword* locs_i = locs.colptr( sorted_indices[i] ); + + arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" ); + + access::rw(values[count]) = vals[ sorted_indices[i] ]; + access::rw(row_indices[count]) = locs_i[0]; + + access::rw(col_ptrs[ locs_i[1] + 1 ])++; + } + + for(uword i=1; i < sorted_indices.n_elem; ++i) + { + const uword* locs_i = locs.colptr( sorted_indices[i ] ); + const uword* locs_im1 = locs.colptr( sorted_indices[i-1] ); + + arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" ); + + if( (locs_i[1] == locs_im1[1]) && (locs_i[0] == locs_im1[0]) ) + { + access::rw(values[count]) += vals[ sorted_indices[i] ]; + } + else + { + count++; + access::rw(values[count]) = vals[ sorted_indices[i] ]; + access::rw(row_indices[count]) = locs_i[0]; + + access::rw(col_ptrs[ locs_i[1] + 1 ])++; + } + } + } + } + + if( (sort_locations == false) || (actually_sorted == true) ) + { + // work out the number of unique elments + uword n_unique = 1; // first element is unique + + for(uword i=1; i < locs.n_cols; ++i) + { + const uword* locs_i = locs.colptr(i ); + const uword* locs_im1 = locs.colptr(i-1); + + if( (locs_i[1] != locs_im1[1]) || (locs_i[0] != locs_im1[0]) ) { ++n_unique; } + } + + // resize to correct number of elements + mem_resize(n_unique); + + // Now set the values and row indices correctly. + // Increment the column pointers in each column (so they are column "counts"). + + uword count = 0; + + // first element + { + const uword i = 0; + const uword* locs_i = locs.colptr(i); + + arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" ); + + access::rw(values[count]) = vals[i]; + access::rw(row_indices[count]) = locs_i[0]; + + access::rw(col_ptrs[ locs_i[1] + 1 ])++; + } + + for(uword i=1; i < locs.n_cols; ++i) + { + const uword* locs_i = locs.colptr(i ); + const uword* locs_im1 = locs.colptr(i-1); + + arma_debug_check( ( (locs_i[0] >= n_rows) || (locs_i[1] >= n_cols) ), "SpMat::SpMat(): invalid row or column index" ); + + arma_debug_check + ( + ( (locs_i[1] < locs_im1[1]) || (locs_i[1] == locs_im1[1] && locs_i[0] < locs_im1[0]) ), + "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering" + ); + + if( (locs_i[1] == locs_im1[1]) && (locs_i[0] == locs_im1[0]) ) + { + access::rw(values[count]) += vals[i]; + } + else + { + count++; + + access::rw(values[count]) = vals[i]; + access::rw(row_indices[count]) = locs_i[0]; + + access::rw(col_ptrs[ locs_i[1] + 1 ])++; + } + } + } + + // Now fix the column pointers. + for(uword i = 0; i < n_cols; ++i) + { + access::rw(col_ptrs[i + 1]) += col_ptrs[i]; + } + } + + + +//! constructor used by SpRow and SpCol classes +template +inline +SpMat::SpMat(const arma_vec_indicator&, const uword in_vec_state) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(in_vec_state) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + const uword in_n_rows = (in_vec_state == 2) ? 1 : 0; + const uword in_n_cols = (in_vec_state == 1) ? 1 : 0; + + init_cold(in_n_rows, in_n_cols); + } + + + +//! constructor used by SpRow and SpCol classes +template +inline +SpMat::SpMat(const arma_vec_indicator&, const uword in_n_rows, const uword in_n_cols, const uword in_vec_state) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(in_vec_state) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init_cold(in_n_rows, in_n_cols); + } + + + +template +inline +void +SpMat::mem_resize(const uword new_n_nonzero) + { + arma_extra_debug_sigprint(); + + invalidate_cache(); // placed here, as mem_resize() is used during matrix modification + + if(n_nonzero == new_n_nonzero) { return; } + + eT* new_values = memory::acquire (new_n_nonzero + 1); + uword* new_row_indices = memory::acquire(new_n_nonzero + 1); + + if( (n_nonzero > 0 ) && (new_n_nonzero > 0) ) + { + // Copy old elements. + uword copy_len = (std::min)(n_nonzero, new_n_nonzero); + + arrayops::copy(new_values, values, copy_len); + arrayops::copy(new_row_indices, row_indices, copy_len); + } + + if(values) { memory::release(access::rw(values)); } + if(row_indices) { memory::release(access::rw(row_indices)); } + + access::rw(values) = new_values; + access::rw(row_indices) = new_row_indices; + + // Set the "fake end" of the matrix by setting the last value and row index to 0. + // This helps the iterators work correctly. + access::rw( values[new_n_nonzero]) = 0; + access::rw(row_indices[new_n_nonzero]) = 0; + + access::rw(n_nonzero) = new_n_nonzero; + } + + + +template +inline +void +SpMat::sync() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + } + + + +template +inline +void +SpMat::remove_zeros() + { + arma_extra_debug_sigprint(); + + sync_csc(); + + invalidate_cache(); // placed here, as remove_zeros() is used during matrix modification + + const uword old_n_nonzero = n_nonzero; + uword new_n_nonzero = 0; + + const eT* old_values = values; + + constexpr eT zero = eT(0); + + for(uword i=0; i < old_n_nonzero; ++i) + { + new_n_nonzero += (old_values[i] != zero) ? uword(1) : uword(0); + } + + if(new_n_nonzero != old_n_nonzero) + { + if(new_n_nonzero == 0) { init(n_rows, n_cols); return; } + + SpMat tmp(arma_reserve_indicator(), n_rows, n_cols, new_n_nonzero); + + uword new_index = 0; + + const_iterator it = cbegin(); + const_iterator it_end = cend(); + + for(; it != it_end; ++it) + { + const eT val = eT(*it); + + if(val != zero) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + access::rw(tmp.values[new_index]) = val; + access::rw(tmp.row_indices[new_index]) = it_row; + access::rw(tmp.col_ptrs[it_col + 1])++; + ++new_index; + } + } + + for(uword i=0; i < n_cols; ++i) + { + access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i]; + } + + steal_mem(tmp); + } + } + + + +// Steal memory from another matrix. +template +inline +void +SpMat::steal_mem(SpMat& x) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + bool layout_ok = false; + + if((*this).vec_state == x.vec_state) + { + layout_ok = true; + } + else + { + if( ((*this).vec_state == 1) && (x.n_cols == 1) ) { layout_ok = true; } + if( ((*this).vec_state == 2) && (x.n_rows == 1) ) { layout_ok = true; } + } + + if(layout_ok) + { + arma_extra_debug_print("SpMat::steal_mem(): stealing memory"); + + x.sync_csc(); + + steal_mem_simple(x); + + x.invalidate_cache(); + + invalidate_cache(); + } + else + { + arma_extra_debug_print("SpMat::steal_mem(): copying memory"); + + (*this).operator=(x); + } + } + + + +template +inline +void +SpMat::steal_mem_simple(SpMat& x) + { + arma_extra_debug_sigprint(); + + if(this == &x) { return; } + + if(values ) { memory::release(access::rw(values)); } + if(row_indices) { memory::release(access::rw(row_indices)); } + if(col_ptrs ) { memory::release(access::rw(col_ptrs)); } + + access::rw(n_rows) = x.n_rows; + access::rw(n_cols) = x.n_cols; + access::rw(n_elem) = x.n_elem; + access::rw(n_nonzero) = x.n_nonzero; + + access::rw(values) = x.values; + access::rw(row_indices) = x.row_indices; + access::rw(col_ptrs) = x.col_ptrs; + + // Set other matrix to empty. + access::rw(x.n_rows) = 0; + access::rw(x.n_cols) = 0; + access::rw(x.n_elem) = 0; + access::rw(x.n_nonzero) = 0; + + access::rw(x.values) = nullptr; + access::rw(x.row_indices) = nullptr; + access::rw(x.col_ptrs) = nullptr; + } + + + +template +template +inline +void +SpMat::init_xform(const SpBase& A, const Functor& func) + { + arma_extra_debug_sigprint(); + + // if possible, avoid doing a copy and instead apply func to the generated elements + if(SpProxy::Q_is_generated) + { + (*this) = A.get_ref(); + + const uword nnz = n_nonzero; + + eT* t_values = access::rwp(values); + + bool has_zero = false; + + for(uword i=0; i < nnz; ++i) + { + eT& t_values_i = t_values[i]; + + t_values_i = func(t_values_i); + + if(t_values_i == eT(0)) { has_zero = true; } + } + + if(has_zero) { remove_zeros(); } + } + else + { + init_xform_mt(A.get_ref(), func); + } + } + + + +template +template +inline +void +SpMat::init_xform_mt(const SpBase& A, const Functor& func) + { + arma_extra_debug_sigprint(); + + const SpProxy P(A.get_ref()); + + if( P.is_alias(*this) || (is_SpMat::stored_type>::value) ) + { + // NOTE: unwrap_spmat will convert a submatrix to a matrix, which in effect takes care of aliasing with submatrices; + // NOTE: however, when more delayed ops are implemented, more elaborate handling of aliasing will be necessary + const unwrap_spmat::stored_type> tmp(P.Q); + + const SpMat& x = tmp.M; + + if(void_ptr(this) != void_ptr(&x)) + { + init(x.n_rows, x.n_cols, x.n_nonzero); + + arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); + arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); + } + + + // initialise the elements array with a transformed version of the elements from x + + const uword nnz = n_nonzero; + + const eT2* x_values = x.values; + eT* t_values = access::rwp(values); + + bool has_zero = false; + + for(uword i=0; i < nnz; ++i) + { + eT& t_values_i = t_values[i]; + + t_values_i = func(x_values[i]); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT) + + if(t_values_i == eT(0)) { has_zero = true; } + } + + if(has_zero) { remove_zeros(); } + } + else + { + init(P.get_n_rows(), P.get_n_cols(), P.get_n_nonzero()); + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + bool has_zero = false; + + while(it != it_end) + { + const eT val = func(*it); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT) + + if(val == eT(0)) { has_zero = true; } + + const uword it_pos = it.pos(); + + access::rw(row_indices[it_pos]) = it.row(); + access::rw(values[it_pos]) = val; + ++access::rw(col_ptrs[it.col() + 1]); + ++it; + } + + // Now sum column pointers. + for(uword c = 1; c <= n_cols; ++c) + { + access::rw(col_ptrs[c]) += col_ptrs[c - 1]; + } + + if(has_zero) { remove_zeros(); } + } + } + + + +template +arma_inline +bool +SpMat::is_alias(const SpMat& X) const + { + return (&X == this); + } + + + +template +inline +typename SpMat::iterator +SpMat::begin() + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return iterator(*this); + } + + + +template +inline +typename SpMat::const_iterator +SpMat::begin() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return const_iterator(*this); + } + + + +template +inline +typename SpMat::const_iterator +SpMat::cbegin() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return const_iterator(*this); + } + + + +template +inline +typename SpMat::iterator +SpMat::end() + { + sync_csc(); + + return iterator(*this, 0, n_cols, n_nonzero); + } + + + +template +inline +typename SpMat::const_iterator +SpMat::end() const + { + sync_csc(); + + return const_iterator(*this, 0, n_cols, n_nonzero); + } + + + +template +inline +typename SpMat::const_iterator +SpMat::cend() const + { + sync_csc(); + + return const_iterator(*this, 0, n_cols, n_nonzero); + } + + + +template +inline +typename SpMat::col_iterator +SpMat::begin_col(const uword col_num) + { + sync_csc(); + + return col_iterator(*this, 0, col_num); + } + + + +template +inline +typename SpMat::const_col_iterator +SpMat::begin_col(const uword col_num) const + { + sync_csc(); + + return const_col_iterator(*this, 0, col_num); + } + + + +template +inline +typename SpMat::col_iterator +SpMat::begin_col_no_sync(const uword col_num) + { + return col_iterator(*this, 0, col_num); + } + + + +template +inline +typename SpMat::const_col_iterator +SpMat::begin_col_no_sync(const uword col_num) const + { + return const_col_iterator(*this, 0, col_num); + } + + + +template +inline +typename SpMat::col_iterator +SpMat::end_col(const uword col_num) + { + sync_csc(); + + return col_iterator(*this, 0, col_num + 1); + } + + + +template +inline +typename SpMat::const_col_iterator +SpMat::end_col(const uword col_num) const + { + sync_csc(); + + return const_col_iterator(*this, 0, col_num + 1); + } + + + +template +inline +typename SpMat::col_iterator +SpMat::end_col_no_sync(const uword col_num) + { + return col_iterator(*this, 0, col_num + 1); + } + + + +template +inline +typename SpMat::const_col_iterator +SpMat::end_col_no_sync(const uword col_num) const + { + return const_col_iterator(*this, 0, col_num + 1); + } + + + +template +inline +typename SpMat::row_iterator +SpMat::begin_row(const uword row_num) + { + sync_csc(); + + return row_iterator(*this, row_num, 0); + } + + + +template +inline +typename SpMat::const_row_iterator +SpMat::begin_row(const uword row_num) const + { + sync_csc(); + + return const_row_iterator(*this, row_num, 0); + } + + + +template +inline +typename SpMat::row_iterator +SpMat::end_row() + { + sync_csc(); + + return row_iterator(*this, n_nonzero); + } + + + +template +inline +typename SpMat::const_row_iterator +SpMat::end_row() const + { + sync_csc(); + + return const_row_iterator(*this, n_nonzero); + } + + + +template +inline +typename SpMat::row_iterator +SpMat::end_row(const uword row_num) + { + sync_csc(); + + return row_iterator(*this, row_num + 1, 0); + } + + + +template +inline +typename SpMat::const_row_iterator +SpMat::end_row(const uword row_num) const + { + sync_csc(); + + return const_row_iterator(*this, row_num + 1, 0); + } + + + +template +inline +typename SpMat::row_col_iterator +SpMat::begin_row_col() + { + sync_csc(); + + return begin(); + } + + + +template +inline +typename SpMat::const_row_col_iterator +SpMat::begin_row_col() const + { + sync_csc(); + + return begin(); + } + + + +template +inline typename SpMat::row_col_iterator +SpMat::end_row_col() + { + sync_csc(); + + return end(); + } + + + +template +inline +typename SpMat::const_row_col_iterator +SpMat::end_row_col() const + { + sync_csc(); + + return end(); + } + + + +template +inline +void +SpMat::clear() + { + (*this).reset(); + } + + + +template +inline +bool +SpMat::empty() const + { + return (n_elem == 0); + } + + + +template +inline +uword +SpMat::size() const + { + return n_elem; + } + + + +template +arma_inline +SpMat_MapMat_val +SpMat::front() + { + arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" ); + + return SpMat_MapMat_val((*this), cache, 0, 0); + } + + + +template +arma_inline +eT +SpMat::front() const + { + arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" ); + + return get_value(0,0); + } + + + +template +arma_inline +SpMat_MapMat_val +SpMat::back() + { + arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" ); + + return SpMat_MapMat_val((*this), cache, n_rows-1, n_cols-1); + } + + + +template +arma_inline +eT +SpMat::back() const + { + arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" ); + + return get_value(n_rows-1, n_cols-1); + } + + + +template +inline +eT +SpMat::get_value(const uword i) const + { + const MapMat& const_cache = cache; // declare as const for clarity of intent + + // get the element from the cache if it has more recent data than CSC + + return (sync_state == 1) ? const_cache.operator[](i) : get_value_csc(i); + } + + + +template +inline +eT +SpMat::get_value(const uword in_row, const uword in_col) const + { + const MapMat& const_cache = cache; // declare as const for clarity of intent + + // get the element from the cache if it has more recent data than CSC + + return (sync_state == 1) ? const_cache.at(in_row, in_col) : get_value_csc(in_row, in_col); + } + + + +template +inline +eT +SpMat::get_value_csc(const uword i) const + { + // First convert to the actual location. + uword lcol = i / n_rows; // Integer division. + uword lrow = i % n_rows; + + return get_value_csc(lrow, lcol); + } + + + +template +inline +const eT* +SpMat::find_value_csc(const uword in_row, const uword in_col) const + { + const uword col_offset = col_ptrs[in_col ]; + const uword next_col_offset = col_ptrs[in_col + 1]; + + const uword* start_ptr = &row_indices[ col_offset]; + const uword* end_ptr = &row_indices[next_col_offset]; + + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, in_row); // binary search + + if( (pos_ptr != end_ptr) && ((*pos_ptr) == in_row) ) + { + const uword offset = uword(pos_ptr - start_ptr); + const uword index = offset + col_offset; + + return &(values[index]); + } + + return nullptr; + } + + + +template +inline +eT +SpMat::get_value_csc(const uword in_row, const uword in_col) const + { + const eT* val_ptr = find_value_csc(in_row, in_col); + + return (val_ptr != nullptr) ? eT(*val_ptr) : eT(0); + } + + + +template +inline +bool +SpMat::try_set_value_csc(const uword in_row, const uword in_col, const eT in_val) + { + const eT* val_ptr = find_value_csc(in_row, in_col); + + // element not found, ie. it's zero; fail if trying to set it to non-zero value + if(val_ptr == nullptr) { return (in_val == eT(0)); } + + // fail if trying to erase an existing element + if(in_val == eT(0)) { return false; } + + access::rw(*val_ptr) = in_val; + + invalidate_cache(); + + return true; + } + + + +template +inline +bool +SpMat::try_add_value_csc(const uword in_row, const uword in_col, const eT in_val) + { + const eT* val_ptr = find_value_csc(in_row, in_col); + + // element not found, ie. it's zero; fail if trying to add a non-zero value + if(val_ptr == nullptr) { return (in_val == eT(0)); } + + const eT new_val = eT(*val_ptr) + in_val; + + // fail if trying to erase an existing element + if(new_val == eT(0)) { return false; } + + access::rw(*val_ptr) = new_val; + + invalidate_cache(); + + return true; + } + + + +template +inline +bool +SpMat::try_sub_value_csc(const uword in_row, const uword in_col, const eT in_val) + { + const eT* val_ptr = find_value_csc(in_row, in_col); + + // element not found, ie. it's zero; fail if trying to subtract a non-zero value + if(val_ptr == nullptr) { return (in_val == eT(0)); } + + const eT new_val = eT(*val_ptr) - in_val; + + // fail if trying to erase an existing element + if(new_val == eT(0)) { return false; } + + access::rw(*val_ptr) = new_val; + + invalidate_cache(); + + return true; + } + + + +template +inline +bool +SpMat::try_mul_value_csc(const uword in_row, const uword in_col, const eT in_val) + { + const eT* val_ptr = find_value_csc(in_row, in_col); + + // element not found, ie. it's zero; succeed if given value is finite; zero multiplied by anything is zero, except for nan and inf + if(val_ptr == nullptr) { return arma_isfinite(in_val); } + + const eT new_val = eT(*val_ptr) * in_val; + + // fail if trying to erase an existing element + if(new_val == eT(0)) { return false; } + + access::rw(*val_ptr) = new_val; + + invalidate_cache(); + + return true; + } + + + +template +inline +bool +SpMat::try_div_value_csc(const uword in_row, const uword in_col, const eT in_val) + { + const eT* val_ptr = find_value_csc(in_row, in_col); + + // element not found, ie. it's zero; succeed if given value is not zero and not nan; zero divided by anything is zero, except for zero and nan + if(val_ptr == nullptr) { return ((in_val != eT(0)) && (arma_isnan(in_val) == false)); } + + const eT new_val = eT(*val_ptr) / in_val; + + // fail if trying to erase an existing element + if(new_val == eT(0)) { return false; } + + access::rw(*val_ptr) = new_val; + + invalidate_cache(); + + return true; + } + + + +/** + * Insert an element at the given position, and return a reference to it. + * The element will be set to 0, unless otherwise specified. + * If the element already exists, its value will be overwritten. + */ +template +inline +eT& +SpMat::insert_element(const uword in_row, const uword in_col, const eT val) + { + arma_extra_debug_sigprint(); + + sync_csc(); + invalidate_cache(); + + // We will assume the new element does not exist and begin the search for + // where to insert it. If we find that it already exists, we will then + // overwrite it. + uword colptr = col_ptrs[in_col ]; + uword next_colptr = col_ptrs[in_col + 1]; + + uword pos = colptr; // The position in the matrix of this value. + + if(colptr != next_colptr) + { + // There are other elements in this column, so we must find where this + // element will fit as compared to those. + while(pos < next_colptr && in_row > row_indices[pos]) + { + pos++; + } + + // We aren't inserting into the last position, so it is still possible + // that the element may exist. + if(pos != next_colptr && row_indices[pos] == in_row) + { + // It already exists. Then, just overwrite it. + access::rw(values[pos]) = val; + + return access::rw(values[pos]); + } + } + + + // + // Element doesn't exist, so we have to insert it + // + + // We have to update the rest of the column pointers. + for(uword i = in_col + 1; i < n_cols + 1; i++) + { + access::rw(col_ptrs[i])++; // We are only inserting one new element. + } + + const uword old_n_nonzero = n_nonzero; + + access::rw(n_nonzero)++; // Add to count of nonzero elements. + + // Allocate larger memory. + eT* new_values = memory::acquire (n_nonzero + 1); + uword* new_row_indices = memory::acquire(n_nonzero + 1); + + // Copy things over, before the new element. + if(pos > 0) + { + arrayops::copy(new_values, values, pos); + arrayops::copy(new_row_indices, row_indices, pos); + } + + // Insert the new element. + new_values[pos] = val; + new_row_indices[pos] = in_row; + + // Copy the rest of things over (including the extra element at the end). + arrayops::copy(new_values + pos + 1, values + pos, (old_n_nonzero - pos) + 1); + arrayops::copy(new_row_indices + pos + 1, row_indices + pos, (old_n_nonzero - pos) + 1); + + // Assign new pointers. + if(values) { memory::release(access::rw(values)); } + if(row_indices) { memory::release(access::rw(row_indices)); } + + access::rw(values) = new_values; + access::rw(row_indices) = new_row_indices; + + return access::rw(values[pos]); + } + + + +/** + * Delete an element at the given position. + */ +template +inline +void +SpMat::delete_element(const uword in_row, const uword in_col) + { + arma_extra_debug_sigprint(); + + sync_csc(); + invalidate_cache(); + + // We assume the element exists (although... it may not) and look for its + // exact position. If it doesn't exist... well, we don't need to do anything. + uword colptr = col_ptrs[in_col]; + uword next_colptr = col_ptrs[in_col + 1]; + + if(colptr != next_colptr) + { + // There's at least one element in this column. + // Let's see if we are one of them. + for(uword pos = colptr; pos < next_colptr; pos++) + { + if(in_row == row_indices[pos]) + { + --access::rw(n_nonzero); // Remove one from the count of nonzero elements. + + // Found it. Now remove it. + + // Make new arrays. + eT* new_values = memory::acquire (n_nonzero + 1); + uword* new_row_indices = memory::acquire(n_nonzero + 1); + + if(pos > 0) + { + arrayops::copy(new_values, values, pos); + arrayops::copy(new_row_indices, row_indices, pos); + } + + arrayops::copy(new_values + pos, values + pos + 1, (n_nonzero - pos) + 1); + arrayops::copy(new_row_indices + pos, row_indices + pos + 1, (n_nonzero - pos) + 1); + + if(values) { memory::release(access::rw(values)); } + if(row_indices) { memory::release(access::rw(row_indices)); } + + access::rw(values) = new_values; + access::rw(row_indices) = new_row_indices; + + // And lastly, update all the column pointers (decrement by one). + for(uword i = in_col + 1; i < n_cols + 1; i++) + { + --access::rw(col_ptrs[i]); // We only removed one element. + } + + return; // There is nothing left to do. + } + } + } + + return; // The element does not exist, so there's nothing for us to do. + } + + + +template +arma_inline +void +SpMat::invalidate_cache() const + { + arma_extra_debug_sigprint(); + + if(sync_state == 0) { return; } + + cache.reset(); + + sync_state = 0; + } + + + +template +arma_inline +void +SpMat::invalidate_csc() const + { + arma_extra_debug_sigprint(); + + sync_state = 1; + } + + + +template +inline +void +SpMat::sync_cache() const + { + arma_extra_debug_sigprint(); + + // using approach adapted from http://preshing.com/20130930/double-checked-locking-is-fixed-in-cpp11/ + // + // OpenMP mode: + // sync_state uses atomic read/write, which has an implied flush; + // flush is also implicitly executed at the entrance and the exit of critical section; + // data races are prevented by the 'critical' directive + // + // C++11 mode: + // underlying type for sync_state is std::atomic; + // reading and writing to sync_state uses std::memory_order_seq_cst which has an implied fence; + // data races are prevented via the mutex + + #if defined(ARMA_USE_OPENMP) + { + if(sync_state == 0) + { + #pragma omp critical (arma_SpMat_cache) + { + sync_cache_simple(); + } + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + if(sync_state == 0) + { + const std::lock_guard lock(cache_mutex); + + sync_cache_simple(); + } + } + #else + { + sync_cache_simple(); + } + #endif + } + + + + +template +inline +void +SpMat::sync_cache_simple() const + { + arma_extra_debug_sigprint(); + + if(sync_state == 0) + { + cache = (*this); + sync_state = 2; + } + } + + + + +template +inline +void +SpMat::sync_csc() const + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_OPENMP) + if(sync_state == 1) + { + #pragma omp critical (arma_SpMat_cache) + { + sync_csc_simple(); + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + if(sync_state == 1) + { + const std::lock_guard lock(cache_mutex); + + sync_csc_simple(); + } + #else + { + sync_csc_simple(); + } + #endif + } + + + +template +inline +void +SpMat::sync_csc_simple() const + { + arma_extra_debug_sigprint(); + + // method: + // 1. construct temporary matrix to prevent the cache from getting zapped + // 2. steal memory from the temporary matrix + + // sync_state is only set to 1 by non-const element access operators, + // so the shenanigans with const_cast are to satisfy the compiler + + // see also the note in sync_cache() above + + if(sync_state == 1) + { + SpMat& x = const_cast< SpMat& >(*this); + + SpMat tmp(cache); + + x.steal_mem_simple(tmp); + + sync_state = 2; + } + } + + + + +// +// SpMat_aux + + + +template +inline +void +SpMat_aux::set_real(SpMat& out, const SpBase& X) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp(X.get_ref()); + const SpMat& A = tmp.M; + + arma_debug_assert_same_size( out, A, "SpMat::set_real()" ); + + out = A; + } + + + +template +inline +void +SpMat_aux::set_imag(SpMat&, const SpBase&) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +SpMat_aux::set_real(SpMat< std::complex >& out, const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const unwrap_spmat U(X.get_ref()); + const SpMat& Y = U.M; + + arma_debug_assert_same_size(out, Y, "SpMat::set_real()"); + + SpMat tmp(Y,arma::imag(out)); // arma:: prefix required due to bugs in GCC 4.4 - 4.6 + + out.steal_mem(tmp); + } + + + +template +inline +void +SpMat_aux::set_imag(SpMat< std::complex >& out, const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const unwrap_spmat U(X.get_ref()); + const SpMat& Y = U.M; + + arma_debug_assert_same_size(out, Y, "SpMat::set_imag()"); + + SpMat tmp(arma::real(out),Y); // arma:: prefix required due to bugs in GCC 4.4 - 4.6 + + out.steal_mem(tmp); + } + + + +#if defined(ARMA_EXTRA_SPMAT_MEAT) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_MEAT) +#endif + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpOp_bones.hpp b/src/armadillo/include/armadillo_bits/SpOp_bones.hpp new file mode 100644 index 0000000..af8a229 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpOp_bones.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpOp +//! @{ + + + +template +class SpOp : public SpBase< typename T1::elem_type, SpOp > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; + + inline explicit SpOp(const T1& in_m); + inline SpOp(const T1& in_m, const elem_type in_aux); + inline SpOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); + inline ~SpOp(); + + arma_inline bool is_alias(const SpMat& X) const; + + arma_aligned const T1& m; //!< the operand; must be derived from SpBase + arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 + arma_aligned uword aux_uword_a; //!< auxiliary data, uword format + arma_aligned uword aux_uword_b; //!< auxiliary data, uword format + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpOp_meat.hpp b/src/armadillo/include/armadillo_bits/SpOp_meat.hpp new file mode 100644 index 0000000..2a6f1f5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpOp_meat.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpOp +//! @{ + + + +template +inline +SpOp::SpOp(const T1& in_m) + : m(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpOp::SpOp(const T1& in_m, const typename T1::elem_type in_aux) + : m(in_m) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpOp::SpOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b) + : m(in_m) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpOp::~SpOp() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +bool +SpOp::is_alias(const SpMat& X) const + { + return m.is_alias(X); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpProxy.hpp b/src/armadillo/include/armadillo_bits/SpProxy.hpp new file mode 100644 index 0000000..50adcc8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpProxy.hpp @@ -0,0 +1,688 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpProxy +//! @{ + + +// TODO: clarify and check which variables and functions are valid when 'use_iterator' is either true or false + + +// within each specialisation of the Proxy class: +// +// elem_type = the type of the elements obtained from object Q +// pod_type = the underlying type of elements if elem_type is std::complex +// stored_type = the type of the Q object +// +// const_iterator_type = the type of iterator provided by begin() and begin_col() +// const_row_iterator_type = the type of iterator provided by begin_row() +// +// use_iterator = boolean indicating that the provided iterators must be used for accessing elements +// Q_is_generated = boolean indicating that the Q object was generated by SpProxy +// +// is_row = boolean indicating whether the Q object can be treated a row vector +// is_col = boolean indicating whether the Q object can be treated a column vector +// is_xvec = boolean indicating whether the Q object is a vector with unknown orientation +// +// Q = object that can be unwrapped via the unwrap_spmat family of classes (ie. Q must be convertible to SpMat) +// +// get_n_rows() = return the number of rows in Q +// get_n_cols() = return the number of columns in Q +// get_n_elem() = return the number of elements in Q +// get_n_nonzero() = return the number of non-zero elements in Q +// +// operator[i] = linear element accessor; valid only if the 'use_iterator' boolean is false +// at(row,col) = access elements via (row,col); valid only if the 'use_iterator' boolean is false +// +// get_values() = return pointer to the CSC values array in Q; valid only if the 'use_iterator' boolean is false +// get_row_indices() = return pointer to the CSC row indices array in Q; valid only if the 'use_iterator' boolean is false +// get_col_ptrs() = return pointer to the CSC column pointers array in Q; valid only if the 'use_iterator' boolean is false +// +// begin() = column-wise iterator indicating the first element in Q +// begin_col(col_num) = column-wise iterator indicating the first element in column 'col_num' in Q +// begin_row(row_num = 0) = row-wise iterator indicating the first element in row 'row_num' in Q +// +// end() = column-wise iterator indicating the "one-past-end" element in Q +// end_row() = row-wise iterator indicating the "one-past-end" element in Q +// end_row(row_num) = row-wise iterator indicating the "one-past-end" element in row 'row_num' in Q +// +// is_alias(X) = return true/false indicating whether the Q object aliases matrix X + + + +template +struct SpProxy< SpMat > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const SpMat& Q; + + inline explicit SpProxy(const SpMat& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.sync(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< SpCol > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpCol stored_type; + + typedef typename SpCol::const_iterator const_iterator_type; + typedef typename SpCol::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const SpCol& Q; + + inline explicit SpProxy(const SpCol& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.sync(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword) const { return Q.begin(); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< SpRow > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpRow stored_type; + + typedef typename SpRow::const_iterator const_iterator_type; + typedef typename SpRow::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const SpRow& Q; + + inline explicit SpProxy(const SpRow& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.sync(); + } + + constexpr uword get_n_rows() const { return 1; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< SpSubview > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpSubview stored_type; + + typedef typename SpSubview::const_iterator const_iterator_type; + typedef typename SpSubview::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = true; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const SpSubview& Q; + + inline explicit SpProxy(const SpSubview& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.m.sync(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.m.values; } + arma_inline const uword* get_row_indices() const { return Q.m.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.m.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q.m) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< SpSubview_col > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpSubview_col stored_type; + + typedef typename SpSubview::const_iterator const_iterator_type; + typedef typename SpSubview::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = true; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const SpSubview_col& Q; + + inline explicit SpProxy(const SpSubview_col& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.m.sync(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q.at(i, 0); } + arma_inline elem_type at (const uword row, const uword) const { return Q.at(row, 0); } + + arma_inline const eT* get_values() const { return Q.m.values; } + arma_inline const uword* get_row_indices() const { return Q.m.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.m.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q.m) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< SpSubview_col_list > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const SpMat Q; + + inline explicit SpProxy(const SpSubview_col_list& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct SpProxy< SpSubview_row > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpSubview_row stored_type; + + typedef typename SpSubview::const_iterator const_iterator_type; + typedef typename SpSubview::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = true; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const SpSubview_row& Q; + + inline explicit SpProxy(const SpSubview_row& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.m.sync(); + } + + constexpr uword get_n_rows() const { return 1; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q.at(0, i ); } + arma_inline elem_type at (const uword, const uword col) const { return Q.at(0, col); } + + arma_inline const eT* get_values() const { return Q.m.values; } + arma_inline const uword* get_row_indices() const { return Q.m.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.m.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q.m) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< spdiagview > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const SpMat Q; + + inline explicit SpProxy(const spdiagview& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct SpProxy< SpOp > + { + typedef typename T1::elem_type elem_type; + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; + + static constexpr bool is_row = SpOp::is_row; + static constexpr bool is_col = SpOp::is_col; + static constexpr bool is_xvec = SpOp::is_xvec; + + arma_aligned const SpMat Q; + + inline explicit SpProxy(const SpOp& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct SpProxy< SpGlue > + { + typedef typename T1::elem_type elem_type; + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; + + static constexpr bool is_row = SpGlue::is_row; + static constexpr bool is_col = SpGlue::is_col; + static constexpr bool is_xvec = SpGlue::is_xvec; + + arma_aligned const SpMat Q; + + inline explicit SpProxy(const SpGlue& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct SpProxy< mtSpOp > + { + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; + + static constexpr bool is_row = mtSpOp::is_row; + static constexpr bool is_col = mtSpOp::is_col; + static constexpr bool is_xvec = mtSpOp::is_xvec; + + arma_aligned const SpMat Q; + + inline explicit SpProxy(const mtSpOp& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const out_eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct SpProxy< mtSpGlue > + { + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; + + static constexpr bool is_row = mtSpGlue::is_row; + static constexpr bool is_col = mtSpGlue::is_col; + static constexpr bool is_xvec = mtSpGlue::is_xvec; + + arma_aligned const SpMat Q; + + inline explicit SpProxy(const mtSpGlue& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const out_eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpRow_bones.hpp b/src/armadillo/include/armadillo_bits/SpRow_bones.hpp new file mode 100644 index 0000000..c13de1b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpRow_bones.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpRow +//! @{ + + +//! Class for sparse row vectors (sparse matrices with only one row) +template +class SpRow : public SpMat + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + + inline SpRow(); + inline explicit SpRow(const uword N); + inline explicit SpRow(const uword in_rows, const uword in_cols); + inline explicit SpRow(const SizeMat& s); + + inline SpRow(const char* text); + inline SpRow& operator=(const char* text); + + inline SpRow(const std::string& text); + inline SpRow& operator=(const std::string& text); + + inline SpRow& operator=(const eT val); + + template inline SpRow(const Base& X); + template inline SpRow& operator=(const Base& X); + + template inline SpRow(const SpBase& X); + template inline SpRow& operator=(const SpBase& X); + + template + inline explicit SpRow(const SpBase& A, const SpBase& B); + + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + + inline void shed_col (const uword col_num); + inline void shed_cols(const uword in_col1, const uword in_col2); + + // inline void insert_cols(const uword col_num, const uword N, const bool set_to_zero = true); + + + typedef typename SpMat::iterator row_iterator; + typedef typename SpMat::const_iterator const_row_iterator; + + inline row_iterator begin_row(const uword row_num = 0); + inline const_row_iterator begin_row(const uword row_num = 0) const; + + inline row_iterator end_row(const uword row_num = 0); + inline const_row_iterator end_row(const uword row_num = 0) const; + + #if defined(ARMA_EXTRA_SPROW_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPROW_PROTO) + #endif + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpRow_meat.hpp b/src/armadillo/include/armadillo_bits/SpRow_meat.hpp new file mode 100644 index 0000000..10f052f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpRow_meat.hpp @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpRow +//! @{ + + + +template +inline +SpRow::SpRow() + : SpMat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpRow::SpRow(const uword in_n_elem) + : SpMat(arma_vec_indicator(), 1, in_n_elem, 2) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpRow::SpRow(const uword in_n_rows, const uword in_n_cols) + : SpMat(arma_vec_indicator(), in_n_rows, in_n_cols, 2) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpRow::SpRow(const SizeMat& s) + : SpMat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + SpMat::init(s.n_rows, s.n_cols); + } + + + +template +inline +SpRow::SpRow(const char* text) + : SpMat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + SpMat::init(std::string(text)); + } + + + +template +inline +SpRow& +SpRow::operator=(const char* text) + { + arma_extra_debug_sigprint(); + + SpMat::init(std::string(text)); + + return *this; + } + + + +template +inline +SpRow::SpRow(const std::string& text) + : SpMat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + SpMat::init(text); + } + + + +template +inline +SpRow& +SpRow::operator=(const std::string& text) + { + arma_extra_debug_sigprint(); + + SpMat::init(text); + + return *this; + } + + + +template +inline +SpRow& +SpRow::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(val); + + return *this; + } + + + +template +template +inline +SpRow::SpRow(const Base& X) + : SpMat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + } + + + +template +template +inline +SpRow& +SpRow::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + + return *this; + } + + + +template +template +inline +SpRow::SpRow(const SpBase& X) + : SpMat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + } + + + +template +template +inline +SpRow& +SpRow::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + SpMat::operator=(X.get_ref()); + + return *this; + } + + + +template +template +inline +SpRow::SpRow + ( + const SpBase::pod_type, T1>& A, + const SpBase::pod_type, T2>& B + ) + : SpMat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(); + + SpMat::init(A,B); + } + + + +template +inline +const SpOp,spop_htrans> +SpRow::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpRow::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpRow::st() const + { + return SpOp,spop_strans>(*this); + } + + + +//! remove specified columns +template +inline +void +SpRow::shed_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= SpMat::n_cols, "SpRow::shed_col(): out of bounds" ); + + shed_cols(col_num, col_num); + } + + + +//! remove specified columns +template +inline +void +SpRow::shed_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= SpMat::n_cols), + "SpRow::shed_cols(): indices out of bounds or incorrectly used" + ); + + SpMat::sync_csc(); + + const uword diff = (in_col2 - in_col1 + 1); + + // This is doubleplus easy because we have all the column pointers stored. + const uword start = SpMat::col_ptrs[in_col1]; + const uword end = SpMat::col_ptrs[in_col2 + 1]; + + if(start != end) + { + const uword elem_diff = end - start; + + eT* new_values = memory::acquire (SpMat::n_nonzero - elem_diff); + uword* new_row_indices = memory::acquire(SpMat::n_nonzero - elem_diff); + + // Copy first set of elements, if necessary. + if(start > 0) + { + arrayops::copy(new_values, SpMat::values, start); + arrayops::copy(new_row_indices, SpMat::row_indices, start); + } + + // Copy last set of elements, if necessary. + if(end != SpMat::n_nonzero) + { + arrayops::copy(new_values + start, SpMat::values + end, (SpMat::n_nonzero - end)); + arrayops::copy(new_row_indices + start, SpMat::row_indices + end, (SpMat::n_nonzero - end)); + } + + memory::release(SpMat::values); + memory::release(SpMat::row_indices); + + access::rw(SpMat::values) = new_values; + access::rw(SpMat::row_indices) = new_row_indices; + + access::rw(SpMat::n_nonzero) -= elem_diff; + } + + // Update column pointers. + uword* new_col_ptrs = memory::acquire(SpMat::n_cols - diff + 1); + + // Copy first part of column pointers. + if(in_col1 > 0) + { + arrayops::copy(new_col_ptrs, SpMat::col_ptrs, in_col1); + } + + // Copy last part of column pointers (and adjust their values as necessary). + if(in_col2 < SpMat::n_cols - 1) + { + arrayops::copy(new_col_ptrs + in_col1, SpMat::col_ptrs + in_col2 + 1, SpMat::n_cols - in_col2); + // Modify their values. + arrayops::inplace_minus(new_col_ptrs + in_col1, (end - start), SpMat::n_cols - in_col2); + } + + memory::release(SpMat::col_ptrs); + + access::rw(SpMat::col_ptrs) = new_col_ptrs; + + access::rw(SpMat::n_cols) -= diff; + access::rw(SpMat::n_elem) -= diff; + + SpMat::invalidate_cache(); + } + + + +// //! insert N cols at the specified col position, +// //! optionally setting the elements of the inserted cols to zero +// template +// inline +// void +// SpRow::insert_cols(const uword col_num, const uword N, const bool set_to_zero) +// { +// arma_extra_debug_sigprint(); +// +// // insertion at col_num == n_cols is in effect an append operation +// arma_debug_check_bounds( (col_num > SpMat::n_cols), "SpRow::insert_cols(): out of bounds" ); +// +// arma_debug_check( (set_to_zero == false), "SpRow::insert_cols(): cannot set elements to nonzero values" ); +// +// uword newVal = (col_num == 0) ? 0 : SpMat::col_ptrs[col_num]; +// SpMat::col_ptrs.insert(col_num, N, newVal); +// uword* new_col_ptrs = memory::acquire(SpMat::n_cols + N); +// +// arrayops::copy(new_col_ptrs, SpMat::col_ptrs, col_num); +// +// uword fill_value = (col_num == 0) ? 0 : SpMat::col_ptrs[col_num - 1]; +// arrayops::inplace_set(new_col_ptrs + col_num, fill_value, N); +// +// arrayops::copy(new_col_ptrs + col_num + N, SpMat::col_ptrs + col_num, SpMat::n_cols - col_num); +// +// access::rw(SpMat::n_cols) += N; +// access::rw(SpMat::n_elem) += N; +// } + + + +template +inline +typename SpRow::row_iterator +SpRow::begin_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + // Since this is a row, row_num can only be 0. But the option is provided for + // compatibility. + arma_debug_check_bounds((row_num >= 1), "SpRow::begin_row(): index out of bounds"); + + return SpMat::begin(); + } + + + +template +inline +typename SpRow::const_row_iterator +SpRow::begin_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + // Since this is a row, row_num can only be 0. But the option is provided for + // compatibility. + arma_debug_check_bounds((row_num >= 1), "SpRow::begin_row(): index out of bounds"); + + return SpMat::begin(); + } + + + +template +inline +typename SpRow::row_iterator +SpRow::end_row(const uword row_num) + { + arma_extra_debug_sigprint(); + + // Since this is a row, row_num can only be 0. But the option is provided for + // compatibility. + arma_debug_check_bounds((row_num >= 1), "SpRow::end_row(): index out of bounds"); + + return SpMat::end(); + } + + + +template +inline +typename SpRow::const_row_iterator +SpRow::end_row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + // Since this is a row, row_num can only be 0. But the option is provided for + // compatibility. + arma_debug_check_bounds((row_num >= 1), "SpRow::end_row(): index out of bounds"); + + return SpMat::end(); + } + + + + +#if defined(ARMA_EXTRA_SPROW_MEAT) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPROW_MEAT) +#endif + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpSubview_bones.hpp b/src/armadillo/include/armadillo_bits/SpSubview_bones.hpp new file mode 100644 index 0000000..6be50b3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpSubview_bones.hpp @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpSubview +//! @{ + + +template +class SpSubview : public SpBase< eT, SpSubview > + { + public: + + const SpMat& m; + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const uword aux_row1; + const uword aux_col1; + const uword n_rows; + const uword n_cols; + const uword n_elem; + const uword n_nonzero; + + protected: + + inline SpSubview(const SpMat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols); + + public: + + inline ~SpSubview(); + inline SpSubview() = delete; + + inline SpSubview(const SpSubview& in); + inline SpSubview( SpSubview&& in); + + inline const SpSubview& operator+= (const eT val); + inline const SpSubview& operator-= (const eT val); + inline const SpSubview& operator*= (const eT val); + inline const SpSubview& operator/= (const eT val); + + inline const SpSubview& operator=(const SpSubview& x); + + template inline const SpSubview& operator= (const Base& x); + template inline const SpSubview& operator+=(const Base& x); + template inline const SpSubview& operator-=(const Base& x); + template inline const SpSubview& operator*=(const Base& x); + template inline const SpSubview& operator%=(const Base& x); + template inline const SpSubview& operator/=(const Base& x); + + template inline const SpSubview& operator_equ_common(const SpBase& x); + + template inline const SpSubview& operator= (const SpBase& x); + template inline const SpSubview& operator+=(const SpBase& x); + template inline const SpSubview& operator-=(const SpBase& x); + template inline const SpSubview& operator*=(const SpBase& x); + template inline const SpSubview& operator%=(const SpBase& x); + template inline const SpSubview& operator/=(const SpBase& x); + + /* + inline static void extract(SpMat& out, const SpSubview& in); + + inline static void plus_inplace(Mat& out, const subview& in); + inline static void minus_inplace(Mat& out, const subview& in); + inline static void schur_inplace(Mat& out, const subview& in); + inline static void div_inplace(Mat& out, const subview& in); + */ + + template inline void for_each(functor F); + template inline void for_each(functor F) const; + + template inline void transform(functor F); + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + inline void eye(); + inline void randu(); + inline void randn(); + + + arma_warn_unused inline SpSubview_MapMat_val operator[](const uword i); + arma_warn_unused inline eT operator[](const uword i) const; + + arma_warn_unused inline SpSubview_MapMat_val operator()(const uword i); + arma_warn_unused inline eT operator()(const uword i) const; + + arma_warn_unused inline SpSubview_MapMat_val operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; + + arma_warn_unused inline SpSubview_MapMat_val at(const uword i); + arma_warn_unused inline eT at(const uword i) const; + + arma_warn_unused inline SpSubview_MapMat_val at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; + + inline bool check_overlap(const SpSubview& x) const; + + arma_warn_unused inline bool is_vec() const; + + inline SpSubview_row row(const uword row_num); + inline const SpSubview_row row(const uword row_num) const; + + inline SpSubview_col col(const uword col_num); + inline const SpSubview_col col(const uword col_num) const; + + inline SpSubview rows(const uword in_row1, const uword in_row2); + inline const SpSubview rows(const uword in_row1, const uword in_row2) const; + + inline SpSubview cols(const uword in_col1, const uword in_col2); + inline const SpSubview cols(const uword in_col1, const uword in_col2) const; + + inline SpSubview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); + inline const SpSubview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const; + + inline SpSubview submat(const span& row_span, const span& col_span); + inline const SpSubview submat(const span& row_span, const span& col_span) const; + + inline SpSubview operator()(const uword row_num, const span& col_span); + inline const SpSubview operator()(const uword row_num, const span& col_span) const; + + inline SpSubview operator()(const span& row_span, const uword col_num); + inline const SpSubview operator()(const span& row_span, const uword col_num) const; + + inline SpSubview operator()(const span& row_span, const span& col_span); + inline const SpSubview operator()(const span& row_span, const span& col_span) const; + + + inline void swap_rows(const uword in_row1, const uword in_row2); + inline void swap_cols(const uword in_col1, const uword in_col2); + + // Forward declarations. + class iterator_base; + class const_iterator; + class iterator; + class const_row_iterator; + class row_iterator; + + // Similar to SpMat iterators but automatically iterates past and ignores values not in the subview. + class iterator_base + { + public: + + inline iterator_base(const SpSubview& in_M); + inline iterator_base(const SpSubview& in_M, const uword col, const uword pos); + + arma_inline uword col() const { return internal_col; } + arma_inline uword pos() const { return internal_pos; } + + arma_aligned const SpSubview* M; + arma_aligned uword internal_col; + arma_aligned uword internal_pos; + + typedef std::bidirectional_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef const eT* pointer; + typedef const eT& reference; + }; + + class const_iterator : public iterator_base + { + public: + + inline const_iterator(const SpSubview& in_M, uword initial_pos = 0); + inline const_iterator(const SpSubview& in_M, uword in_row, uword in_col); + inline const_iterator(const SpSubview& in_M, uword in_row, uword in_col, uword in_pos, uword skip_pos); + inline const_iterator(const const_iterator& other); + + arma_inline eT operator*() const; + + // Don't hold location internally; call "dummy" methods to get that information. + arma_inline uword row() const { return iterator_base::M->m.row_indices[iterator_base::internal_pos + skip_pos] - iterator_base::M->aux_row1; } + + arma_hot inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); + + arma_hot inline const_iterator& operator--(); + arma_warn_unused inline const_iterator operator--(int); + + arma_hot inline bool operator!=(const const_iterator& rhs) const; + arma_hot inline bool operator==(const const_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_iterator& rhs) const; + + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_row_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_row_iterator& rhs) const; + + arma_aligned uword skip_pos; // not used in row_iterator or const_row_iterator + }; + + class iterator : public const_iterator + { + public: + + inline iterator(SpSubview& in_M, const uword initial_pos = 0) : const_iterator(in_M, initial_pos) { } + inline iterator(SpSubview& in_M, const uword in_row, const uword in_col) : const_iterator(in_M, in_row, in_col) { } + inline iterator(SpSubview& in_M, const uword in_row, const uword in_col, const uword in_pos, const uword in_skip_pos) : const_iterator(in_M, in_row, in_col, in_pos, in_skip_pos) { } + inline iterator(const iterator& other) : const_iterator(other) { } + + arma_hot inline SpValProxy< SpSubview > operator*(); + + // overloads needed for return type correctness + arma_hot inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); + + arma_hot inline iterator& operator--(); + arma_warn_unused inline iterator operator--(int); + + // This has a different value_type than iterator_base. + typedef SpValProxy< SpSubview > value_type; + typedef const SpValProxy< SpSubview >* pointer; + typedef const SpValProxy< SpSubview >& reference; + }; + + class const_row_iterator : public iterator_base + { + public: + + inline const_row_iterator(); + inline const_row_iterator(const SpSubview& in_M, uword initial_pos = 0); + inline const_row_iterator(const SpSubview& in_M, uword in_row, uword in_col); + inline const_row_iterator(const const_row_iterator& other); + + arma_hot inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); + + arma_hot inline const_row_iterator& operator--(); + arma_warn_unused inline const_row_iterator operator--(int); + + uword internal_row; // Hold row internally because we use internal_pos differently. + uword actual_pos; // Actual position in subview's parent matrix. + + arma_inline eT operator*() const { return iterator_base::M->m.values[actual_pos]; } + + arma_inline uword row() const { return internal_row; } + + arma_hot inline bool operator!=(const const_iterator& rhs) const; + arma_hot inline bool operator==(const const_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_iterator& rhs) const; + + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_row_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_row_iterator& rhs) const; + }; + + class row_iterator : public const_row_iterator + { + public: + + inline row_iterator(SpSubview& in_M, uword initial_pos = 0) : const_row_iterator(in_M, initial_pos) { } + inline row_iterator(SpSubview& in_M, uword in_row, uword in_col) : const_row_iterator(in_M, in_row, in_col) { } + inline row_iterator(const row_iterator& other) : const_row_iterator(other) { } + + arma_hot inline SpValProxy< SpSubview > operator*(); + + // overloads needed for return type correctness + arma_hot inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); + + arma_hot inline row_iterator& operator--(); + arma_warn_unused inline row_iterator operator--(int); + + // This has a different value_type than iterator_base. + typedef SpValProxy< SpSubview > value_type; + typedef const SpValProxy< SpSubview >* pointer; + typedef const SpValProxy< SpSubview >& reference; + }; + + inline iterator begin(); + inline const_iterator begin() const; + inline const_iterator cbegin() const; + + inline iterator begin_col(const uword col_num); + inline const_iterator begin_col(const uword col_num) const; + + inline row_iterator begin_row(const uword row_num = 0); + inline const_row_iterator begin_row(const uword row_num = 0) const; + + inline iterator end(); + inline const_iterator end() const; + inline const_iterator cend() const; + + inline row_iterator end_row(); + inline const_row_iterator end_row() const; + + inline row_iterator end_row(const uword row_num); + inline const_row_iterator end_row(const uword row_num) const; + + //! don't use this unless you're writing internal Armadillo code + arma_inline bool is_alias(const SpMat& X) const; + + + private: + + friend class SpMat; + friend class SpSubview_col; + friend class SpSubview_row; + friend class SpValProxy< SpSubview >; // allow SpValProxy to call insert_element() and delete_element() + + arma_warn_unused inline eT& insert_element(const uword in_row, const uword in_col, const eT in_val = eT(0)); + inline void delete_element(const uword in_row, const uword in_col); + + inline void invalidate_cache() const; + }; + + + +template +class SpSubview_col : public SpSubview + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + inline void operator= (const SpSubview& x); + inline void operator= (const SpSubview_col& x); + + template inline void operator= (const SpBase& x); + template inline void operator= (const Base& x); + + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + + + protected: + + inline SpSubview_col(const SpMat& in_m, const uword in_col); + inline SpSubview_col(const SpMat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows); + inline SpSubview_col() = delete; + + + private: + + friend class SpMat; + friend class SpSubview; + }; + + + +template +class SpSubview_row : public SpSubview + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + inline void operator= (const SpSubview& x); + inline void operator= (const SpSubview_row& x); + + template inline void operator= (const SpBase& x); + template inline void operator= (const Base& x); + + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + + + protected: + + inline SpSubview_row(const SpMat& in_m, const uword in_row); + inline SpSubview_row(const SpMat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols); + inline SpSubview_row() = delete; + + + private: + + friend class SpMat; + friend class SpSubview; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpSubview_col_list_bones.hpp b/src/armadillo/include/armadillo_bits/SpSubview_col_list_bones.hpp new file mode 100644 index 0000000..8501291 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpSubview_col_list_bones.hpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpSubview_col_list +//! @{ + + + +template +class SpSubview_col_list : public SpBase< eT, SpSubview_col_list > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const SpMat& m; + const quasi_unwrap U_ci; + + + protected: + + arma_inline SpSubview_col_list(const SpMat& in_m, const Base& in_ci); + + + public: + + inline ~SpSubview_col_list(); + inline SpSubview_col_list() = delete; + + template inline void for_each(functor F); + template inline void for_each(functor F) const; + + template inline void transform(functor F); + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + + inline void operator+= (const eT val); + inline void operator-= (const eT val); + inline void operator*= (const eT val); + inline void operator/= (const eT val); + + template inline void operator= (const Base& x); + template inline void operator+=(const Base& x); + template inline void operator-=(const Base& x); + template inline void operator%=(const Base& x); + template inline void operator/=(const Base& x); + + inline void operator= (const SpSubview_col_list& x); + template inline void operator= (const SpSubview_col_list& x); + + template inline void operator= (const SpBase& x); + template inline void operator+= (const SpBase& x); + template inline void operator-= (const SpBase& x); + template inline void operator%= (const SpBase& x); + template inline void operator/= (const SpBase& x); + + inline static void extract(SpMat& out, const SpSubview_col_list& in); + + inline static void plus_inplace(SpMat& out, const SpSubview_col_list& in); + inline static void minus_inplace(SpMat& out, const SpSubview_col_list& in); + inline static void schur_inplace(SpMat& out, const SpSubview_col_list& in); + inline static void div_inplace(SpMat& out, const SpSubview_col_list& in); + + + friend class SpMat; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpSubview_col_list_meat.hpp b/src/armadillo/include/armadillo_bits/SpSubview_col_list_meat.hpp new file mode 100644 index 0000000..46d2d8d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpSubview_col_list_meat.hpp @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpSubview_col_list +//! @{ + + + +template +inline +SpSubview_col_list::~SpSubview_col_list() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +SpSubview_col_list::SpSubview_col_list + ( + const SpMat& in_m, + const Base& in_ci + ) + : m (in_m ) + , U_ci(in_ci.get_ref()) + { + arma_extra_debug_sigprint(); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_check + ( + ( (ci.is_vec() == false) && (ci.is_empty() == false) ), + "SpMat::cols(): given object must be a vector" + ); + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + arma_debug_check_bounds( (i >= in_m.n_cols), "SpMat::cols(): index out of bounds" ); + } + } + + + +//! apply a functor to each element +template +template +inline +void +SpSubview_col_list::for_each(functor F) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.for_each(F); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + const SpMat tmp(*this); + + tmp.for_each(F); + } + + + +//! transform each element using a functor +template +template +inline +void +SpSubview_col_list::transform(functor F) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.transform(F); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.replace(old_val, new_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::fill(const eT val) + { + arma_extra_debug_sigprint(); + + Mat tmp(m.n_rows, U_ci.M.n_elem, arma_nozeros_indicator()); tmp.fill(val); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::zeros() + { + arma_extra_debug_sigprint(); + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + m_local.sync_csc(); + m_local.invalidate_cache(); + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = m_local.col_ptrs[i+1] - m_local.col_ptrs[i]; + + uword offset = m_local.col_ptrs[i]; + + for(uword j=0; j < col_n_nonzero; ++j) + { + access::rw(m_local.values[offset]) = eT(0); + + ++offset; + } + } + + m_local.remove_zeros(); + } + + + +template +inline +void +SpSubview_col_list::ones() + { + arma_extra_debug_sigprint(); + + const Mat tmp(m.n_rows, U_ci.M.n_elem, fill::ones); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::operator+= (const eT val) + { + arma_extra_debug_sigprint(); + + const SpMat tmp1(*this); + + Mat tmp2(tmp1.n_rows, tmp1.n_cols, arma_nozeros_indicator()); tmp2.fill(val); + + const Mat tmp3 = tmp1 + tmp2; + + (*this).operator=(tmp3); + } + + + +template +inline +void +SpSubview_col_list::operator-= (const eT val) + { + arma_extra_debug_sigprint(); + + const SpMat tmp1(*this); + + Mat tmp2(tmp1.n_rows, tmp1.n_cols, arma_nozeros_indicator()); tmp2.fill(val); + + const Mat tmp3 = tmp1 - tmp2; + + (*this).operator=(tmp3); + } + + + +template +inline +void +SpSubview_col_list::operator*= (const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { (*this).zeros(); return; } + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + m_local.sync_csc(); + m_local.invalidate_cache(); + + bool has_zero = false; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = m_local.col_ptrs[i+1] - m_local.col_ptrs[i]; + + uword offset = m_local.col_ptrs[i]; + + for(uword j=0; j < col_n_nonzero; ++j) + { + eT& m_local_val = access::rw(m_local.values[offset]); + + m_local_val *= val; + + if(m_local_val == eT(0)) { has_zero = true; } + + ++offset; + } + } + + if(has_zero) { m_local.remove_zeros(); } + } + + + +template +inline +void +SpSubview_col_list::operator/= (const eT val) + { + arma_extra_debug_sigprint(); + + const SpMat tmp1(*this); + + Mat tmp2(tmp1.n_rows, tmp1.n_cols, arma_nozeros_indicator()); tmp2.fill(val); + + const SpMat tmp3 = tmp1 / tmp2; + + (*this).operator=(tmp3); + } + + + +template +template +inline +void +SpSubview_col_list::operator= (const Base& x) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U(x.get_ref()); + const Mat& X = U.M; + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_assert_same_size( m_local.n_rows, ci_n_elem, X.n_rows, X.n_cols, "SpMat::cols()" ); + + const uword X_n_elem = X.n_elem; + const eT* X_mem = X.memptr(); + + uword X_n_nonzero = 0; + + for(uword i=0; i < X_n_elem; ++i) { X_n_nonzero += (X_mem[i] != eT(0)) ? uword(1) : uword(0); } + + SpMat Y(arma_reserve_indicator(), X.n_rows, m_local.n_cols, X_n_nonzero); + + uword count = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + for(uword row=0; row < X.n_rows; ++row) + { + const eT X_val = (*X_mem); ++X_mem; + + if(X_val != eT(0)) + { + access::rw(Y.row_indices[count]) = row; + access::rw(Y.values [count]) = X_val; + ++count; + ++access::rw(Y.col_ptrs[i + 1]); + } + } + } + + // fix the column pointers + for(uword i = 0; i < Y.n_cols; ++i) + { + access::rw(Y.col_ptrs[i+1]) += Y.col_ptrs[i]; + } + + (*this).zeros(); + + SpMat tmp = m_local + Y; + + m_local.steal_mem(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator+= (const Base& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp = SpMat(*this) + x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator-= (const Base& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp = SpMat(*this) - x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator%= (const Base& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) % x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator/= (const Base& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) / x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::operator= (const SpSubview_col_list& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(x); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator= (const SpSubview_col_list& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(x); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(x.get_ref()); + const SpMat& X = U.M; + + if(U.is_alias(m)) + { + const SpMat tmp(X); + + (*this).operator=(tmp); + + return; + } + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_assert_same_size( m_local.n_rows, ci_n_elem, X.n_rows, X.n_cols, "SpMat::cols()" ); + + SpMat Y(arma_reserve_indicator(), X.n_rows, m_local.n_cols, X.n_nonzero); + + uword count = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + typename SpMat::const_col_iterator X_col_it = X.begin_col(ci_count); + typename SpMat::const_col_iterator X_col_it_end = X.end_col(ci_count); + + while(X_col_it != X_col_it_end) + { + access::rw(Y.row_indices[count]) = X_col_it.row(); + access::rw(Y.values [count]) = (*X_col_it); + ++count; + ++access::rw(Y.col_ptrs[i + 1]); + ++X_col_it; + } + } + + // fix the column pointers + for(uword i = 0; i < Y.n_cols; ++i) + { + access::rw(Y.col_ptrs[i+1]) += Y.col_ptrs[i]; + } + + (*this).zeros(); + + SpMat tmp = m_local + Y; + + m_local.steal_mem(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator+= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) + x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator-= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) - x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator%= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) % x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator/= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp /= x.get_ref(); + + (*this).operator=(tmp); + } + + + +// +// + + + +template +inline +void +SpSubview_col_list::extract(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + // NOTE: aliasing is handled by SpMat::operator=(const SpSubview_col_list& in) + + const umat& ci = in.U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + const SpMat& in_m = in.m; + + in_m.sync_csc(); + + uword total_n_nonzero = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = in_m.col_ptrs[i+1] - in_m.col_ptrs[i]; + + total_n_nonzero += col_n_nonzero; + } + + out.reserve(in.m.n_rows, ci_n_elem, total_n_nonzero); + + uword out_n_nonzero = 0; + uword out_col_count = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = in_m.col_ptrs[i+1] - in_m.col_ptrs[i]; + + uword offset = in_m.col_ptrs[i]; + + for(uword j=0; j < col_n_nonzero; ++j) + { + const eT val = in_m.values [ offset ]; + const uword row = in_m.row_indices[ offset ]; + + ++offset; + + access::rw(out.values [out_n_nonzero]) = val; + access::rw(out.row_indices[out_n_nonzero]) = row; + + access::rw(out.col_ptrs[out_col_count+1])++; + + ++out_n_nonzero; + } + + ++out_col_count; + } + + // fix the column pointers + for(uword i = 0; i < out.n_cols; ++i) + { + access::rw(out.col_ptrs[i+1]) += out.col_ptrs[i]; + } + } + + + +template +inline +void +SpSubview_col_list::plus_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out += tmp; + } + + + +template +inline +void +SpSubview_col_list::minus_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out -= tmp; + } + + + +template +inline +void +SpSubview_col_list::schur_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out %= tmp; + } + + + +template +inline +void +SpSubview_col_list::div_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out /= tmp; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpSubview_iterators_meat.hpp b/src/armadillo/include/armadillo_bits/SpSubview_iterators_meat.hpp new file mode 100644 index 0000000..d97d7c6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpSubview_iterators_meat.hpp @@ -0,0 +1,1154 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpSubview +//! @{ + + +/////////////////////////////////////////////////////////////////////////////// +// SpSubview::iterator_base implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpSubview::iterator_base::iterator_base(const SpSubview& in_M) + : M(&in_M) + , internal_col(0) + , internal_pos(0) + { + // Technically this iterator is invalid (it may not point to a valid element) + } + + + +template +inline +SpSubview::iterator_base::iterator_base(const SpSubview& in_M, const uword in_col, const uword in_pos) + : M(&in_M) + , internal_col(in_col) + , internal_pos(in_pos) + { + // Nothing to do. + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpSubview::const_iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpSubview::const_iterator::const_iterator(const SpSubview& in_M, const uword initial_pos) + : iterator_base(in_M, 0, initial_pos) + { + // Corner case for empty subviews. + if(in_M.n_nonzero == 0) + { + iterator_base::internal_col = in_M.n_cols; + skip_pos = in_M.m.n_nonzero; + return; + } + + // Figure out the row and column of the position. + // lskip_pos holds the number of values which aren't part of this subview. + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + const uword ln_rows = iterator_base::M->n_rows; + const uword ln_cols = iterator_base::M->n_cols; + + uword cur_pos = 0; // off by one because we might be searching for pos 0 + uword lskip_pos = iterator_base::M->m.col_ptrs[aux_col]; + uword cur_col = 0; + + while(cur_pos < (iterator_base::internal_pos + 1)) + { + // Have we stepped forward a column (or multiple columns)? + while(((lskip_pos + cur_pos) >= iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]) && (cur_col < ln_cols)) + { + ++cur_col; + } + + // See if the current position is in the subview. + const uword row_index = iterator_base::M->m.row_indices[cur_pos + lskip_pos]; + if(row_index < aux_row) + { + ++lskip_pos; // not valid + } + else if(row_index < (aux_row + ln_rows)) + { + ++cur_pos; // valid, in the subview + } + else + { + // skip to end of column + const uword next_colptr = iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]; + lskip_pos += (next_colptr - (cur_pos + lskip_pos)); + } + } + + iterator_base::internal_col = cur_col; + skip_pos = lskip_pos; + } + + + +template +inline +SpSubview::const_iterator::const_iterator(const SpSubview& in_M, const uword in_row, const uword in_col) + : iterator_base(in_M, in_col, 0) + { + // Corner case for empty subviews. + if(in_M.n_nonzero == 0) + { + // We must be at the last position. + iterator_base::internal_col = in_M.n_cols; + skip_pos = in_M.m.n_nonzero; + return; + } + + // We have a destination we want to be just after, but don't know what position that is. + // Because we have to count the points in this subview and not in this subview, this becomes a little difficult and slow. + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + const uword ln_rows = iterator_base::M->n_rows; + const uword ln_cols = iterator_base::M->n_cols; + + uword cur_pos = 0; + skip_pos = iterator_base::M->m.col_ptrs[aux_col]; + uword cur_col = 0; + + // Skip any empty columns. + while(((skip_pos + cur_pos) >= iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]) && (cur_col < ln_cols)) + { + ++cur_col; + } + + while(cur_col < in_col) + { + // See if the current position is in the subview. + const uword row_index = iterator_base::M->m.row_indices[cur_pos + skip_pos]; + if(row_index < aux_row) + { + ++skip_pos; + } + else if(row_index < (aux_row + ln_rows)) + { + ++cur_pos; + } + else + { + // skip to end of column + const uword next_colptr = iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]; + skip_pos += (next_colptr - (cur_pos + skip_pos)); + } + + // Have we stepped forward a column (or multiple columns)? + while(((skip_pos + cur_pos) >= iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]) && (cur_col < ln_cols)) + { + ++cur_col; + } + } + + // Now we are either on the right column or ahead of it. + if(cur_col == in_col) + { + // We have to find the right row index. + uword row_index = iterator_base::M->m.row_indices[cur_pos + skip_pos]; + while((row_index < (in_row + aux_row))) + { + if(row_index < aux_row) + { + ++skip_pos; + } + else + { + ++cur_pos; + } + + // Ensure we didn't step forward a column; if we did, we need to stop. + while(((skip_pos + cur_pos) >= iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]) && (cur_col < ln_cols)) + { + ++cur_col; + } + + if(cur_col != in_col) + { + break; + } + + row_index = iterator_base::M->m.row_indices[cur_pos + skip_pos]; + } + } + + // Now we need to find the next valid position in the subview. + uword row_index; + while(true) + { + const uword next_colptr = iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]; + row_index = iterator_base::M->m.row_indices[cur_pos + skip_pos]; + + // Are we at the last position? + if(cur_col >= ln_cols) + { + cur_col = ln_cols; + // Make sure we will be pointing at the last element in the parent matrix. + skip_pos = iterator_base::M->m.n_nonzero - iterator_base::M->n_nonzero; + break; + } + + if(row_index < aux_row) + { + ++skip_pos; + } + else if(row_index < (aux_row + ln_rows)) + { + break; // found + } + else + { + skip_pos += (next_colptr - (cur_pos + skip_pos)); + } + + // Did we move any columns? + while(((skip_pos + cur_pos) >= iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]) && (cur_col < ln_cols)) + { + ++cur_col; + } + } + + // It is possible we have moved another column. + while(((skip_pos + cur_pos) >= iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]) && (cur_col < ln_cols)) + { + ++cur_col; + } + + iterator_base::internal_pos = cur_pos; + iterator_base::internal_col = cur_col; + } + + + +template +inline +SpSubview::const_iterator::const_iterator(const SpSubview& in_M, uword in_row, uword in_col, uword in_pos, uword in_skip_pos) + : iterator_base(in_M, in_col, in_pos) + , skip_pos(in_skip_pos) + { + arma_ignore(in_row); + + // Nothing to do. + } + + + +template +inline +SpSubview::const_iterator::const_iterator(const const_iterator& other) + : iterator_base(*other.M, other.internal_col, other.internal_pos) + , skip_pos(other.skip_pos) + { + // Nothing to do. + } + + + +template +arma_inline +eT +SpSubview::const_iterator::operator*() const + { + return iterator_base::M->m.values[iterator_base::internal_pos + skip_pos]; + } + + + +template +inline +typename SpSubview::const_iterator& +SpSubview::const_iterator::operator++() + { + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + const uword ln_rows = iterator_base::M->n_rows; + const uword ln_cols = iterator_base::M->n_cols; + + uword cur_col = iterator_base::internal_col; + uword cur_pos = iterator_base::internal_pos + 1; + uword lskip_pos = skip_pos; + uword row_index; + + while(true) + { + const uword next_colptr = iterator_base::M->m.col_ptrs[cur_col + aux_col + 1]; + row_index = iterator_base::M->m.row_indices[cur_pos + lskip_pos]; + + // Did we move any columns? + while((cur_col < ln_cols) && ((lskip_pos + cur_pos) >= iterator_base::M->m.col_ptrs[cur_col + aux_col + 1])) + { + ++cur_col; + } + + // Are we at the last position? + if(cur_col >= ln_cols) + { + cur_col = ln_cols; + // Make sure we will be pointing at the last element in the parent matrix. + lskip_pos = iterator_base::M->m.n_nonzero - iterator_base::M->n_nonzero; + break; + } + + if(row_index < aux_row) + { + ++lskip_pos; + } + else if(row_index < (aux_row + ln_rows)) + { + break; // found + } + else + { + lskip_pos += (next_colptr - (cur_pos + lskip_pos)); + } + } + + iterator_base::internal_pos = cur_pos; + iterator_base::internal_col = cur_col; + skip_pos = lskip_pos; + + return *this; + } + + + +template +inline +typename SpSubview::const_iterator +SpSubview::const_iterator::operator++(int) + { + typename SpSubview::const_iterator tmp(*this); + + ++(*this); + + return tmp; + } + + + +template +inline +typename SpSubview::const_iterator& +SpSubview::const_iterator::operator--() + { + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + const uword ln_rows = iterator_base::M->n_rows; + + uword cur_col = iterator_base::internal_col; + uword cur_pos = iterator_base::internal_pos - 1; + + // Special condition for end of iterator. + if((skip_pos + cur_pos + 1) == iterator_base::M->m.n_nonzero) + { + // We are at the last element. So we need to set skip_pos back to what it + // would be if we didn't manually modify it back in operator++(). + skip_pos = iterator_base::M->m.col_ptrs[cur_col + aux_col] - iterator_base::internal_pos; + } + + uword row_index; + + while(true) + { + const uword colptr = iterator_base::M->m.col_ptrs[cur_col + aux_col]; + row_index = iterator_base::M->m.row_indices[cur_pos + skip_pos]; + + // Did we move back any columns? + while((skip_pos + cur_pos) < iterator_base::M->m.col_ptrs[cur_col + aux_col]) + { + --cur_col; + } + + if(row_index < aux_row) + { + skip_pos -= (colptr - (cur_pos + skip_pos) + 1); + } + else if(row_index < (aux_row + ln_rows)) + { + break; // found + } + else + { + --skip_pos; + } + } + + iterator_base::internal_pos = cur_pos; + iterator_base::internal_col = cur_col; + + return *this; + } + + + +template +inline +typename SpSubview::const_iterator +SpSubview::const_iterator::operator--(int) + { + typename SpSubview::const_iterator tmp(*this); + + --(*this); + + return tmp; + } + + + +template +inline +bool +SpSubview::const_iterator::operator==(const const_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_iterator::operator!=(const const_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_iterator::operator==(const typename SpMat::const_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_iterator::operator!=(const typename SpMat::const_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_iterator::operator==(const const_row_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_iterator::operator!=(const const_row_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_iterator::operator==(const typename SpMat::const_row_iterator& rhs) const + { + return (rhs.row() == (*this).row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_iterator::operator!=(const typename SpMat::const_row_iterator& rhs) const + { + return (rhs.row() != (*this).row()) || (rhs.col() != iterator_base::internal_col); + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpSubview::iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpValProxy< SpSubview > +SpSubview::iterator::operator*() + { + return SpValProxy< SpSubview >( + const_iterator::row(), + iterator_base::col(), + access::rw(*iterator_base::M), + &(access::rw(iterator_base::M->m.values[iterator_base::internal_pos + const_iterator::skip_pos]))); + } + + + +template +inline +typename SpSubview::iterator& +SpSubview::iterator::operator++() + { + const_iterator::operator++(); + return *this; + } + + + +template +inline +typename SpSubview::iterator +SpSubview::iterator::operator++(int) + { + typename SpSubview::iterator tmp(*this); + + const_iterator::operator++(); + + return tmp; + } + + + +template +inline +typename SpSubview::iterator& +SpSubview::iterator::operator--() + { + const_iterator::operator--(); + return *this; + } + + + +template +inline +typename SpSubview::iterator +SpSubview::iterator::operator--(int) + { + typename SpSubview::iterator tmp(*this); + + const_iterator::operator--(); + + return tmp; + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpSubview::const_row_iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpSubview::const_row_iterator::const_row_iterator() + : iterator_base() + , internal_row(0) + , actual_pos(0) + { + } + + + +template +inline +SpSubview::const_row_iterator::const_row_iterator(const SpSubview& in_M, uword initial_pos) + : iterator_base(in_M, 0, initial_pos) + , internal_row(0) + , actual_pos(0) + { + // Corner case for the end of a subview. + if(initial_pos == in_M.n_nonzero) + { + iterator_base::internal_col = 0; + internal_row = in_M.n_rows; + return; + } + + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + + // We don't count zeros in our position count, so we have to find the nonzero + // value corresponding to the given initial position, and we also have to skip + // any nonzero elements that aren't a part of the subview. + + uword cur_pos = std::numeric_limits::max(); + uword cur_actual_pos = 0; + + // Since we don't know where the elements are in each row, we have to loop + // across all columns looking for elements in row 0 and add to our sum, then + // in row 1, and so forth, until we get to the desired position. + for(uword row = 0; row < iterator_base::M->n_rows; ++row) + { + for(uword col = 0; col < iterator_base::M->n_cols; ++col) + { + // Find the first element with row greater than or equal to row + aux_row. + const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; + const uword next_col_offset = iterator_base::M->m.col_ptrs[col + aux_col + 1]; + + const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, row + aux_row); + + const uword offset = uword(pos_ptr - start_ptr); + + if(iterator_base::M->m.row_indices[col_offset + offset] == row + aux_row) + { + cur_actual_pos = col_offset + offset; + + // Increment position portably. + if(cur_pos == std::numeric_limits::max()) + cur_pos = 0; + else + ++cur_pos; + + // Do we terminate? + if(cur_pos == initial_pos) + { + internal_row = row; + iterator_base::internal_col = col; + iterator_base::internal_pos = cur_pos; + actual_pos = cur_actual_pos; + + return; + } + } + } + } + } + + // This shouldn't happen. + iterator_base::internal_pos = iterator_base::M->n_nonzero; + iterator_base::internal_col = 0; + internal_row = iterator_base::M->n_rows; + actual_pos = iterator_base::M->n_nonzero; + } + + + +template +inline +SpSubview::const_row_iterator::const_row_iterator(const SpSubview& in_M, uword in_row, uword in_col) + : iterator_base(in_M, in_col, 0) + , internal_row(0) + , actual_pos(0) + { + // Start our search in the given row. We need to find two things: + // + // 1. The first nonzero element (iterating by rows) after (in_row, in_col). + // 2. The number of nonzero elements (iterating by rows) that come before + // (in_row, in_col). + // + // We'll find these simultaneously, though we will have to loop over all + // columns. + + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + + // This will hold the total number of points in the subview with rows less + // than in_row. + uword cur_pos = 0; + uword cur_min_row = iterator_base::M->n_rows; + uword cur_min_col = 0; + uword cur_actual_pos = 0; + + for(uword col = 0; col < iterator_base::M->n_cols; ++col) + { + // Find the first element with row greater than or equal to in_row. + const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; + const uword next_col_offset = iterator_base::M->m.col_ptrs[col + aux_col + 1]; + + const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + // First let us find the first element that is in the subview. + const uword* first_subview_ptr = std::lower_bound(start_ptr, end_ptr, aux_row); + + if(first_subview_ptr != end_ptr && (*first_subview_ptr) < aux_row + iterator_base::M->n_rows) + { + // There exists at least one element in the subview. + const uword* pos_ptr = std::lower_bound(first_subview_ptr, end_ptr, aux_row + in_row); + + // This is the number of elements in the subview with row index less + // than in_row. + cur_pos += uword(pos_ptr - first_subview_ptr); + + if(pos_ptr != end_ptr && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + { + // This is the row index of the first element in the column with row + // index greater than or equal to in_row + aux_row. + if((*pos_ptr) - aux_row < cur_min_row) + { + // If we are in the desired row but before the desired column, we + // can't take this. + if(col >= in_col) + { + cur_min_row = (*pos_ptr) - aux_row; + cur_min_col = col; + cur_actual_pos = col_offset + (pos_ptr - start_ptr); + } + } + } + } + } + } + + // Now we know what the minimum row is. + internal_row = cur_min_row; + iterator_base::internal_col = cur_min_col; + iterator_base::internal_pos = cur_pos; + actual_pos = cur_actual_pos; + } + + + +template +inline +SpSubview::const_row_iterator::const_row_iterator(const const_row_iterator& other) + : iterator_base(*other.M, other.internal_col, other.internal_pos) + , internal_row(other.internal_row) + , actual_pos(other.actual_pos) + { + // Nothing to do. + } + + + +template +inline +typename SpSubview::const_row_iterator& +SpSubview::const_row_iterator::operator++() + { + // We just need to find the next nonzero element. + ++iterator_base::internal_pos; + + // If we have exceeded the bounds, update accordingly. + if(iterator_base::internal_pos >= iterator_base::M->n_nonzero) + { + internal_row = iterator_base::M->n_rows; + iterator_base::internal_col = 0; + actual_pos = iterator_base::M->n_nonzero; + + return *this; + } + + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + const uword M_n_cols = iterator_base::M->n_cols; + + // Otherwise, we need to search. We have to loop over all of the columns in + // the subview. + uword next_min_row = iterator_base::M->n_rows; + uword next_min_col = 0; + uword next_actual_pos = 0; + + for(uword col = iterator_base::internal_col + 1; col < M_n_cols; ++col) + { + // Find the first element with row greater than or equal to row. + const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; + const uword next_col_offset = iterator_base::M->m.col_ptrs[col + aux_col + 1]; + + const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + // Find the first element in the column with row greater than or equal to + // the current row. Since this is a subview, it's possible that we may + // find rows past the end of the subview. + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row); + + if(pos_ptr != end_ptr) + { + // We found something; is the row index correct? + if((*pos_ptr) == internal_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + { + // Exact match---so we are done. + iterator_base::internal_col = col; + actual_pos = col_offset + (pos_ptr - start_ptr); + return *this; + } + else if((*pos_ptr) < next_min_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + { + // The first element in this column is in a subsequent row, but it's + // the minimum row we've seen so far. + next_min_row = (*pos_ptr) - aux_row; + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + else if((*pos_ptr) == next_min_row + aux_row && col < next_min_col && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + { + // The first element in this column is in a subsequent row that we + // already have another elemnt for, but the column index is less so + // this element will come first. + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + } + } + } + + // Restart the search in the next row. + for(uword col = 0; col <= iterator_base::internal_col; ++col) + { + // Find the first element with row greater than or equal to row + 1. + const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; + const uword next_col_offset = iterator_base::M->m.col_ptrs[col + aux_col + 1]; + + const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row + 1); + + if(pos_ptr != end_ptr) + { + // We found something in the column, but is the row index correct? + if((*pos_ptr) == internal_row + aux_row + 1 && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + { + // Exact match---so we are done. + iterator_base::internal_col = col; + internal_row++; + actual_pos = col_offset + (pos_ptr - start_ptr); + return *this; + } + else if((*pos_ptr) < next_min_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + { + // The first element in this column is in a subsequent row, but it's + // the minimum row we've seen so far. + next_min_row = (*pos_ptr) - aux_row; + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + else if((*pos_ptr) == next_min_row + aux_row && col < next_min_col && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + { + // We've found a better column. + next_min_col = col; + next_actual_pos = col_offset + (pos_ptr - start_ptr); + } + } + } + } + + iterator_base::internal_col = next_min_col; + internal_row = next_min_row; + actual_pos = next_actual_pos; + + return *this; + } + + + +template +inline +typename SpSubview::const_row_iterator +SpSubview::const_row_iterator::operator++(int) + { + typename SpSubview::const_row_iterator tmp(*this); + + ++(*this); + + return tmp; + } + + + +template +inline +typename SpSubview::const_row_iterator& +SpSubview::const_row_iterator::operator--() + { + if(iterator_base::internal_pos == 0) + { + // We are already at the beginning. + return *this; + } + + iterator_base::internal_pos--; + + const uword aux_col = iterator_base::M->aux_col1; + const uword aux_row = iterator_base::M->aux_row1; + + // We have to search backwards. + uword max_row = 0; + uword max_col = 0; + uword next_actual_pos = 0; + + for(uword col = iterator_base::internal_col; col >= 1; --col) + { + // Find the first element with row greater than or equal to in_row + 1. + const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col - 1]; + const uword next_col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; + + const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + // There are elements in this column. + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row + 1); + + if(pos_ptr != start_ptr) + { + if(*(pos_ptr - 1) > max_row + aux_row) + { + // There are elements in this column with row index < internal_row. + max_row = *(pos_ptr - 1) - aux_row; + max_col = col - 1; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + else if(*(pos_ptr - 1) == max_row + aux_row && (col - 1) >= max_col) + { + max_col = col - 1; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + } + } + } + + for(uword col = iterator_base::M->n_cols - 1; col >= iterator_base::internal_col; --col) + { + // Find the first element with row greater than or equal to row + 1. + const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; + const uword next_col_offset = iterator_base::M->m.col_ptrs[col + aux_col + 1]; + + const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; + const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; + + if(start_ptr != end_ptr) + { + // There are elements in this column. + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row); + + if(pos_ptr != start_ptr) + { + // There are elements in this column with row index < internal_row. + if(*(pos_ptr - 1) > max_row + aux_row) + { + max_row = *(pos_ptr - 1) - aux_row; + max_col = col; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + else if(*(pos_ptr - 1) == max_row + aux_row && col >= max_col) + { + max_col = col; + next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); + } + } + } + + if(col == 0) // Catch edge case that the loop termination condition won't. + { + break; + } + } + + iterator_base::internal_col = max_col; + internal_row = max_row; + actual_pos = next_actual_pos; + + return *this; + } + + + +template +inline +typename SpSubview::const_row_iterator +SpSubview::const_row_iterator::operator--(int) + { + typename SpSubview::const_row_iterator tmp(*this); + + --(*this); + + return tmp; + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator==(const const_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator!=(const const_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator==(const typename SpMat::const_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator!=(const typename SpMat::const_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator==(const const_row_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator!=(const const_row_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator==(const typename SpMat::const_row_iterator& rhs) const + { + return (rhs.row() == row()) && (rhs.col() == iterator_base::internal_col); + } + + + +template +inline +bool +SpSubview::const_row_iterator::operator!=(const typename SpMat::const_row_iterator& rhs) const + { + return (rhs.row() != row()) || (rhs.col() != iterator_base::internal_col); + } + + + +/////////////////////////////////////////////////////////////////////////////// +// SpSubview::row_iterator implementation // +/////////////////////////////////////////////////////////////////////////////// + +template +inline +SpValProxy< SpSubview > +SpSubview::row_iterator::operator*() + { + return SpValProxy< SpSubview >( + const_row_iterator::internal_row, + iterator_base::internal_col, + access::rw(*iterator_base::M), + &access::rw(iterator_base::M->m.values[const_row_iterator::actual_pos])); + } + + + +template +inline +typename SpSubview::row_iterator& +SpSubview::row_iterator::operator++() + { + const_row_iterator::operator++(); + return *this; + } + + + +template +inline +typename SpSubview::row_iterator +SpSubview::row_iterator::operator++(int) + { + typename SpSubview::row_iterator tmp(*this); + + ++(*this); + + return tmp; + } + + + +template +inline +typename SpSubview::row_iterator& +SpSubview::row_iterator::operator--() + { + const_row_iterator::operator--(); + return *this; + } + + + +template +inline +typename SpSubview::row_iterator +SpSubview::row_iterator::operator--(int) + { + typename SpSubview::row_iterator tmp(*this); + + --(*this); + + return tmp; + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpSubview_meat.hpp b/src/armadillo/include/armadillo_bits/SpSubview_meat.hpp new file mode 100644 index 0000000..481359b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpSubview_meat.hpp @@ -0,0 +1,2006 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpSubview +//! @{ + + +template +inline +SpSubview::~SpSubview() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +SpSubview::SpSubview(const SpMat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols) + : m(in_m) + , aux_row1(in_row1) + , aux_col1(in_col1) + , n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows * in_n_cols) + , n_nonzero(0) + { + arma_extra_debug_sigprint_this(this); + + m.sync_csc(); + + // There must be a O(1) way to do this + uword lend = m.col_ptrs[in_col1 + in_n_cols]; + uword lend_row = in_row1 + in_n_rows; + uword count = 0; + + for(uword i = m.col_ptrs[in_col1]; i < lend; ++i) + { + const uword m_row_indices_i = m.row_indices[i]; + + const bool condition = (m_row_indices_i >= in_row1) && (m_row_indices_i < lend_row); + + count += condition ? uword(1) : uword(0); + } + + access::rw(n_nonzero) = count; + } + + + +template +inline +SpSubview::SpSubview(const SpSubview& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) + , n_nonzero(in.n_nonzero) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + } + + + +template +inline +SpSubview::SpSubview(SpSubview&& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) + , n_nonzero(in.n_nonzero) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + + // for paranoia + + access::rw(in.aux_row1 ) = 0; + access::rw(in.aux_col1 ) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_cols ) = 0; + access::rw(in.n_elem ) = 0; + access::rw(in.n_nonzero) = 0; + } + + + +template +inline +const SpSubview& +SpSubview::operator+=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { return *this; } + + Mat tmp( (*this).n_rows, (*this).n_cols, arma_nozeros_indicator() ); + + tmp.fill(val); + + return (*this).operator=( (*this) + tmp ); + } + + + +template +inline +const SpSubview& +SpSubview::operator-=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { return *this; } + + Mat tmp( (*this).n_rows, (*this).n_cols, arma_nozeros_indicator() ); + + tmp.fill(val); + + return (*this).operator=( (*this) - tmp ); + } + + + +template +inline +const SpSubview& +SpSubview::operator*=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { (*this).zeros(); return *this; } + + if((n_elem == 0) || (n_nonzero == 0)) { return *this; } + + m.sync_csc(); + m.invalidate_cache(); + + const uword lstart_row = aux_row1; + const uword lend_row = aux_row1 + n_rows; + + const uword lstart_col = aux_col1; + const uword lend_col = aux_col1 + n_cols; + + const uword* m_row_indices = m.row_indices; + eT* m_values = access::rwp(m.values); + + bool has_zero = false; + + for(uword c = lstart_col; c < lend_col; ++c) + { + const uword r_start = m.col_ptrs[c ]; + const uword r_end = m.col_ptrs[c + 1]; + + for(uword r = r_start; r < r_end; ++r) + { + const uword m_row_indices_r = m_row_indices[r]; + + if( (m_row_indices_r >= lstart_row) && (m_row_indices_r < lend_row) ) + { + eT& m_values_r = m_values[r]; + + m_values_r *= val; + + if(m_values_r == eT(0)) { has_zero = true; } + } + } + } + + if(has_zero) + { + const uword old_m_n_nonzero = m.n_nonzero; + + access::rw(m).remove_zeros(); + + if(m.n_nonzero != old_m_n_nonzero) + { + access::rw(n_nonzero) = n_nonzero - (old_m_n_nonzero - m.n_nonzero); + } + } + + return *this; + } + + + +template +inline +const SpSubview& +SpSubview::operator/=(const eT val) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (val == eT(0)), "element-wise division: division by zero" ); + + m.sync_csc(); + m.invalidate_cache(); + + const uword lstart_row = aux_row1; + const uword lend_row = aux_row1 + n_rows; + + const uword lstart_col = aux_col1; + const uword lend_col = aux_col1 + n_cols; + + const uword* m_row_indices = m.row_indices; + eT* m_values = access::rwp(m.values); + + bool has_zero = false; + + for(uword c = lstart_col; c < lend_col; ++c) + { + const uword r_start = m.col_ptrs[c ]; + const uword r_end = m.col_ptrs[c + 1]; + + for(uword r = r_start; r < r_end; ++r) + { + const uword m_row_indices_r = m_row_indices[r]; + + if( (m_row_indices_r >= lstart_row) && (m_row_indices_r < lend_row) ) + { + eT& m_values_r = m_values[r]; + + m_values_r /= val; + + if(m_values_r == eT(0)) { has_zero = true; } + } + } + } + + if(has_zero) + { + const uword old_m_n_nonzero = m.n_nonzero; + + access::rw(m).remove_zeros(); + + if(m.n_nonzero != old_m_n_nonzero) + { + access::rw(n_nonzero) = n_nonzero - (old_m_n_nonzero - m.n_nonzero); + } + } + + return *this; + } + + + +template +template +inline +const SpSubview& +SpSubview::operator=(const Base& in) + { + arma_extra_debug_sigprint(); + + if(is_same_type< T1, Gen, gen_zeros> >::yes) + { + const Proxy P(in.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, P.get_n_rows(), P.get_n_cols(), "insertion into sparse submatrix"); + + (*this).zeros(); + + return *this; + } + + if(is_same_type< T1, Gen, gen_eye> >::yes) + { + const Proxy P(in.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, P.get_n_rows(), P.get_n_cols(), "insertion into sparse submatrix"); + + (*this).eye(); + + return *this; + } + + const quasi_unwrap U(in.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, U.M.n_rows, U.M.n_cols, "insertion into sparse submatrix"); + + spglue_merge::subview_merge(*this, U.M); + + return *this; + } + + + +template +template +inline +const SpSubview& +SpSubview::operator+=(const Base& x) + { + arma_extra_debug_sigprint(); + + return (*this).operator=( (*this) + x.get_ref() ); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator-=(const Base& x) + { + arma_extra_debug_sigprint(); + + return (*this).operator=( (*this) - x.get_ref() ); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator*=(const Base& x) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp *= x.get_ref(); + + return (*this).operator=(tmp); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator%=(const Base& x) + { + arma_extra_debug_sigprint(); + + SpSubview& sv = (*this); + + const quasi_unwrap U(x.get_ref()); + const Mat& B = U.M; + + arma_debug_assert_same_size(sv.n_rows, sv.n_cols, B.n_rows, B.n_cols, "element-wise multiplication"); + + SpMat& sv_m = access::rw(sv.m); + + sv_m.sync_csc(); + sv_m.invalidate_cache(); + + const uword m_row_start = sv.aux_row1; + const uword m_row_end = sv.aux_row1 + sv.n_rows - 1; + + const uword m_col_start = sv.aux_col1; + const uword m_col_end = sv.aux_col1 + sv.n_cols - 1; + + constexpr eT zero = eT(0); + + bool has_zero = false; + uword count = 0; + + for(uword m_col = m_col_start; m_col <= m_col_end; ++m_col) + { + const uword sv_col = m_col - m_col_start; + + const uword index_start = sv_m.col_ptrs[m_col ]; + const uword index_end = sv_m.col_ptrs[m_col + 1]; + + for(uword i=index_start; i < index_end; ++i) + { + const uword m_row = sv_m.row_indices[i]; + + if(m_row < m_row_start) { continue; } + if(m_row > m_row_end ) { break; } + + const uword sv_row = m_row - m_row_start; + + eT& m_val = access::rw(sv_m.values[i]); + + const eT result = m_val * B.at(sv_row, sv_col); + + m_val = result; + + if(result == zero) { has_zero = true; } else { ++count; } + } + } + + if(has_zero) { sv_m.remove_zeros(); } + + access::rw(sv.n_nonzero) = count; + + return (*this); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator/=(const Base& x) + { + arma_extra_debug_sigprint(); + + const SpSubview& A = (*this); + + const quasi_unwrap U(x.get_ref()); + const Mat& B = U.M; + + arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "element-wise division"); + + bool result_ok = true; + + constexpr eT zero = eT(0); + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + for(uword c=0; c < B_n_cols; ++c) + { + for(uword r=0; r < B_n_rows; ++r) + { + // a zero in B and A at the same location implies the division result is NaN; + // hence a zero in A (not stored) needs to be changed into a non-zero + + // for efficiency, an element in B is checked before checking the corresponding element in A + + if((B.at(r,c) == zero) && (A.at(r,c) == zero)) { result_ok = false; break; } + } + + if(result_ok == false) { break; } + } + + if(result_ok) + { + const_iterator cit = A.begin(); + const_iterator cit_end = A.end(); + + while(cit != cit_end) + { + const eT tmp = (*cit) / B.at(cit.row(), cit.col()); + + if(tmp == zero) { result_ok = false; break; } + + ++cit; + } + } + + if(result_ok) + { + iterator it = (*this).begin(); + iterator it_end = (*this).end(); + + while(it != it_end) + { + (*it) /= B.at(it.row(), it.col()); + + ++it; + } + } + else + { + (*this).operator=( (*this) / B ); + } + + return (*this); + } + + + +template +inline +const SpSubview& +SpSubview::operator=(const SpSubview& x) + { + arma_extra_debug_sigprint(); + + return (*this).operator_equ_common(x); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + return (*this).operator_equ_common( x.get_ref() ); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator_equ_common(const SpBase& in) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(in.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, U.M.n_rows, U.M.n_cols, "insertion into sparse submatrix"); + + if(U.is_alias(m)) + { + const SpMat tmp(U.M); + + spglue_merge::subview_merge(*this, tmp); + } + else + { + spglue_merge::subview_merge(*this, U.M); + } + + return *this; + } + + + +template +template +inline +const SpSubview& +SpSubview::operator+=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + // TODO: implement dedicated machinery + return (*this).operator=( (*this) + x.get_ref() ); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator-=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + // TODO: implement dedicated machinery + return (*this).operator=( (*this) - x.get_ref() ); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator*=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + return (*this).operator=( (*this) * x.get_ref() ); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator%=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + // TODO: implement dedicated machinery + return (*this).operator=( (*this) % x.get_ref() ); + } + + + +template +template +inline +const SpSubview& +SpSubview::operator/=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + // NOTE: use of this function is not advised; it is implemented only for completeness + + SpProxy p(x.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise division"); + + if(p.is_alias(m) == false) + { + for(uword lcol = 0; lcol < n_cols; ++lcol) + for(uword lrow = 0; lrow < n_rows; ++lrow) + { + at(lrow,lcol) /= p.at(lrow,lcol); + } + } + else + { + const SpMat tmp(p.Q); + + (*this).operator/=(tmp); + } + + return *this; + } + + + +//! apply a functor to each element +template +template +inline +void +SpSubview::for_each(functor F) + { + arma_extra_debug_sigprint(); + + m.sync_csc(); + m.invalidate_cache(); + + const uword lstart_row = aux_row1; + const uword lend_row = aux_row1 + n_rows; + + const uword lstart_col = aux_col1; + const uword lend_col = aux_col1 + n_cols; + + const uword* m_row_indices = m.row_indices; + eT* m_values = access::rwp(m.values); + + bool has_zero = false; + + for(uword c = lstart_col; c < lend_col; ++c) + { + const uword r_start = m.col_ptrs[c ]; + const uword r_end = m.col_ptrs[c + 1]; + + for(uword r = r_start; r < r_end; ++r) + { + const uword m_row_indices_r = m_row_indices[r]; + + if( (m_row_indices_r >= lstart_row) && (m_row_indices_r < lend_row) ) + { + eT& m_values_r = m_values[r]; + + F(m_values_r); + + if(m_values_r == eT(0)) { has_zero = true; } + } + } + } + + if(has_zero) + { + const uword old_m_n_nonzero = m.n_nonzero; + + access::rw(m).remove_zeros(); + + if(m.n_nonzero != old_m_n_nonzero) + { + access::rw(n_nonzero) = n_nonzero - (old_m_n_nonzero - m.n_nonzero); + } + } + } + + + +template +template +inline +void +SpSubview::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + m.sync_csc(); + + const uword lstart_row = aux_row1; + const uword lend_row = aux_row1 + n_rows; + + const uword lstart_col = aux_col1; + const uword lend_col = aux_col1 + n_cols; + + const uword* m_row_indices = m.row_indices; + + for(uword c = lstart_col; c < lend_col; ++c) + { + const uword r_start = m.col_ptrs[c ]; + const uword r_end = m.col_ptrs[c + 1]; + + for(uword r = r_start; r < r_end; ++r) + { + const uword m_row_indices_r = m_row_indices[r]; + + if( (m_row_indices_r >= lstart_row) && (m_row_indices_r < lend_row) ) + { + F(m.values[r]); + } + } + } + } + + + +//! transform each element using a functor +template +template +inline +void +SpSubview::transform(functor F) + { + arma_extra_debug_sigprint(); + + m.sync_csc(); + m.invalidate_cache(); + + const uword lstart_row = aux_row1; + const uword lend_row = aux_row1 + n_rows; + + const uword lstart_col = aux_col1; + const uword lend_col = aux_col1 + n_cols; + + const uword* m_row_indices = m.row_indices; + eT* m_values = access::rwp(m.values); + + bool has_zero = false; + + for(uword c = lstart_col; c < lend_col; ++c) + { + const uword r_start = m.col_ptrs[c ]; + const uword r_end = m.col_ptrs[c + 1]; + + for(uword r = r_start; r < r_end; ++r) + { + const uword m_row_indices_r = m_row_indices[r]; + + if( (m_row_indices_r >= lstart_row) && (m_row_indices_r < lend_row) ) + { + eT& m_values_r = m_values[r]; + + m_values_r = eT( F(m_values_r) ); + + if(m_values_r == eT(0)) { has_zero = true; } + } + } + } + + if(has_zero) + { + const uword old_m_n_nonzero = m.n_nonzero; + + access::rw(m).remove_zeros(); + + if(m.n_nonzero != old_m_n_nonzero) + { + access::rw(n_nonzero) = n_nonzero - (old_m_n_nonzero - m.n_nonzero); + } + } + } + + + +template +inline +void +SpSubview::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + if(old_val == eT(0)) + { + if(new_val != eT(0)) + { + Mat tmp(*this); + + tmp.replace(old_val, new_val); + + (*this).operator=(tmp); + } + + return; + } + + m.sync_csc(); + m.invalidate_cache(); + + const uword lstart_row = aux_row1; + const uword lend_row = aux_row1 + n_rows; + + const uword lstart_col = aux_col1; + const uword lend_col = aux_col1 + n_cols; + + const uword* m_row_indices = m.row_indices; + eT* m_values = access::rwp(m.values); + + if(arma_isnan(old_val)) + { + for(uword c = lstart_col; c < lend_col; ++c) + { + const uword r_start = m.col_ptrs[c ]; + const uword r_end = m.col_ptrs[c + 1]; + + for(uword r = r_start; r < r_end; ++r) + { + const uword m_row_indices_r = m_row_indices[r]; + + if( (m_row_indices_r >= lstart_row) && (m_row_indices_r < lend_row) ) + { + eT& val = m_values[r]; + + val = (arma_isnan(val)) ? new_val : val; + } + } + } + } + else + { + for(uword c = lstart_col; c < lend_col; ++c) + { + const uword r_start = m.col_ptrs[c ]; + const uword r_end = m.col_ptrs[c + 1]; + + for(uword r = r_start; r < r_end; ++r) + { + const uword m_row_indices_r = m_row_indices[r]; + + if( (m_row_indices_r >= lstart_row) && (m_row_indices_r < lend_row) ) + { + eT& val = m_values[r]; + + val = (val == old_val) ? new_val : val; + } + } + } + } + + if(new_val == eT(0)) { access::rw(m).remove_zeros(); } + } + + + +template +inline +void +SpSubview::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + if((n_elem == 0) || (n_nonzero == 0)) { return; } + + // TODO: replace with a more efficient implementation + + SpMat tmp(*this); + + tmp.clean(threshold); + + if(is_cx::yes) + { + (*this).operator=(tmp); + } + else + if(tmp.n_nonzero != n_nonzero) + { + (*this).operator=(tmp); + } + } + + + +template +inline +void +SpSubview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpSubview::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpSubview::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "SpSubview::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + if((n_elem == 0) || (n_nonzero == 0)) { return; } + + // TODO: replace with a more efficient implementation + + SpMat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview::fill(const eT val) + { + arma_extra_debug_sigprint(); + + if(val != eT(0)) + { + Mat tmp( (*this).n_rows, (*this).n_cols, arma_nozeros_indicator() ); + + tmp.fill(val); + + (*this).operator=(tmp); + } + else + { + (*this).zeros(); + } + } + + + +template +inline +void +SpSubview::zeros() + { + arma_extra_debug_sigprint(); + + if((n_elem == 0) || (n_nonzero == 0)) { return; } + + if((m.n_nonzero - n_nonzero) == 0) + { + access::rw(m).zeros(); + access::rw(n_nonzero) = 0; + return; + } + + SpMat tmp(arma_reserve_indicator(), m.n_rows, m.n_cols, m.n_nonzero - n_nonzero); + + const uword sv_row_start = aux_row1; + const uword sv_col_start = aux_col1; + + const uword sv_row_end = aux_row1 + n_rows - 1; + const uword sv_col_end = aux_col1 + n_cols - 1; + + typename SpMat::const_iterator m_it = m.begin(); + typename SpMat::const_iterator m_it_end = m.end(); + + uword tmp_count = 0; + + for(; m_it != m_it_end; ++m_it) + { + const uword m_it_row = m_it.row(); + const uword m_it_col = m_it.col(); + + const bool inside_box = ((m_it_row >= sv_row_start) && (m_it_row <= sv_row_end)) && ((m_it_col >= sv_col_start) && (m_it_col <= sv_col_end)); + + if(inside_box == false) + { + access::rw(tmp.values[tmp_count]) = (*m_it); + access::rw(tmp.row_indices[tmp_count]) = m_it_row; + access::rw(tmp.col_ptrs[m_it_col + 1])++; + ++tmp_count; + } + } + + for(uword i=0; i < tmp.n_cols; ++i) + { + access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i]; + } + + access::rw(m).steal_mem(tmp); + + access::rw(n_nonzero) = 0; + } + + + +template +inline +void +SpSubview::ones() + { + arma_extra_debug_sigprint(); + + (*this).fill(eT(1)); + } + + + +template +inline +void +SpSubview::eye() + { + arma_extra_debug_sigprint(); + + SpMat tmp; + + tmp.eye( (*this).n_rows, (*this).n_cols ); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview::randu() + { + arma_extra_debug_sigprint(); + + Mat tmp( (*this).n_rows, (*this).n_cols, fill::randu ); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview::randn() + { + arma_extra_debug_sigprint(); + + Mat tmp( (*this).n_rows, (*this).n_cols, fill::randn ); + + (*this).operator=(tmp); + } + + + +template +inline +SpSubview_MapMat_val +SpSubview::operator[](const uword i) + { + const uword lrow = i % n_rows; + const uword lcol = i / n_rows; + + return (*this).at(lrow, lcol); + } + + + +template +inline +eT +SpSubview::operator[](const uword i) const + { + const uword lrow = i % n_rows; + const uword lcol = i / n_rows; + + return (*this).at(lrow, lcol); + } + + + +template +inline +SpSubview_MapMat_val +SpSubview::operator()(const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "SpSubview::operator(): index out of bounds" ); + + const uword lrow = i % n_rows; + const uword lcol = i / n_rows; + + return (*this).at(lrow, lcol); + } + + + +template +inline +eT +SpSubview::operator()(const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "SpSubview::operator(): index out of bounds" ); + + const uword lrow = i % n_rows; + const uword lcol = i / n_rows; + + return (*this).at(lrow, lcol); + } + + + +template +inline +SpSubview_MapMat_val +SpSubview::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( (in_row >= n_rows) || (in_col >= n_cols), "SpSubview::operator(): index out of bounds" ); + + return (*this).at(in_row, in_col); + } + + + +template +inline +eT +SpSubview::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( (in_row >= n_rows) || (in_col >= n_cols), "SpSubview::operator(): index out of bounds" ); + + return (*this).at(in_row, in_col); + } + + + +template +inline +SpSubview_MapMat_val +SpSubview::at(const uword i) + { + const uword lrow = i % n_rows; + const uword lcol = i / n_cols; + + return (*this).at(lrow, lcol); + } + + + +template +inline +eT +SpSubview::at(const uword i) const + { + const uword lrow = i % n_rows; + const uword lcol = i / n_cols; + + return (*this).at(lrow, lcol); + } + + + +template +inline +SpSubview_MapMat_val +SpSubview::at(const uword in_row, const uword in_col) + { + return SpSubview_MapMat_val((*this), m.cache, aux_row1 + in_row, aux_col1 + in_col); + } + + + +template +inline +eT +SpSubview::at(const uword in_row, const uword in_col) const + { + return m.at(aux_row1 + in_row, aux_col1 + in_col); + } + + + +template +inline +bool +SpSubview::check_overlap(const SpSubview& x) const + { + const SpSubview& t = *this; + + if(&t.m != &x.m) + { + return false; + } + else + { + if( (t.n_elem == 0) || (x.n_elem == 0) ) + { + return false; + } + else + { + const uword t_row_start = t.aux_row1; + const uword t_row_end_p1 = t_row_start + t.n_rows; + + const uword t_col_start = t.aux_col1; + const uword t_col_end_p1 = t_col_start + t.n_cols; + + const uword x_row_start = x.aux_row1; + const uword x_row_end_p1 = x_row_start + x.n_rows; + + const uword x_col_start = x.aux_col1; + const uword x_col_end_p1 = x_col_start + x.n_cols; + + const bool outside_rows = ( (x_row_start >= t_row_end_p1) || (t_row_start >= x_row_end_p1) ); + const bool outside_cols = ( (x_col_start >= t_col_end_p1) || (t_col_start >= x_col_end_p1) ); + + return ( (outside_rows == false) && (outside_cols == false) ); + } + } + } + + + +template +inline +bool +SpSubview::is_vec() const + { + return ( (n_rows == 1) || (n_cols == 1) ); + } + + + +template +inline +SpSubview_row +SpSubview::row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(row_num >= n_rows, "SpSubview::row(): out of bounds"); + + return SpSubview_row(const_cast< SpMat& >(m), row_num + aux_row1, aux_col1, n_cols); + } + + + +template +inline +const SpSubview_row +SpSubview::row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(row_num >= n_rows, "SpSubview::row(): out of bounds"); + + return SpSubview_row(m, row_num + aux_row1, aux_col1, n_cols); + } + + + +template +inline +SpSubview_col +SpSubview::col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(col_num >= n_cols, "SpSubview::col(): out of bounds"); + + return SpSubview_col(const_cast< SpMat& >(m), col_num + aux_col1, aux_row1, n_rows); + } + + + +template +inline +const SpSubview_col +SpSubview::col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds(col_num >= n_cols, "SpSubview::col(): out of bounds"); + + return SpSubview_col(m, col_num + aux_col1, aux_row1, n_rows); + } + + + +template +inline +SpSubview +SpSubview::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "SpSubview::rows(): indices out of bounds or incorrectly used" + ); + + return submat(in_row1, 0, in_row2, n_cols - 1); + } + + + +template +inline +const SpSubview +SpSubview::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "SpSubview::rows(): indices out of bounds or incorrectly used" + ); + + return submat(in_row1, 0, in_row2, n_cols - 1); + } + + + +template +inline +SpSubview +SpSubview::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "SpSubview::cols(): indices out of bounds or incorrectly used" + ); + + return submat(0, in_col1, n_rows - 1, in_col2); + } + + + +template +inline +const SpSubview +SpSubview::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "SpSubview::cols(): indices out of bounds or incorrectly used" + ); + + return submat(0, in_col1, n_rows - 1, in_col2); + } + + + +template +inline +SpSubview +SpSubview::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "SpSubview::submat(): indices out of bounds or incorrectly used" + ); + + return access::rw(m).submat(in_row1 + aux_row1, in_col1 + aux_col1, in_row2 + aux_row1, in_col2 + aux_col1); + } + + + +template +inline +const SpSubview +SpSubview::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "SpSubview::submat(): indices out of bounds or incorrectly used" + ); + + return m.submat(in_row1 + aux_row1, in_col1 + aux_col1, in_row2 + aux_row1, in_col2 + aux_col1); + } + + + +template +inline +SpSubview +SpSubview::submat(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = row_span.whole; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_all ? n_rows : row_span.b; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_all ? n_cols : col_span.b; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= n_rows))) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= n_cols))), + "SpSubview::submat(): indices out of bounds or incorrectly used" + ); + + return submat(in_row1, in_col1, in_row2, in_col2); + } + + + +template +inline +const SpSubview +SpSubview::submat(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = row_span.whole; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_all ? n_rows - 1 : row_span.b; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_all ? n_cols - 1 : col_span.b; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= n_rows))) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= n_cols))), + "SpSubview::submat(): indices out of bounds or incorrectly used" + ); + + return submat(in_row1, in_col1, in_row2, in_col2); + } + + + +template +inline +SpSubview +SpSubview::operator()(const uword row_num, const span& col_span) + { + arma_extra_debug_sigprint(); + + return submat(span(row_num, row_num), col_span); + } + + + +template +inline +const SpSubview +SpSubview::operator()(const uword row_num, const span& col_span) const + { + arma_extra_debug_sigprint(); + + return submat(span(row_num, row_num), col_span); + } + + + +template +inline +SpSubview +SpSubview::operator()(const span& row_span, const uword col_num) + { + arma_extra_debug_sigprint(); + + return submat(row_span, span(col_num, col_num)); + } + + + +template +inline +const SpSubview +SpSubview::operator()(const span& row_span, const uword col_num) const + { + arma_extra_debug_sigprint(); + + return submat(row_span, span(col_num, col_num)); + } + + + +template +inline +SpSubview +SpSubview::operator()(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + return submat(row_span, col_span); + } + + + +template +inline +const SpSubview +SpSubview::operator()(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + return submat(row_span, col_span); + } + + + +template +inline +void +SpSubview::swap_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check((in_row1 >= n_rows) || (in_row2 >= n_rows), "SpSubview::swap_rows(): invalid row index"); + + const uword lstart_col = aux_col1; + const uword lend_col = aux_col1 + n_cols; + + for(uword c = lstart_col; c < lend_col; ++c) + { + const eT val = access::rw(m).at(in_row1 + aux_row1, c); + access::rw(m).at(in_row2 + aux_row1, c) = eT( access::rw(m).at(in_row1 + aux_row1, c) ); + access::rw(m).at(in_row1 + aux_row1, c) = val; + } + } + + + +template +inline +void +SpSubview::swap_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check((in_col1 >= n_cols) || (in_col2 >= n_cols), "SpSubview::swap_cols(): invalid column index"); + + const uword lstart_row = aux_row1; + const uword lend_row = aux_row1 + n_rows; + + for(uword r = lstart_row; r < lend_row; ++r) + { + const eT val = access::rw(m).at(r, in_col1 + aux_col1); + access::rw(m).at(r, in_col1 + aux_col1) = eT( access::rw(m).at(r, in_col2 + aux_col1) ); + access::rw(m).at(r, in_col2 + aux_col1) = val; + } + } + + + +template +inline +typename SpSubview::iterator +SpSubview::begin() + { + m.sync_csc(); + + return iterator(*this); + } + + + +template +inline +typename SpSubview::const_iterator +SpSubview::begin() const + { + m.sync_csc(); + + return const_iterator(*this); + } + + + +template +inline +typename SpSubview::const_iterator +SpSubview::cbegin() const + { + m.sync_csc(); + + return const_iterator(*this); + } + + + +template +inline +typename SpSubview::iterator +SpSubview::begin_col(const uword col_num) + { + m.sync_csc(); + + return iterator(*this, 0, col_num); + } + + +template +inline +typename SpSubview::const_iterator +SpSubview::begin_col(const uword col_num) const + { + m.sync_csc(); + + return const_iterator(*this, 0, col_num); + } + + + +template +inline +typename SpSubview::row_iterator +SpSubview::begin_row(const uword row_num) + { + m.sync_csc(); + + return row_iterator(*this, row_num, 0); + } + + + +template +inline +typename SpSubview::const_row_iterator +SpSubview::begin_row(const uword row_num) const + { + m.sync_csc(); + + return const_row_iterator(*this, row_num, 0); + } + + + +template +inline +typename SpSubview::iterator +SpSubview::end() + { + m.sync_csc(); + + return iterator(*this, 0, n_cols, n_nonzero, m.n_nonzero - n_nonzero); + } + + + +template +inline +typename SpSubview::const_iterator +SpSubview::end() const + { + m.sync_csc(); + + return const_iterator(*this, 0, n_cols, n_nonzero, m.n_nonzero - n_nonzero); + } + + + +template +inline +typename SpSubview::const_iterator +SpSubview::cend() const + { + m.sync_csc(); + + return const_iterator(*this, 0, n_cols, n_nonzero, m.n_nonzero - n_nonzero); + } + + + +template +inline +typename SpSubview::row_iterator +SpSubview::end_row() + { + m.sync_csc(); + + return row_iterator(*this, n_nonzero); + } + + + +template +inline +typename SpSubview::const_row_iterator +SpSubview::end_row() const + { + m.sync_csc(); + + return const_row_iterator(*this, n_nonzero); + } + + + +template +inline +typename SpSubview::row_iterator +SpSubview::end_row(const uword row_num) + { + m.sync_csc(); + + return row_iterator(*this, row_num + 1, 0); + } + + + +template +inline +typename SpSubview::const_row_iterator +SpSubview::end_row(const uword row_num) const + { + m.sync_csc(); + + return const_row_iterator(*this, row_num + 1, 0); + } + + + +template +arma_inline +bool +SpSubview::is_alias(const SpMat& X) const + { + return m.is_alias(X); + } + + + +template +inline +eT& +SpSubview::insert_element(const uword in_row, const uword in_col, const eT in_val) + { + arma_extra_debug_sigprint(); + + // This may not actually insert an element. + const uword old_n_nonzero = m.n_nonzero; + eT& retval = access::rw(m).insert_element(in_row + aux_row1, in_col + aux_col1, in_val); + // Update n_nonzero (if necessary). + access::rw(n_nonzero) += (m.n_nonzero - old_n_nonzero); + + return retval; + } + + + +template +inline +void +SpSubview::delete_element(const uword in_row, const uword in_col) + { + arma_extra_debug_sigprint(); + + // This may not actually delete an element. + const uword old_n_nonzero = m.n_nonzero; + access::rw(m).delete_element(in_row + aux_row1, in_col + aux_col1); + access::rw(n_nonzero) -= (old_n_nonzero - m.n_nonzero); + } + + + +template +inline +void +SpSubview::invalidate_cache() const + { + arma_extra_debug_sigprint(); + + m.invalidate_cache(); + } + + + +// +// +// + + + +template +inline +SpSubview_col::SpSubview_col(const SpMat& in_m, const uword in_col) + : SpSubview(in_m, 0, in_col, in_m.n_rows, 1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpSubview_col::SpSubview_col(const SpMat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows) + : SpSubview(in_m, in_row1, in_col, in_n_rows, 1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +SpSubview_col::operator=(const SpSubview& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +inline +void +SpSubview_col::operator=(const SpSubview_col& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); // interprets 'SpSubview_col' as 'SpSubview' + } + + + +template +template +inline +void +SpSubview_col::operator=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +template +inline +void +SpSubview_col::operator=(const Base& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +inline +const SpOp,spop_htrans> +SpSubview_col::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpSubview_col::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpSubview_col::st() const + { + return SpOp,spop_strans>(*this); + } + + + +// +// +// + + + +template +inline +SpSubview_row::SpSubview_row(const SpMat& in_m, const uword in_row) + : SpSubview(in_m, in_row, 0, 1, in_m.n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpSubview_row::SpSubview_row(const SpMat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols) + : SpSubview(in_m, in_row, in_col1, 1, in_n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +SpSubview_row::operator=(const SpSubview& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +inline +void +SpSubview_row::operator=(const SpSubview_row& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); // interprets 'SpSubview_row' as 'SpSubview' + } + + + +template +template +inline +void +SpSubview_row::operator=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +template +inline +void +SpSubview_row::operator=(const Base& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +inline +const SpOp,spop_htrans> +SpSubview_row::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpSubview_row::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpSubview_row::st() const + { + return SpOp,spop_strans>(*this); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpToDGlue_bones.hpp b/src/armadillo/include/armadillo_bits/SpToDGlue_bones.hpp new file mode 100644 index 0000000..7d4ce9d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpToDGlue_bones.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpToDGlue +//! @{ + + + +template +class SpToDGlue : public Base< typename T1::elem_type, SpToDGlue > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline explicit SpToDGlue(const T1& in_A, const T2& in_B); + inline ~SpToDGlue(); + + const T1& A; //!< first operand; must be derived from Base or SpBase + const T2& B; //!< second operand; must be derived from Base or SpBase + + static constexpr bool is_row = glue_type::template traits::is_row; + static constexpr bool is_col = glue_type::template traits::is_col; + static constexpr bool is_xvec = glue_type::template traits::is_xvec; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpToDGlue_meat.hpp b/src/armadillo/include/armadillo_bits/SpToDGlue_meat.hpp new file mode 100644 index 0000000..1d3d095 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpToDGlue_meat.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpToDGlue +//! @{ + + + +template +inline +SpToDGlue::SpToDGlue(const T1& in_A, const T2& in_B) + : A(in_A) + , B(in_B) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpToDGlue::~SpToDGlue() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpToDOp_bones.hpp b/src/armadillo/include/armadillo_bits/SpToDOp_bones.hpp new file mode 100644 index 0000000..b8ae6cc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpToDOp_bones.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpToDOp +//! @{ + + + +//! Class for storing data required for delayed unary operations on a sparse +//! matrix that produce a dense matrix; the data for storage may include +//! the operand (eg. the matrix to which the operation is to be applied) and the unary operator (eg. inverse). +//! The operand is stored as a reference (which can be optimised away), +//! while the operator is "stored" through the template definition (op_type). +//! The operands can be 'SpMat', 'SpRow', 'SpCol', 'SpOp', and 'SpGlue'. +//! Note that as 'SpGlue' can be one of the operands, more than one matrix can be stored. +//! +//! For example, we could have: +//! SpToDOp< SpGlue< SpMat, SpMat, sp_glue_times >, op_sp_plus > + +template +class SpToDOp : public Base< typename T1::elem_type, SpToDOp > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline explicit SpToDOp(const T1& in_m); + inline SpToDOp(const T1& in_m, const elem_type in_aux); + inline ~SpToDOp(); + + arma_aligned const T1& m; //!< the operand; must be derived from SpBase + arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 + + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpToDOp_meat.hpp b/src/armadillo/include/armadillo_bits/SpToDOp_meat.hpp new file mode 100644 index 0000000..66ab640 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpToDOp_meat.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpToDOp +//! @{ + + + +template +inline +SpToDOp::SpToDOp(const T1& in_m) + : m(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpToDOp::SpToDOp(const T1& in_m, const typename T1::elem_type in_aux) + : m(in_m) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +SpToDOp::~SpToDOp() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpValProxy_bones.hpp b/src/armadillo/include/armadillo_bits/SpValProxy_bones.hpp new file mode 100644 index 0000000..af9a52d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpValProxy_bones.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpValProxy +//! @{ + + +// Sparse value proxy class, to prevent inserting 0s into sparse matrices. +// T1 must be either SpMat or SpSubview. +// This class uses T1::insert_element(), T1::delete_element(), T1::invalidate_cache() + +template +class SpValProxy + { + public: + + typedef typename T1::elem_type eT; // Convenience typedef + + friend class SpMat; + friend class SpSubview; + + /** + * Create the sparse value proxy. + * Otherwise, pass a pointer to a reference of the value. + */ + arma_inline SpValProxy(uword row, uword col, T1& in_parent, eT* in_val_ptr = nullptr); + inline SpValProxy() = delete; + + //! For swapping operations. + arma_inline SpValProxy& operator=(const SpValProxy& rhs); + template + arma_inline SpValProxy& operator=(const SpValProxy& rhs); + + //! Overload all of the potential operators. + + //! First, the ones that could modify a value. + inline SpValProxy& operator= (const eT rhs); + inline SpValProxy& operator+=(const eT rhs); + inline SpValProxy& operator-=(const eT rhs); + inline SpValProxy& operator*=(const eT rhs); + inline SpValProxy& operator/=(const eT rhs); + + inline SpValProxy& operator++(); + inline SpValProxy& operator--(); + + inline eT operator++(const int); + inline eT operator--(const int); + + //! This will work for any other operations that do not modify a value. + arma_inline operator eT() const; + + arma_inline typename get_pod_type::result real() const; + arma_inline typename get_pod_type::result imag() const; + + + private: + + // Deletes the element if it is zero; NOTE: does not check if val_ptr == nullptr + arma_inline void check_zero(); + + arma_aligned const uword row; + arma_aligned const uword col; + + arma_aligned eT* val_ptr; + + arma_aligned T1& parent; // We will call this object if we need to insert or delete an element. + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/SpValProxy_meat.hpp b/src/armadillo/include/armadillo_bits/SpValProxy_meat.hpp new file mode 100644 index 0000000..242ec07 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/SpValProxy_meat.hpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpValProxy +//! @{ + + +//! SpValProxy implementation. +template +arma_inline +SpValProxy::SpValProxy(uword in_row, uword in_col, T1& in_parent, eT* in_val_ptr) + : row(in_row) + , col(in_col) + , val_ptr(in_val_ptr) + , parent(in_parent) + { + // Nothing to do. + } + + + +template +arma_inline +SpValProxy& +SpValProxy::operator=(const SpValProxy& rhs) + { + return (*this).operator=(eT(rhs)); + } + + + +template +template +arma_inline +SpValProxy& +SpValProxy::operator=(const SpValProxy& rhs) + { + return (*this).operator=(eT(rhs)); + } + + + +template +inline +SpValProxy& +SpValProxy::operator=(const eT rhs) + { + if(rhs != eT(0)) // A nonzero element is being assigned. + { + if(val_ptr) + { + // The value exists and merely needs to be updated. + *val_ptr = rhs; + parent.invalidate_cache(); + } + else + { + // The value is nonzero and must be inserted. + val_ptr = &parent.insert_element(row, col, rhs); + } + } + else // A zero is being assigned.~ + { + if(val_ptr) + { + // The element exists, but we need to remove it, because it is being set to 0. + parent.delete_element(row, col); + val_ptr = nullptr; + } + + // If the element does not exist, we do not need to do anything at all. + } + + return *this; + } + + + +template +inline +SpValProxy& +SpValProxy::operator+=(const eT rhs) + { + if(val_ptr) + { + // The value already exists and merely needs to be updated. + *val_ptr += rhs; + parent.invalidate_cache(); + check_zero(); + } + else + { + if(rhs != eT(0)) + { + // The value does not exist and must be inserted. + val_ptr = &parent.insert_element(row, col, rhs); + } + } + + return *this; + } + + + +template +inline +SpValProxy& +SpValProxy::operator-=(const eT rhs) + { + if(val_ptr) + { + // The value already exists and merely needs to be updated. + *val_ptr -= rhs; + parent.invalidate_cache(); + check_zero(); + } + else + { + if(rhs != eT(0)) + { + // The value does not exist and must be inserted. + val_ptr = &parent.insert_element(row, col, -rhs); + } + } + + return *this; + } + + + +template +inline +SpValProxy& +SpValProxy::operator*=(const eT rhs) + { + if(rhs != eT(0)) + { + if(val_ptr) + { + // The value already exists and merely needs to be updated. + *val_ptr *= rhs; + parent.invalidate_cache(); + check_zero(); + } + } + else + { + if(val_ptr) + { + // Since we are multiplying by zero, the value can be deleted. + parent.delete_element(row, col); + val_ptr = nullptr; + } + } + + return *this; + } + + + +template +inline +SpValProxy& +SpValProxy::operator/=(const eT rhs) + { + if(rhs != eT(0)) // I hope this is true! + { + if(val_ptr) + { + *val_ptr /= rhs; + parent.invalidate_cache(); + check_zero(); + } + } + else + { + if(val_ptr) + { + *val_ptr /= rhs; // That is where it gets ugly. + // Now check if it's 0. + if(*val_ptr == eT(0)) + { + parent.delete_element(row, col); + val_ptr = nullptr; + } + } + else + { + eT val = eT(0) / rhs; // This may vary depending on type and implementation. + + if(val != eT(0)) + { + // Ok, now we have to insert it. + val_ptr = &parent.insert_element(row, col, val); + } + } + } + + return *this; + } + + + +template +inline +SpValProxy& +SpValProxy::operator++() + { + if(val_ptr) + { + (*val_ptr) += eT(1); + parent.invalidate_cache(); + check_zero(); + } + else + { + val_ptr = &parent.insert_element(row, col, eT(1)); + } + + return *this; + } + + + +template +inline +SpValProxy& +SpValProxy::operator--() + { + if(val_ptr) + { + (*val_ptr) -= eT(1); + parent.invalidate_cache(); + check_zero(); + } + else + { + val_ptr = &parent.insert_element(row, col, eT(-1)); + } + + return *this; + } + + + +template +inline +typename T1::elem_type +SpValProxy::operator++(const int) + { + if(val_ptr) + { + (*val_ptr) += eT(1); + parent.invalidate_cache(); + check_zero(); + } + else + { + val_ptr = &parent.insert_element(row, col, eT(1)); + } + + if(val_ptr) // It may have changed to now be 0. + { + return *(val_ptr) - eT(1); + } + else + { + return eT(0); + } + } + + + +template +inline +typename T1::elem_type +SpValProxy::operator--(const int) + { + if(val_ptr) + { + (*val_ptr) -= eT(1); + parent.invalidate_cache(); + check_zero(); + } + else + { + val_ptr = &parent.insert_element(row, col, eT(-1)); + } + + if(val_ptr) // It may have changed to now be 0. + { + return *(val_ptr) + eT(1); + } + else + { + return eT(0); + } + } + + + +template +arma_inline +SpValProxy::operator eT() const + { + return (val_ptr) ? eT(*val_ptr) : eT(0); + } + + + +template +arma_inline +typename get_pod_type::eT>::result +SpValProxy::real() const + { + typedef typename get_pod_type::result T; + + return T( access::tmp_real( (val_ptr) ? eT(*val_ptr) : eT(0) ) ); + } + + + +template +arma_inline +typename get_pod_type::eT>::result +SpValProxy::imag() const + { + typedef typename get_pod_type::result T; + + return T( access::tmp_imag( (val_ptr) ? eT(*val_ptr) : eT(0) ) ); + } + + + +template +arma_inline +void +SpValProxy::check_zero() + { + if(*val_ptr == eT(0)) + { + parent.delete_element(row, col); + val_ptr = nullptr; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/access.hpp b/src/armadillo/include/armadillo_bits/access.hpp new file mode 100644 index 0000000..77db862 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/access.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup access +//! @{ + + +class access + { + public: + + //! internal function to allow modification of data declared as read-only (use with caution) + template constexpr static T1& rw (const T1& x) { return const_cast(x); } + template constexpr static T1*& rwp(const T1* const& x) { return const_cast(x); } + + //! internal function to obtain the real part of either a plain number or a complex number + template constexpr static const eT& tmp_real(const eT& X) { return X; } + template constexpr static const T tmp_real(const std::complex& X) { return X.real(); } + + //! internal function to obtain the imag part of either a plain number or a complex number + template constexpr static const eT tmp_imag(const eT ) { return eT(0); } + template constexpr static const T tmp_imag(const std::complex& X) { return X.imag(); } + + //! internal function to work around braindead compilers + template constexpr static const typename enable_if2::no, const eT&>::result alt_conj(const eT& X) { return X; } + template arma_inline static const typename enable_if2::yes, const eT >::result alt_conj(const eT& X) { return std::conj(X); } + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_cmath.hpp b/src/armadillo/include/armadillo_bits/arma_cmath.hpp new file mode 100644 index 0000000..22df4bf --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_cmath.hpp @@ -0,0 +1,378 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup arma_cmath +//! @{ + + + +// +// wrappers for isfinite + + +template +inline +bool +arma_isfinite(eT) + { + return true; + } + + + +template<> +inline +bool +arma_isfinite(float x) + { + return std::isfinite(x); + } + + + +template<> +inline +bool +arma_isfinite(double x) + { + return std::isfinite(x); + } + + + +template +inline +bool +arma_isfinite(const std::complex& x) + { + return ( arma_isfinite(x.real()) && arma_isfinite(x.imag()) ); + } + + + +// +// wrappers for isinf + + +template +inline +bool +arma_isinf(eT) + { + return false; + } + + + +template<> +inline +bool +arma_isinf(float x) + { + return std::isinf(x); + } + + + +template<> +inline +bool +arma_isinf(double x) + { + return std::isinf(x); + } + + + +template +inline +bool +arma_isinf(const std::complex& x) + { + return ( arma_isinf(x.real()) || arma_isinf(x.imag()) ); + } + + + +// +// wrappers for isnan + + +template +inline +bool +arma_isnan(eT val) + { + arma_ignore(val); + + return false; + } + + + +template<> +inline +bool +arma_isnan(float x) + { + return std::isnan(x); + } + + + +template<> +inline +bool +arma_isnan(double x) + { + return std::isnan(x); + } + + + +template +inline +bool +arma_isnan(const std::complex& x) + { + return ( arma_isnan(x.real()) || arma_isnan(x.imag()) ); + } + + + +// +// implementation of arma_sign() + + +template +constexpr +typename arma_unsigned_integral_only::result +arma_sign(const eT x) + { + return (x > eT(0)) ? eT(+1) : eT(0); + } + + + +template +constexpr +typename arma_signed_integral_only::result +arma_sign(const eT x) + { + return (x > eT(0)) ? eT(+1) : ( (x < eT(0)) ? eT(-1) : eT(0) ); + } + + + +template +constexpr +typename arma_real_only::result +arma_sign(const eT x) + { + return (x > eT(0)) ? eT(+1) : ( (x < eT(0)) ? eT(-1) : ((x == eT(0)) ? eT(0) : x) ); + } + + + +template +inline +typename arma_cx_only::result +arma_sign(const eT& x) + { + typedef typename eT::value_type T; + + const T abs_x = std::abs(x); + + return (abs_x != T(0)) ? (x / abs_x) : x; + } + + + +// +// wrappers for hypot(x, y) = sqrt(x^2 + y^2) + + +template +inline +eT +arma_hypot(const eT x, const eT y) + { + arma_ignore(x); + arma_ignore(y); + + arma_stop_runtime_error("arma_hypot(): not implemented for integer or complex element types"); + + return eT(0); + } + + + +template<> +inline +float +arma_hypot(const float x, const float y) + { + return std::hypot(x, y); + } + + + +template<> +inline +double +arma_hypot(const double x, const double y) + { + return std::hypot(x, y); + } + + + +// +// implementation of arma_sinc() + + +template +inline +eT +arma_sinc_generic(const eT x) + { + typedef typename get_pod_type::result T; + + const eT tmp = Datum::pi * x; + + return (tmp == eT(0)) ? eT(1) : eT( std::sin(tmp) / tmp ); + } + + + +template +inline +eT +arma_sinc(const eT x) + { + return eT( arma_sinc_generic( double(x) ) ); + } + + + +template<> +inline +float +arma_sinc(const float x) + { + return arma_sinc_generic(x); + } + + + +template<> +inline +double +arma_sinc(const double x) + { + return arma_sinc_generic(x); + } + + + +template +inline +std::complex +arma_sinc(const std::complex& x) + { + return arma_sinc_generic(x); + } + + + +// +// wrappers for arg() + + +template +struct arma_arg + { + static + inline + eT + eval(const eT x) + { + return eT( std::arg(x) ); + } + }; + + + +template<> +struct arma_arg + { + static + inline + float + eval(const float x) + { + return std::arg(x); + } + }; + + + +template<> +struct arma_arg + { + static + inline + double + eval(const double x) + { + return std::arg(x); + } + }; + + + +template<> +struct arma_arg< std::complex > + { + static + inline + float + eval(const std::complex& x) + { + return std::arg(x); + } + }; + + + +template<> +struct arma_arg< std::complex > + { + static + inline + double + eval(const std::complex& x) + { + return std::arg(x); + } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_config.hpp b/src/armadillo/include/armadillo_bits/arma_config.hpp new file mode 100644 index 0000000..3670199 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_config.hpp @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arma_config +//! @{ + + + +struct arma_config + { + #if defined(ARMA_MAT_PREALLOC) + static constexpr uword mat_prealloc = (sword(ARMA_MAT_PREALLOC) > 0) ? uword(ARMA_MAT_PREALLOC) : 1; + #else + static constexpr uword mat_prealloc = 16; + #endif + + + #if defined(ARMA_OPENMP_THRESHOLD) + static constexpr uword mp_threshold = (sword(ARMA_OPENMP_THRESHOLD) > 0) ? uword(ARMA_OPENMP_THRESHOLD) : 320; + #else + static constexpr uword mp_threshold = 320; + #endif + + + #if defined(ARMA_OPENMP_THREADS) + static constexpr uword mp_threads = (sword(ARMA_OPENMP_THREADS) > 0) ? uword(ARMA_OPENMP_THREADS) : 8; + #else + static constexpr uword mp_threads = 8; + #endif + + + #if defined(ARMA_OPTIMISE_BAND) + static constexpr bool optimise_band = true; + #else + static constexpr bool optimise_band = false; + #endif + + + #if defined(ARMA_OPTIMISE_SYM) + static constexpr bool optimise_sym = true; + #else + static constexpr bool optimise_sym = false; + #endif + + + #if defined(ARMA_OPTIMISE_INVEXPR) + static constexpr bool optimise_invexpr = true; + #else + static constexpr bool optimise_invexpr = false; + #endif + + + #if defined(ARMA_CHECK_NONFINITE) + static constexpr bool check_nonfinite = true; + #else + static constexpr bool check_nonfinite = false; + #endif + + + #if defined(ARMA_USE_LAPACK) + static constexpr bool lapack = true; + #else + static constexpr bool lapack = false; + #endif + + + #if defined(ARMA_USE_BLAS) + static constexpr bool blas = true; + #else + static constexpr bool blas = false; + #endif + + + #if defined(ARMA_USE_ATLAS) + static constexpr bool atlas = true; + #else + static constexpr bool atlas = false; + #endif + + + #if defined(ARMA_USE_NEWARP) + static constexpr bool newarp = true; + #else + static constexpr bool newarp = false; + #endif + + + #if defined(ARMA_USE_ARPACK) + static constexpr bool arpack = true; + #else + static constexpr bool arpack = false; + #endif + + + #if defined(ARMA_USE_SUPERLU) + static constexpr bool superlu = true; + #else + static constexpr bool superlu = false; + #endif + + + #if defined(ARMA_USE_HDF5) + static constexpr bool hdf5 = true; + #else + static constexpr bool hdf5 = false; + #endif + + + #if defined(ARMA_NO_DEBUG) + static constexpr bool debug = false; + #else + static constexpr bool debug = true; + #endif + + + #if defined(ARMA_EXTRA_DEBUG) + static constexpr bool extra_debug = true; + #else + static constexpr bool extra_debug = false; + #endif + + + #if defined(ARMA_GOOD_COMPILER) + static constexpr bool good_comp = true; + #else + static constexpr bool good_comp = false; + #endif + + + #if ( \ + defined(ARMA_EXTRA_MAT_PROTO) || defined(ARMA_EXTRA_MAT_MEAT) \ + || defined(ARMA_EXTRA_COL_PROTO) || defined(ARMA_EXTRA_COL_MEAT) \ + || defined(ARMA_EXTRA_ROW_PROTO) || defined(ARMA_EXTRA_ROW_MEAT) \ + || defined(ARMA_EXTRA_CUBE_PROTO) || defined(ARMA_EXTRA_CUBE_MEAT) \ + || defined(ARMA_EXTRA_FIELD_PROTO) || defined(ARMA_EXTRA_FIELD_MEAT) \ + || defined(ARMA_EXTRA_SPMAT_PROTO) || defined(ARMA_EXTRA_SPMAT_MEAT) \ + || defined(ARMA_EXTRA_SPCOL_PROTO) || defined(ARMA_EXTRA_SPCOL_MEAT) \ + || defined(ARMA_EXTRA_SPROW_PROTO) || defined(ARMA_EXTRA_SPROW_MEAT) \ + || defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) \ + || defined(ARMA_ALIEN_MEM_FREE_FUNCTION) \ + ) + static constexpr bool extra_code = true; + #else + static constexpr bool extra_code = false; + #endif + + + #if defined(ARMA_HAVE_CXX14) + static constexpr bool cxx14 = true; + #else + static constexpr bool cxx14 = false; + #endif + + + #if defined(ARMA_HAVE_CXX17) + static constexpr bool cxx17 = true; + #else + static constexpr bool cxx17 = false; + #endif + + + #if defined(ARMA_HAVE_CXX20) + static constexpr bool cxx20 = true; + #else + static constexpr bool cxx20 = false; + #endif + + + #if (!defined(ARMA_DONT_USE_STD_MUTEX)) + static constexpr bool std_mutex = true; + #else + static constexpr bool std_mutex = false; + #endif + + + #if (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) + static constexpr bool posix = true; + #else + static constexpr bool posix = false; + #endif + + + #if defined(ARMA_USE_WRAPPER) + static constexpr bool wrapper = true; + #else + static constexpr bool wrapper = false; + #endif + + + #if defined(ARMA_USE_OPENMP) + static constexpr bool openmp = true; + #else + static constexpr bool openmp = false; + #endif + + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + static constexpr bool hidden_args = true; + #else + static constexpr bool hidden_args = false; + #endif + + + #if defined(ARMA_DONT_ZERO_INIT) + static constexpr bool zero_init = false; + #else + static constexpr bool zero_init = true; + #endif + + + #if defined(ARMA_FAST_MATH) + static constexpr bool fast_math = true; + #else + static constexpr bool fast_math = false; + #endif + + + #if defined(ARMA_FAST_MATH) && !defined(ARMA_DONT_PRINT_FAST_MATH_WARNING) + static constexpr bool fast_math_warn = true; + #else + static constexpr bool fast_math_warn = false; + #endif + + + #if (!defined(ARMA_DONT_TREAT_TEXT_AS_BINARY)) + static constexpr bool text_as_binary = true; + #else + static constexpr bool text_as_binary = false; + #endif + + + static constexpr uword warn_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : 0; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_forward.hpp b/src/armadillo/include/armadillo_bits/arma_forward.hpp new file mode 100644 index 0000000..4b2f37f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_forward.hpp @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +using std::cout; +using std::cerr; +using std::endl; +using std::ios; +using std::size_t; + +template struct Base; +template struct BaseCube; + +template class Mat; +template class Col; +template class Row; +template class Cube; +template class xvec_htrans; +template class field; + +template class xtrans_mat; + + +template class subview; +template class subview_col; +template class subview_cols; +template class subview_row; +template class subview_row_strans; +template class subview_row_htrans; +template class subview_cube; +template class subview_field; + +template class SpValProxy; +template class SpMat; +template class SpCol; +template class SpRow; +template class SpSubview; +template class SpSubview_col; +template class SpSubview_row; + +template class diagview; +template class spdiagview; + +template class MapMat; +template class MapMat_val; +template class SpMat_MapMat_val; +template class SpSubview_MapMat_val; + +template class subview_elem1; +template class subview_elem2; + +template class subview_each1; +template class subview_each2; + +template class subview_cube_each1; +template class subview_cube_each2; +template class subview_cube_slices; + +template class SpSubview_col_list; + + +class SizeMat; +class SizeCube; + +class arma_empty_class {}; + +class diskio; + +class op_strans; +class op_htrans; +class op_htrans2; +class op_inv_gen_default; +class op_inv_spd_default; +class op_inv_gen_full; +class op_inv_spd_full; +class op_diagmat; +class op_trimat; +class op_vectorise_row; +class op_vectorise_col; + +class op_row_as_mat; +class op_col_as_mat; + +class glue_times; +class glue_times_diag; + +class glue_rel_lt; +class glue_rel_gt; +class glue_rel_lteq; +class glue_rel_gteq; +class glue_rel_eq; +class glue_rel_noteq; +class glue_rel_and; +class glue_rel_or; + +class op_rel_lt_pre; +class op_rel_lt_post; +class op_rel_gt_pre; +class op_rel_gt_post; +class op_rel_lteq_pre; +class op_rel_lteq_post; +class op_rel_gteq_pre; +class op_rel_gteq_post; +class op_rel_eq; +class op_rel_noteq; + +class gen_eye; +class gen_ones; +class gen_zeros; + + + +class spop_strans; +class spop_htrans; +class spop_vectorise_row; +class spop_vectorise_col; + +class spglue_plus; +class spglue_minus; +class spglue_schur; +class spglue_times; +class spglue_max; +class spglue_min; +class spglue_rel_lt; +class spglue_rel_gt; + + + +class op_internal_equ; +class op_internal_plus; +class op_internal_minus; +class op_internal_schur; +class op_internal_div; + + + +struct traits_op_default + { + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + }; + }; + + +struct traits_op_xvec + { + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; + }; + }; + + +struct traits_op_col + { + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + }; + }; + + +struct traits_op_row + { + template + struct traits + { + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + }; + }; + + +struct traits_op_passthru + { + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + }; + }; + + +struct traits_glue_default + { + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + }; + }; + + +struct traits_glue_or + { + template + struct traits + { + static constexpr bool is_row = (T1::is_row || T2::is_row ); + static constexpr bool is_col = (T1::is_col || T2::is_col ); + static constexpr bool is_xvec = (T1::is_xvec || T2::is_xvec); + }; + }; + + + +template class gemm; +template class gemv; + + +template< typename eT, typename gen_type> class Gen; + +template< typename T1, typename op_type> class Op; +template< typename T1, typename eop_type> class eOp; +template< typename T1, typename op_type> class SpToDOp; +template< typename T1, typename op_type> class CubeToMatOp; +template class mtOp; + +template< typename T1, typename T2, typename glue_type> class Glue; +template< typename T1, typename T2, typename eglue_type> class eGlue; +template< typename T1, typename T2, typename glue_type> class SpToDGlue; +template class mtGlue; + + + +template< typename eT, typename gen_type> class GenCube; + +template< typename T1, typename op_type> class OpCube; +template< typename T1, typename eop_type> class eOpCube; +template class mtOpCube; + +template< typename T1, typename T2, typename glue_type> class GlueCube; +template< typename T1, typename T2, typename eglue_type> class eGlueCube; +template class mtGlueCube; + + +template struct Proxy; +template struct ProxyCube; + +template class diagmat_proxy; + +template struct unwrap; +template struct quasi_unwrap; +template struct unwrap_cube; +template struct unwrap_spmat; + + + + +struct state_type + { + #if defined(ARMA_USE_OPENMP) + int state; + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + std::atomic state; + #else + int state; + #endif + + arma_inline state_type() : state(int(0)) {} + + // openmp: "omp atomic" does an implicit flush on the affected variable + // C++11: std::atomic<>::load() and std::atomic<>::store() use std::memory_order_seq_cst by default, which has an implied fence + + arma_inline + operator int () const + { + int out; + + #if defined(ARMA_USE_OPENMP) + #pragma omp atomic read + out = state; + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + out = state.load(); + #else + out = state; + #endif + + return out; + } + + arma_inline + void + operator= (const int in_state) + { + #if defined(ARMA_USE_OPENMP) + #pragma omp atomic write + state = in_state; + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + state.store(in_state); + #else + state = in_state; + #endif + } + }; + + +template< typename T1, typename spop_type> class SpOp; +template class mtSpOp; + +template< typename T1, typename T2, typename spglue_type> class SpGlue; +template class mtSpGlue; + + +template struct SpProxy; + + + +struct arma_vec_indicator {}; +struct arma_fixed_indicator {}; +struct arma_reserve_indicator {}; +struct arma_layout_indicator {}; + +template struct arma_initmode_indicator {}; + +struct arma_zeros_indicator : public arma_initmode_indicator {}; +struct arma_nozeros_indicator : public arma_initmode_indicator {}; + + +//! \addtogroup injector +//! @{ + +template struct injector_end_of_row {}; + +// DEPRECATED: DO NOT USE IN NEW CODE +static const injector_end_of_row<> endr = injector_end_of_row<>(); +//!< endr indicates "end of row" when using the << operator; +//!< similar conceptual meaning to std::endl + +//! @} + + + +//! \addtogroup diskio +//! @{ + + +enum struct file_type : unsigned int + { + file_type_unknown, + auto_detect, //!< attempt to automatically detect the file type + raw_ascii, //!< raw text (ASCII), without a header + arma_ascii, //!< Armadillo text format, with a header specifying matrix type and size + csv_ascii, //!< comma separated values (CSV), without a header + raw_binary, //!< raw binary format (machine dependent), without a header + arma_binary, //!< Armadillo binary format (machine dependent), with a header specifying matrix type and size + pgm_binary, //!< Portable Grey Map (greyscale image) + ppm_binary, //!< Portable Pixel Map (colour image), used by the field and cube classes + hdf5_binary, //!< HDF5: open binary format, not specific to Armadillo, which can store arbitrary data + hdf5_binary_trans, //!< [NOTE: DO NOT USE - deprecated] as per hdf5_binary, but save/load the data with columns transposed to rows + coord_ascii, //!< simple co-ordinate format for sparse matrices (indices start at zero) + ssv_ascii, //!< similar to csv_ascii; uses semicolon (;) instead of comma (,) as the separator + }; + + +static constexpr file_type file_type_unknown = file_type::file_type_unknown; +static constexpr file_type auto_detect = file_type::auto_detect; +static constexpr file_type raw_ascii = file_type::raw_ascii; +static constexpr file_type arma_ascii = file_type::arma_ascii; +static constexpr file_type csv_ascii = file_type::csv_ascii; +static constexpr file_type raw_binary = file_type::raw_binary; +static constexpr file_type arma_binary = file_type::arma_binary; +static constexpr file_type pgm_binary = file_type::pgm_binary; +static constexpr file_type ppm_binary = file_type::ppm_binary; +static constexpr file_type hdf5_binary = file_type::hdf5_binary; +static constexpr file_type hdf5_binary_trans = file_type::hdf5_binary_trans; +static constexpr file_type coord_ascii = file_type::coord_ascii; +static constexpr file_type ssv_ascii = file_type::ssv_ascii; + + +struct hdf5_name; +struct csv_name; + + +//! @} + + + +//! \addtogroup fn_spsolve +//! @{ + + +struct spsolve_opts_base + { + const unsigned int id; + + inline spsolve_opts_base(const unsigned int in_id) : id(in_id) {} + }; + + +struct spsolve_opts_none : public spsolve_opts_base + { + inline spsolve_opts_none() : spsolve_opts_base(0) {} + }; + + +struct superlu_opts : public spsolve_opts_base + { + typedef enum {NATURAL, MMD_ATA, MMD_AT_PLUS_A, COLAMD} permutation_type; + + typedef enum {REF_NONE, REF_SINGLE, REF_DOUBLE, REF_EXTRA} refine_type; + + bool allow_ugly; + bool equilibrate; + bool symmetric; + double pivot_thresh; + permutation_type permutation; + refine_type refine; + + inline superlu_opts() + : spsolve_opts_base(1) + { + allow_ugly = false; + equilibrate = false; + symmetric = false; + pivot_thresh = 1.0; + permutation = COLAMD; + refine = REF_NONE; + } + }; + + +//! @} + + + +//! \ingroup fn_eigs_sym fs_eigs_gen +//! @{ + + +struct eigs_opts + { + double tol; // tolerance + unsigned int maxiter; // max iterations + unsigned int subdim; // subspace dimension + + inline eigs_opts() + { + tol = 0.0; + maxiter = 1000; + subdim = 0; + } + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_ostream_bones.hpp b/src/armadillo/include/armadillo_bits/arma_ostream_bones.hpp new file mode 100644 index 0000000..e59c26f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_ostream_bones.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arma_ostream +//! @{ + + + +class arma_ostream_state + { + private: + + const ios::fmtflags orig_flags; + const std::streamsize orig_precision; + const std::streamsize orig_width; + const char orig_fill; + + + public: + + inline arma_ostream_state(const std::ostream& o); + + inline void restore(std::ostream& o) const; + }; + + + +class arma_ostream + { + public: + + template inline static std::streamsize modify_stream(std::ostream& o, const eT* data, const uword n_elem); + template inline static std::streamsize modify_stream(std::ostream& o, const std::complex* data, const uword n_elem); + template inline static std::streamsize modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_not_cx::result* junk = nullptr); + template inline static std::streamsize modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_cx_only::result* junk = nullptr); + + template inline static void print_elem_zero(std::ostream& o, const bool modify); + + template inline static void print_elem(std::ostream& o, const eT& x, const bool modify); + template inline static void raw_print_elem(std::ostream& o, const eT& x); + + template inline static void print_elem(std::ostream& o, const std::complex& x, const bool modify); + template inline static void raw_print_elem(std::ostream& o, const std::complex& x); + + template arma_cold inline static void print(std::ostream& o, const Mat& m, const bool modify); + template arma_cold inline static void print(std::ostream& o, const Cube& m, const bool modify); + + template arma_cold inline static void print(std::ostream& o, const field& m); + template arma_cold inline static void print(std::ostream& o, const subview_field& m); + + template arma_cold inline static void print_dense(std::ostream& o, const SpMat& m, const bool modify); + template arma_cold inline static void print(std::ostream& o, const SpMat& m, const bool modify); + + arma_cold inline static void print(std::ostream& o, const SizeMat& S); + arma_cold inline static void print(std::ostream& o, const SizeCube& S); + + template arma_cold inline static void brief_print(std::ostream& o, const Mat& m, const bool print_size = true); + template arma_cold inline static void brief_print(std::ostream& o, const Cube& m); + template arma_cold inline static void brief_print(std::ostream& o, const SpMat& m); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_ostream_meat.hpp b/src/armadillo/include/armadillo_bits/arma_ostream_meat.hpp new file mode 100644 index 0000000..dbd4b6c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_ostream_meat.hpp @@ -0,0 +1,1274 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arma_ostream +//! @{ + + + +inline +arma_ostream_state::arma_ostream_state(const std::ostream& o) + : orig_flags (o.flags()) + , orig_precision(o.precision()) + , orig_width (o.width()) + , orig_fill (o.fill()) + { + } + + + +inline +void +arma_ostream_state::restore(std::ostream& o) const + { + o.flags (orig_flags); + o.precision(orig_precision); + o.width (orig_width); + o.fill (orig_fill); + } + + + +// +// + + + +template +inline +std::streamsize +arma_ostream::modify_stream(std::ostream& o, const eT* data, const uword n_elem) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + + o.fill(' '); + + std::streamsize cell_width; + + bool use_layout_B = false; + bool use_layout_C = false; + bool use_layout_D = false; + + for(uword i=0; i 4) && (is_same_type::yes || is_same_type::yes) >::geq(val, eT(+10000000000)) ) + || + ( cond_rel< (sizeof(eT) > 4) && is_same_type::yes >::leq(val, eT(-10000000000)) ) + ) + { + use_layout_D = true; + break; + } + + if( + ( val >= eT(+100) ) + || + //( (is_signed::value) && (val <= eT(-100)) ) || + //( (is_non_integral::value) && (val > eT(0)) && (val <= eT(+1e-4)) ) || + //( (is_non_integral::value) && (is_signed::value) && (val < eT(0)) && (val >= eT(-1e-4)) ) + ( + cond_rel< is_signed::value >::leq(val, eT(-100)) + ) + || + ( + cond_rel< is_non_integral::value >::gt(val, eT(0)) + && + cond_rel< is_non_integral::value >::leq(val, eT(+1e-4)) + ) + || + ( + cond_rel< is_non_integral::value && is_signed::value >::lt(val, eT(0)) + && + cond_rel< is_non_integral::value && is_signed::value >::geq(val, eT(-1e-4)) + ) + ) + { + use_layout_C = true; + break; + } + + if( + // (val >= eT(+10)) || ( (is_signed::value) && (val <= eT(-10)) ) + (val >= eT(+10)) || ( cond_rel< is_signed::value >::leq(val, eT(-10)) ) + ) + { + use_layout_B = true; + } + } + + if(use_layout_D) + { + o.setf(ios::scientific); + o.setf(ios::right); + o.unsetf(ios::fixed); + o.precision(4); + cell_width = 21; + } + else + if(use_layout_C) + { + o.setf(ios::scientific); + o.setf(ios::right); + o.unsetf(ios::fixed); + o.precision(4); + cell_width = 13; + } + else + if(use_layout_B) + { + o.unsetf(ios::scientific); + o.setf(ios::right); + o.setf(ios::fixed); + o.precision(4); + cell_width = 10; + } + else + { + o.unsetf(ios::scientific); + o.setf(ios::right); + o.setf(ios::fixed); + o.precision(4); + cell_width = 9; + } + + return cell_width; + } + + + +//! "better than nothing" settings for complex numbers +template +inline +std::streamsize +arma_ostream::modify_stream(std::ostream& o, const std::complex* data, const uword n_elem) + { + arma_ignore(data); + arma_ignore(n_elem); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.fill(' '); + + o.setf(ios::scientific); + o.setf(ios::showpos); + o.setf(ios::right); + o.unsetf(ios::fixed); + + std::streamsize cell_width; + + o.precision(3); + cell_width = 2 + 2*(1 + 3 + o.precision() + 5) + 1; + + return cell_width; + } + + +template +inline +std::streamsize +arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + + o.fill(' '); + + std::streamsize cell_width; + + bool use_layout_B = false; + bool use_layout_C = false; + + for(typename SpMat::const_iterator it = begin; it.pos() < n_elem; ++it) + { + const eT val = (*it); + + if(arma_isfinite(val) == false) { continue; } + + if( + val >= eT(+100) || + ( (is_signed::value) && (val <= eT(-100)) ) || + ( (is_non_integral::value) && (val > eT(0)) && (val <= eT(+1e-4)) ) || + ( (is_non_integral::value) && (is_signed::value) && (val < eT(0)) && (val >= eT(-1e-4)) ) + ) + { + use_layout_C = true; + break; + } + + if( + (val >= eT(+10)) || ( (is_signed::value) && (val <= eT(-10)) ) + ) + { + use_layout_B = true; + } + } + + if(use_layout_C) + { + o.setf(ios::scientific); + o.setf(ios::right); + o.unsetf(ios::fixed); + o.precision(4); + cell_width = 13; + } + else + if(use_layout_B) + { + o.unsetf(ios::scientific); + o.setf(ios::right); + o.setf(ios::fixed); + o.precision(4); + cell_width = 10; + } + else + { + o.unsetf(ios::scientific); + o.setf(ios::right); + o.setf(ios::fixed); + o.precision(4); + cell_width = 9; + } + + return cell_width; + } + + + +//! "better than nothing" settings for complex numbers +template +inline +std::streamsize +arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_cx_only::result* junk) + { + arma_ignore(begin); + arma_ignore(n_elem); + arma_ignore(junk); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.fill(' '); + + o.setf(ios::scientific); + o.setf(ios::showpos); + o.setf(ios::right); + o.unsetf(ios::fixed); + + std::streamsize cell_width; + + o.precision(3); + cell_width = 2 + 2*(1 + 3 + o.precision() + 5) + 1; + + return cell_width; + } + + + +template +inline +void +arma_ostream::print_elem_zero(std::ostream& o, const bool modify) + { + typedef typename promote_type::result promoted_eT; + + if(modify) + { + const ios::fmtflags save_flags = o.flags(); + const std::streamsize save_precision = o.precision(); + + o.unsetf(ios::scientific); + o.setf(ios::fixed); + o.precision(0); + + o << promoted_eT(0); + + o.flags(save_flags); + o.precision(save_precision); + } + else + { + o << promoted_eT(0); + } + } + + + +template +inline +void +arma_ostream::print_elem(std::ostream& o, const eT& x, const bool modify) + { + if(x == eT(0)) + { + arma_ostream::print_elem_zero(o, modify); + } + else + { + arma_ostream::raw_print_elem(o, x); + } + } + + + +template +inline +void +arma_ostream::raw_print_elem(std::ostream& o, const eT& x) + { + if(is_signed::value) + { + typedef typename promote_type::result promoted_eT; + + if(arma_isfinite(x)) + { + o << promoted_eT(x); + } + else + { + o << ( arma_isinf(x) ? ((x <= eT(0)) ? "-inf" : "inf") : "nan" ); + } + } + else + { + typedef typename promote_type::result promoted_eT; + + o << promoted_eT(x); + } + } + + + +template +inline +void +arma_ostream::print_elem(std::ostream& o, const std::complex& x, const bool modify) + { + if( (x.real() == T(0)) && (x.imag() == T(0)) && (modify) ) + { + o << "(0,0)"; + } + else + { + arma_ostream::raw_print_elem(o, x); + } + } + + + +template +inline +void +arma_ostream::raw_print_elem(std::ostream& o, const std::complex& x) + { + std::ostringstream ss; + ss.flags(o.flags()); + //ss.imbue(o.getloc()); + ss.precision(o.precision()); + + ss << '('; + + const T a = x.real(); + + if(arma_isfinite(a)) + { + ss << a; + } + else + { + ss << ( arma_isinf(a) ? ((a <= T(0)) ? "-inf" : "+inf") : "nan" ); + } + + ss << ','; + + const T b = x.imag(); + + if(arma_isfinite(b)) + { + ss << b; + } + else + { + ss << ( arma_isinf(b) ? ((b <= T(0)) ? "-inf" : "+inf") : "nan" ); + } + + ss << ')'; + + o << ss.str(); + } + + + +//! Print a matrix to the specified stream +template +inline +void +arma_ostream::print(std::ostream& o, const Mat& m, const bool modify) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + const std::streamsize cell_width = modify ? arma_ostream::modify_stream(o, m.memptr(), m.n_elem) : o.width(); + + const uword m_n_rows = m.n_rows; + const uword m_n_cols = m.n_cols; + + if(m.is_empty() == false) + { + if(m_n_cols > 0) + { + if(cell_width > 0) + { + for(uword row=0; row < m_n_rows; ++row) + { + for(uword col=0; col < m_n_cols; ++col) + { + // the cell width appears to be reset after each element is printed, + // hence we need to restore it + o.width(cell_width); + arma_ostream::print_elem(o, m.at(row,col), modify); + } + + o << '\n'; + } + } + else + { + for(uword row=0; row < m_n_rows; ++row) + { + for(uword col=0; col < m_n_cols-1; ++col) + { + arma_ostream::print_elem(o, m.at(row,col), modify); + o << ' '; + } + + arma_ostream::print_elem(o, m.at(row, m_n_cols-1), modify); + o << '\n'; + } + } + } + } + else + { + if(modify) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + } + + o << "[matrix size: " << m_n_rows << 'x' << m_n_cols << "]\n"; + } + + o.flush(); + stream_state.restore(o); + } + + + +//! Print a cube to the specified stream +template +inline +void +arma_ostream::print(std::ostream& o, const Cube& x, const bool modify) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + if(x.is_empty() == false) + { + for(uword slice=0; slice < x.n_slices; ++slice) + { + const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); + + o << "[cube slice: " << slice << ']' << '\n'; + arma_ostream::print(o, tmp, modify); + + if((slice+1) < x.n_slices) { o << '\n'; } + } + } + else + { + if(modify) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + } + + o << "[cube size: " << x.n_rows << 'x' << x.n_cols << 'x' << x.n_slices << "]\n"; + } + + stream_state.restore(o); + } + + + + +//! Print a field to the specified stream +//! Assumes type oT can be printed, ie. oT has std::ostream& operator<< (std::ostream&, const oT&) +template +inline +void +arma_ostream::print(std::ostream& o, const field& x) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + const std::streamsize cell_width = o.width(); + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + const uword x_n_slices = x.n_slices; + + if(x.is_empty() == false) + { + if(x_n_slices == 1) + { + for(uword col=0; col < x_n_cols; ++col) + { + o << "[field column: " << col << ']' << '\n'; + + for(uword row=0; row < x_n_rows; ++row) + { + o.width(cell_width); + o << x.at(row,col) << '\n'; + } + + o << '\n'; + } + } + else + { + for(uword slice=0; slice < x_n_slices; ++slice) + { + o << "[field slice: " << slice << ']' << '\n'; + + for(uword col=0; col < x_n_cols; ++col) + { + o << "[field column: " << col << ']' << '\n'; + + for(uword row=0; row < x_n_rows; ++row) + { + o.width(cell_width); + o << x.at(row,col,slice) << '\n'; + } + + o << '\n'; + } + + o << '\n'; + } + } + } + else + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + + o << "[field size: " << x_n_rows << 'x' << x_n_cols << 'x' << x_n_slices << "]\n"; + } + + o.flush(); + stream_state.restore(o); + } + + + +//! Print a subfield to the specified stream +//! Assumes type oT can be printed, ie. oT has std::ostream& operator<< (std::ostream&, const oT&) +template +inline +void +arma_ostream::print(std::ostream& o, const subview_field& x) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + const std::streamsize cell_width = o.width(); + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + const uword x_n_slices = x.n_slices; + + if(x.is_empty() == false) + { + if(x_n_slices == 1) + { + for(uword col=0; col < x_n_cols; ++col) + { + o << "[field column: " << col << ']' << '\n'; + for(uword row=0; row +inline +void +arma_ostream::print_dense(std::ostream& o, const SpMat& m, const bool modify) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + std::streamsize cell_width = o.width(); + + if(modify) + { + if(m.n_nonzero > 0) + { + cell_width = arma_ostream::modify_stream(o, m.begin(), m.n_nonzero); + } + else + { + eT tmp[1]; tmp[0] = eT(0); + + cell_width = arma_ostream::modify_stream(o, &tmp[0], 1); + } + } + + const uword m_n_rows = m.n_rows; + const uword m_n_cols = m.n_cols; + + if(m.is_empty() == false) + { + if(m_n_cols > 0) + { + if(cell_width > 0) + { + for(uword row=0; row < m_n_rows; ++row) + { + for(uword col=0; col < m_n_cols; ++col) + { + // the cell width appears to be reset after each element is printed, + // hence we need to restore it + o.width(cell_width); + arma_ostream::print_elem(o, m.at(row,col), modify); + } + + o << '\n'; + } + } + else + { + for(uword row=0; row < m_n_rows; ++row) + { + for(uword col=0; col < m_n_cols-1; ++col) + { + arma_ostream::print_elem(o, m.at(row,col), modify); + o << ' '; + } + + arma_ostream::print_elem(o, m.at(row, m_n_cols-1), modify); + o << '\n'; + } + } + } + } + else + { + if(modify) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + } + + o << "[matrix size: " << m_n_rows << 'x' << m_n_cols << "]\n"; + } + + o.flush(); + stream_state.restore(o); + } + + + +template +inline +void +arma_ostream::print(std::ostream& o, const SpMat& m, const bool modify) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.unsetf(ios::scientific); + o.setf(ios::right); + o.setf(ios::fixed); + + const uword m_n_nonzero = m.n_nonzero; + const double density = (m.n_elem > 0) ? (double(m_n_nonzero) / double(m.n_elem) * double(100)) : double(0); + + o << "[matrix size: " << m.n_rows << 'x' << m.n_cols << "; n_nonzero: " << m_n_nonzero; + + if(density == double(0)) + { + o.precision(0); + } + else + if(density >= (double(10.0)-std::numeric_limits::epsilon())) + { + o.precision(1); + } + else + if(density > (double(0.01)-std::numeric_limits::epsilon())) + { + o.precision(2); + } + else + if(density > (double(0.001)-std::numeric_limits::epsilon())) + { + o.precision(3); + } + else + if(density > (double(0.0001)-std::numeric_limits::epsilon())) + { + o.precision(4); + } + else + { + o.unsetf(ios::fixed); + o.setf(ios::scientific); + o.precision(2); + } + + o << "; density: " << density << "%]\n\n"; + + if(modify == false) { stream_state.restore(o); } + + if(m_n_nonzero > 0) + { + const std::streamsize cell_width = modify ? arma_ostream::modify_stream(o, m.begin(), m_n_nonzero) : o.width(); + + typename SpMat::const_iterator it = m.begin(); + typename SpMat::const_iterator it_end = m.end(); + + while(it != it_end) + { + const uword row = it.row(); + const uword col = it.col(); + + // TODO: change the maximum number of spaces before and after each location to be dependent on n_rows and n_cols + + if(row < 10) { o << " "; } + else if(row < 100) { o << " "; } + else if(row < 1000) { o << " "; } + else if(row < 10000) { o << " "; } + else if(row < 100000) { o << " "; } + else if(row < 1000000) { o << ' '; } + + o << '(' << row << ", " << col << ") "; + + if(col < 10) { o << " "; } + else if(col < 100) { o << " "; } + else if(col < 1000) { o << " "; } + else if(col < 10000) { o << " "; } + else if(col < 100000) { o << " "; } + else if(col < 1000000) { o << ' '; } + + if(cell_width > 0) { o.width(cell_width); } + + arma_ostream::print_elem(o, eT(*it), modify); + o << '\n'; + + ++it; + } + + o << '\n'; + } + + o.flush(); + stream_state.restore(o); + } + + + +inline +void +arma_ostream::print(std::ostream& o, const SizeMat& S) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + + o.setf(ios::fixed); + + o << S.n_rows << 'x' << S.n_cols; + + stream_state.restore(o); + } + + + +inline +void +arma_ostream::print(std::ostream& o, const SizeCube& S) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + + o.setf(ios::fixed); + + o << S.n_rows << 'x' << S.n_cols << 'x' << S.n_slices; + + stream_state.restore(o); + } + + + +template +inline +void +arma_ostream::brief_print(std::ostream& o, const Mat& m, const bool print_size) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + if(print_size) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + + o << "[matrix size: " << m.n_rows << 'x' << m.n_cols << "]\n"; + } + + if(m.n_elem == 0) { o.flush(); stream_state.restore(o); return; } + + if((m.n_rows <= 5) && (m.n_cols <= 5)) { arma_ostream::print(o, m, true); return; } + + const bool print_row_ellipsis = (m.n_rows >= 6); + const bool print_col_ellipsis = (m.n_cols >= 6); + + if( (print_row_ellipsis == true) && (print_col_ellipsis == true) ) + { + Mat X(4, 4, arma_nozeros_indicator()); + + X( span(0,2), span(0,2) ) = m( span(0,2), span(0,2) ); // top left submatrix + X( 3, span(0,2) ) = m( m.n_rows-1, span(0,2) ); // truncated last row + X( span(0,2), 3 ) = m( span(0,2), m.n_cols-1 ); // truncated last column + X( 3, 3 ) = m( m.n_rows-1, m.n_cols-1 ); // bottom right element + + const std::streamsize cell_width = arma_ostream::modify_stream(o, X.memptr(), X.n_elem); + + for(uword row=0; row <= 2; ++row) + { + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o.width(6); + o << "..."; + + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,3), true); + o << '\n'; + } + + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + o << ':'; + } + + o.width(6); + o << "..."; + + o.width(cell_width); + o << ':' << '\n'; + + const uword row = 3; + { + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o.width(6); + o << "..."; + + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,3), true); + o << '\n'; + } + } + + + if( (print_row_ellipsis == true) && (print_col_ellipsis == false) ) + { + Mat X(4, m.n_cols, arma_nozeros_indicator()); + + X( span(0,2), span::all ) = m( span(0,2), span::all ); // top + X( 3, span::all ) = m( m.n_rows-1, span::all ); // bottom + + const std::streamsize cell_width = arma_ostream::modify_stream(o, X.memptr(), X.n_elem); + + for(uword row=0; row <= 2; ++row) // first 3 rows + { + for(uword col=0; col < m.n_cols; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o << '\n'; + } + + for(uword col=0; col < m.n_cols; ++col) + { + o.width(cell_width); + o << ':'; + } + + o.width(cell_width); + o << '\n'; + + const uword row = 3; + { + for(uword col=0; col < m.n_cols; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + } + + o << '\n'; + } + + + if( (print_row_ellipsis == false) && (print_col_ellipsis == true) ) + { + Mat X(m.n_rows, 4, arma_nozeros_indicator()); + + X( span::all, span(0,2) ) = m( span::all, span(0,2) ); // left + X( span::all, 3 ) = m( span::all, m.n_cols-1 ); // right + + const std::streamsize cell_width = arma_ostream::modify_stream(o, X.memptr(), X.n_elem); + + for(uword row=0; row < m.n_rows; ++row) + { + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o.width(6); + o << "..."; + + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,3), true); + o << '\n'; + } + } + + + o.flush(); + stream_state.restore(o); + } + + + +template +inline +void +arma_ostream::brief_print(std::ostream& o, const Cube& x) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + + o << "[cube size: " << x.n_rows << 'x' << x.n_cols << 'x' << x.n_slices << "]\n"; + + if(x.n_elem == 0) { o.flush(); stream_state.restore(o); return; } + + if(x.n_slices <= 3) + { + for(uword slice=0; slice < x.n_slices; ++slice) + { + const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); + + o << "[cube slice: " << slice << ']' << '\n'; + arma_ostream::brief_print(o, tmp, false); + + if((slice+1) < x.n_slices) { o << '\n'; } + } + } + else + { + for(uword slice=0; slice <= 1; ++slice) + { + const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); + + o << "[cube slice: " << slice << ']' << '\n'; + arma_ostream::brief_print(o, tmp, false); + o << '\n'; + } + + o << "[cube slice: ...]\n\n"; + + const uword slice = x.n_slices-1; + { + const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); + + o << "[cube slice: " << slice << ']' << '\n'; + arma_ostream::brief_print(o, tmp, false); + } + } + + stream_state.restore(o); + } + + + +template +inline +void +arma_ostream::brief_print(std::ostream& o, const SpMat& m) + { + arma_extra_debug_sigprint(); + + if(m.n_nonzero <= 10) { arma_ostream::print(o, m, true); return; } + + const arma_ostream_state stream_state(o); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.unsetf(ios::scientific); + o.setf(ios::right); + o.setf(ios::fixed); + + const uword m_n_nonzero = m.n_nonzero; + const double density = (m.n_elem > 0) ? (double(m_n_nonzero) / double(m.n_elem) * double(100)) : double(0); + + o << "[matrix size: " << m.n_rows << 'x' << m.n_cols << "; n_nonzero: " << m_n_nonzero; + + if(density == double(0)) + { + o.precision(0); + } + else + if(density >= (double(10.0)-std::numeric_limits::epsilon())) + { + o.precision(1); + } + else + if(density > (double(0.01)-std::numeric_limits::epsilon())) + { + o.precision(2); + } + else + if(density > (double(0.001)-std::numeric_limits::epsilon())) + { + o.precision(3); + } + else + if(density > (double(0.0001)-std::numeric_limits::epsilon())) + { + o.precision(4); + } + else + { + o.unsetf(ios::fixed); + o.setf(ios::scientific); + o.precision(2); + } + + o << "; density: " << density << "%]\n\n"; + + // get the first 9 elements and the last element + + typename SpMat::const_iterator it = m.begin(); + typename SpMat::const_iterator it_end = m.end(); + + uvec storage_row(10); + uvec storage_col(10); + Col storage_val(10); + + uword count = 0; + + while( (it != it_end) && (count < 9) ) + { + storage_row(count) = it.row(); + storage_col(count) = it.col(); + storage_val(count) = (*it); + + ++it; + ++count; + } + + it = it_end; + --it; + + storage_row(count) = it.row(); + storage_col(count) = it.col(); + storage_val(count) = (*it); + + const std::streamsize cell_width = arma_ostream::modify_stream(o, storage_val.memptr(), 10); + + for(uword i=0; i < 9; ++i) + { + const uword row = storage_row(i); + const uword col = storage_col(i); + + if(row < 10) { o << " "; } + else if(row < 100) { o << " "; } + else if(row < 1000) { o << " "; } + else if(row < 10000) { o << " "; } + else if(row < 100000) { o << " "; } + else if(row < 1000000) { o << ' '; } + + o << '(' << row << ", " << col << ") "; + + if(col < 10) { o << " "; } + else if(col < 100) { o << " "; } + else if(col < 1000) { o << " "; } + else if(col < 10000) { o << " "; } + else if(col < 100000) { o << " "; } + else if(col < 1000000) { o << ' '; } + + if(cell_width > 0) { o.width(cell_width); } + + arma_ostream::print_elem(o, storage_val(i), true); + o << '\n'; + } + + o << " (:, :) "; + if(cell_width > 0) { o.width(cell_width); } + o << "...\n"; + + + const uword i = 9; + { + const uword row = storage_row(i); + const uword col = storage_col(i); + + if(row < 10) { o << " "; } + else if(row < 100) { o << " "; } + else if(row < 1000) { o << " "; } + else if(row < 10000) { o << " "; } + else if(row < 100000) { o << " "; } + else if(row < 1000000) { o << ' '; } + + o << '(' << row << ", " << col << ") "; + + if(col < 10) { o << " "; } + else if(col < 100) { o << " "; } + else if(col < 1000) { o << " "; } + else if(col < 10000) { o << " "; } + else if(col < 100000) { o << " "; } + else if(col < 1000000) { o << ' '; } + + if(cell_width > 0) { o.width(cell_width); } + + arma_ostream::print_elem(o, storage_val(i), true); + o << '\n'; + } + + o.flush(); + stream_state.restore(o); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_rel_comparators.hpp b/src/armadillo/include/armadillo_bits/arma_rel_comparators.hpp new file mode 100644 index 0000000..977617b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_rel_comparators.hpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arma_rel_comparators +//! @{ + + + +template +struct arma_lt_comparator + { + arma_inline bool operator() (const eT a, const eT b) const { return (a < b); } + }; + + + +template +struct arma_gt_comparator + { + arma_inline bool operator() (const eT a, const eT b) const { return (a > b); } + }; + + + +template +struct arma_leq_comparator + { + arma_inline bool operator() (const eT a, const eT b) const { return (a <= b); } + }; + + + +template +struct arma_geq_comparator + { + arma_inline bool operator() (const eT a, const eT b) const { return (a >= b); } + }; + + + +template +struct arma_lt_comparator< std::complex > + { + typedef typename std::complex eT; + + inline bool operator() (const eT& a, const eT& b) const { return (std::abs(a) < std::abs(b)); } + + // inline + // bool + // operator() (const eT& a, const eT& b) const + // { + // const T abs_a = std::abs(a); + // const T abs_b = std::abs(b); + // + // return ( (abs_a != abs_b) ? (abs_a < abs_b) : (std::arg(a) < std::arg(b)) ); + // } + + // inline + // bool + // operator() (const eT& a, const eT& b) const + // { + // const T a_real = a.real(); + // const T a_imag = a.imag(); + // + // const T a_mag_squared = a_real*a_real + a_imag*a_imag; + // + // const T b_real = b.real(); + // const T b_imag = b.imag(); + // + // const T b_mag_squared = b_real*b_real + b_imag*b_imag; + // + // if( (a_mag_squared != T(0)) && (b_mag_squared != T(0)) && std::isfinite(a_mag_squared) && std::isfinite(b_mag_squared) ) + // { + // return ( (a_mag_squared != b_mag_squared) ? (a_mag_squared < b_mag_squared) : (std::arg(a) < std::arg(b)) ); + // } + // else + // { + // const T abs_a = std::abs(a); + // const T abs_b = std::abs(b); + // + // return ( (abs_a != abs_b) ? (abs_a < abs_b) : (std::arg(a) < std::arg(b)) ); + // } + // } + }; + + + +template +struct arma_gt_comparator< std::complex > + { + typedef typename std::complex eT; + + inline bool operator() (const eT& a, const eT& b) const { return (std::abs(a) > std::abs(b)); } + + // inline + // bool + // operator() (const eT& a, const eT& b) const + // { + // const T abs_a = std::abs(a); + // const T abs_b = std::abs(b); + // + // return ( (abs_a != abs_b) ? (abs_a > abs_b) : (std::arg(a) > std::arg(b)) ); + // } + + // inline + // bool + // operator() (const eT& a, const eT& b) const + // { + // const T a_real = a.real(); + // const T a_imag = a.imag(); + // + // const T a_mag_squared = a_real*a_real + a_imag*a_imag; + // + // const T b_real = b.real(); + // const T b_imag = b.imag(); + // + // const T b_mag_squared = b_real*b_real + b_imag*b_imag; + // + // if( (a_mag_squared != T(0)) && (b_mag_squared != T(0)) && std::isfinite(a_mag_squared) && std::isfinite(b_mag_squared) ) + // { + // return ( (a_mag_squared != b_mag_squared) ? (a_mag_squared > b_mag_squared) : (std::arg(a) > std::arg(b)) ); + // } + // else + // { + // const T abs_a = std::abs(a); + // const T abs_b = std::abs(b); + // + // return ( (abs_a != abs_b) ? (abs_a > abs_b) : (std::arg(a) > std::arg(b)) ); + // } + // } + }; + + + +template +struct arma_leq_comparator< std::complex > + { + typedef typename std::complex eT; + + inline bool operator() (const eT& a, const eT& b) const { return (std::abs(a) <= std::abs(b)); } + }; + + + +template +struct arma_geq_comparator< std::complex > + { + typedef typename std::complex eT; + + inline bool operator() (const eT& a, const eT& b) const { return (std::abs(a) >= std::abs(b)); } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_rng.hpp b/src/armadillo/include/armadillo_bits/arma_rng.hpp new file mode 100644 index 0000000..da1b4f7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_rng.hpp @@ -0,0 +1,1042 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arma_rng +//! @{ + + +#undef ARMA_USE_CXX11_RNG +#define ARMA_USE_CXX11_RNG + +#undef ARMA_USE_THREAD_LOCAL +#define ARMA_USE_THREAD_LOCAL + +#if (defined(ARMA_RNG_ALT) || defined(ARMA_DONT_USE_CXX11_RNG)) + #undef ARMA_USE_CXX11_RNG +#endif + +#if defined(ARMA_DONT_USE_THREAD_LOCAL) + #undef ARMA_USE_THREAD_LOCAL +#endif + + +// NOTE: ARMA_WARMUP_PRODUCER enables a workaround +// NOTE: for thread_local issue on macOS 11 and/or AppleClang 12.0 +// NOTE: see https://gitlab.com/conradsnicta/armadillo-code/-/issues/173 +// NOTE: if this workaround causes problems, please report it and +// NOTE: disable the workaround by commenting out the code block below: + +#if defined(__APPLE__) || defined(__apple_build_version__) + #undef ARMA_WARMUP_PRODUCER + #define ARMA_WARMUP_PRODUCER +#endif + +#if defined(ARMA_DONT_WARMUP_PRODUCER) + #undef ARMA_WARMUP_PRODUCER +#endif + +// NOTE: workaround for another thread_local issue on macOS +// NOTE: where GCC (not Clang) may not have support for thread_local + +#if (defined(__APPLE__) && defined(__GNUG__) && !defined(__clang__)) + #undef ARMA_USE_THREAD_LOCAL +#endif + +// NOTE: disable use of thread_local on MinGW et al; +// NOTE: i don't have the patience to keep looking into these broken platforms + +#if (defined(__MINGW32__) || defined(__MINGW64__) || defined(__CYGWIN__) || defined(__MSYS__) || defined(__MSYS2__)) + #undef ARMA_USE_THREAD_LOCAL +#endif + +#if defined(ARMA_FORCE_USE_THREAD_LOCAL) + #undef ARMA_USE_THREAD_LOCAL + #define ARMA_USE_THREAD_LOCAL +#endif + +#if (!defined(ARMA_USE_THREAD_LOCAL)) + #undef ARMA_GUARD_PRODUCER + #define ARMA_GUARD_PRODUCER +#endif + +#if (defined(ARMA_DONT_GUARD_PRODUCER) || defined(ARMA_DONT_USE_STD_MUTEX)) + #undef ARMA_GUARD_PRODUCER +#endif + + +class arma_rng + { + public: + + #if defined(ARMA_RNG_ALT) + typedef arma_rng_alt::seed_type seed_type; + #elif defined(ARMA_USE_CXX11_RNG) + typedef std::mt19937_64::result_type seed_type; + #else + typedef arma_rng_cxx03::seed_type seed_type; + #endif + + #if defined(ARMA_RNG_ALT) + static constexpr int rng_method = 2; + #elif defined(ARMA_USE_CXX11_RNG) + static constexpr int rng_method = 1; + #else + static constexpr int rng_method = 0; + #endif + + #if defined(ARMA_USE_CXX11_RNG) + inline static std::mt19937_64& get_producer(); + inline static void warmup_producer(std::mt19937_64& producer); + + inline static void lock_producer(); + inline static void unlock_producer(); + + #if defined(ARMA_GUARD_PRODUCER) + inline static std::mutex& get_producer_mutex(); + #endif + #endif + + inline static void set_seed(const seed_type val); + inline static void set_seed_random(); + + template struct randi; + template struct randu; + template struct randn; + template struct randg; + }; + + + +#if defined(ARMA_USE_CXX11_RNG) + +inline +std::mt19937_64& +arma_rng::get_producer() + { + #if defined(ARMA_USE_THREAD_LOCAL) + + // use a thread-safe RNG, with each thread having its own unique starting seed + + static std::atomic mt19937_64_producer_counter(0); + + static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + mt19937_64_producer_counter++ ); + + arma_rng::warmup_producer(mt19937_64_producer); + + #else + + // use a plain RNG in case we don't have thread_local + + static std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed ); + + arma_rng::warmup_producer(mt19937_64_producer); + + #endif + + return mt19937_64_producer; + } + + +inline +void +arma_rng::warmup_producer(std::mt19937_64& producer) + { + #if defined(ARMA_WARMUP_PRODUCER) + + static std::atomic_flag warmup_done = ATOMIC_FLAG_INIT; // init to false + + if(warmup_done.test_and_set() == false) + { + typename std::mt19937_64::result_type junk = producer(); + + arma_ignore(junk); + } + + #else + + arma_ignore(producer); + + #endif + } + + +inline +void +arma_rng::lock_producer() + { + #if defined(ARMA_GUARD_PRODUCER) + + std::mutex& producer_mutex = arma_rng::get_producer_mutex(); + + producer_mutex.lock(); + + #endif + } + + +inline +void +arma_rng::unlock_producer() + { + #if defined(ARMA_GUARD_PRODUCER) + + std::mutex& producer_mutex = arma_rng::get_producer_mutex(); + + producer_mutex.unlock(); + + #endif + } + + +#if defined(ARMA_GUARD_PRODUCER) + inline + std::mutex& + arma_rng::get_producer_mutex() + { + static std::mutex producer_mutex; + + return producer_mutex; + } +#endif + +#endif + + +inline +void +arma_rng::set_seed(const arma_rng::seed_type val) + { + #if defined(ARMA_RNG_ALT) + { + arma_rng_alt::set_seed(val); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + arma_rng::lock_producer(); + arma_rng::get_producer().seed(val); + arma_rng::unlock_producer(); + } + #else + { + arma_rng_cxx03::set_seed(val); + } + #endif + } + + + +arma_cold +inline +void +arma_rng::set_seed_random() + { + seed_type seed1 = seed_type(0); + seed_type seed2 = seed_type(0); + seed_type seed3 = seed_type(0); + seed_type seed4 = seed_type(0); + + bool have_seed = false; + + try + { + std::random_device rd; + + if(rd.entropy() > double(0)) { seed1 = static_cast( rd() ); } + + have_seed = (seed1 != seed_type(0)); + } + catch(...) {} + + + if(have_seed == false) + { + try + { + union + { + seed_type a; + unsigned char b[sizeof(seed_type)]; + } tmp; + + tmp.a = seed_type(0); + + std::ifstream f("/dev/urandom", std::ifstream::binary); + + if(f.good()) { f.read((char*)(&(tmp.b[0])), sizeof(seed_type)); } + + if(f.good()) { seed2 = tmp.a; } + + have_seed = (seed2 != seed_type(0)); + } + catch(...) {} + } + + + if(have_seed == false) + { + // get better-than-nothing seeds in case reading /dev/urandom failed + + const std::chrono::system_clock::time_point tp_now = std::chrono::system_clock::now(); + + auto since_epoch_usec = std::chrono::duration_cast(tp_now.time_since_epoch()).count(); + + seed3 = static_cast( since_epoch_usec & 0xFFFF ); + + union + { + uword* a; + unsigned char b[sizeof(uword*)]; + } tmp; + + tmp.a = (uword*)malloc(sizeof(uword)); + + if(tmp.a != nullptr) + { + for(size_t i=0; i +struct arma_rng::randi + { + inline + operator eT () + { + #if defined(ARMA_RNG_ALT) + { + return eT( arma_rng_alt::randi_val() ); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + constexpr double scale = double(std::numeric_limits::max()) / double(std::mt19937_64::max()); + + arma_rng::lock_producer(); + + const eT out = eT(double(arma_rng::get_producer()()) * scale); + + arma_rng::unlock_producer(); + + return out; + } + #else + { + return eT( arma_rng_cxx03::randi_val() ); + } + #endif + } + + + inline + static + int + max_val() + { + #if defined(ARMA_RNG_ALT) + { + return arma_rng_alt::randi_max_val(); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + return std::numeric_limits::max(); + } + #else + { + return arma_rng_cxx03::randi_max_val(); + } + #endif + } + + + inline + static + void + fill(eT* mem, const uword N, const int a, const int b) + { + #if defined(ARMA_RNG_ALT) + { + arma_rng_alt::randi_fill(mem, N, a, b); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_int_distribution local_i_distr(a, b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i local_i_distr(a, b); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i +struct arma_rng::randu + { + inline + operator eT () + { + #if defined(ARMA_RNG_ALT) + { + return eT( arma_rng_alt::randu_val() ); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + constexpr double scale = double(1.0) / double(std::mt19937_64::max()); + + arma_rng::lock_producer(); + + const eT out = eT( double(arma_rng::get_producer()()) * scale ); + + arma_rng::unlock_producer(); + + return out; + } + #else + { + return eT( arma_rng_cxx03::randu_val() ); + } + #endif + } + + + inline + static + void + fill(eT* mem, const uword N) + { + #if defined(ARMA_RNG_ALT) + { + for(uword i=0; i < N; ++i) { mem[i] = eT( arma_rng_alt::randu_val() ); } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(producer) ); } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) { mem[0] = eT( arma_rng_cxx03::randu_val() ); return; } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(local_engine) ); } + } + #endif + } + + + inline + static + void + fill(eT* mem, const uword N, const double a, const double b) + { + #if defined(ARMA_RNG_ALT) + { + const double r = b - a; + + for(uword i=0; i < N; ++i) { mem[i] = eT( arma_rng_alt::randu_val() * r + a ); } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr(a,b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(producer) ); } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) { mem[0] = eT( arma_rng_cxx03::randu_val() * (b - a) + a ); return; } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr(a,b); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(local_engine) ); } + } + #endif + } + }; + + + +template +struct arma_rng::randu< std::complex > + { + arma_inline + operator std::complex () + { + #if defined(ARMA_RNG_ALT) + { + const T a = T( arma_rng_alt::randu_val() ); + const T b = T( arma_rng_alt::randu_val() ); + + return std::complex(a, b); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + const T a = T( local_u_distr(producer) ); + const T b = T( local_u_distr(producer) ); + + arma_rng::unlock_producer(); + + return std::complex(a, b); + } + #else + { + const T a = T( arma_rng_cxx03::randu_val() ); + const T b = T( arma_rng_cxx03::randu_val() ); + + return std::complex(a, b); + } + #endif + } + + + inline + static + void + fill(std::complex* mem, const uword N) + { + #if defined(ARMA_RNG_ALT) + { + for(uword i=0; i < N; ++i) + { + const T a = T( arma_rng_alt::randu_val() ); + const T b = T( arma_rng_alt::randu_val() ); + + mem[i] = std::complex(a, b); + } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) + { + const T a = T( local_u_distr(producer) ); + const T b = T( local_u_distr(producer) ); + + mem[i] = std::complex(a, b); + } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) + { + const T a = T( arma_rng_cxx03::randu_val() ); + const T b = T( arma_rng_cxx03::randu_val() ); + + mem[0] = std::complex(a, b); + + return; + } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) + { + const T a = T( local_u_distr(local_engine) ); + const T b = T( local_u_distr(local_engine) ); + + mem[i] = std::complex(a, b); + } + } + #endif + } + + + inline + static + void + fill(std::complex* mem, const uword N, const double a, const double b) + { + #if defined(ARMA_RNG_ALT) + { + const double r = b - a; + + for(uword i=0; i < N; ++i) + { + const T tmp1 = T( arma_rng_alt::randu_val() * r + a ); + const T tmp2 = T( arma_rng_alt::randu_val() * r + a ); + + mem[i] = std::complex(tmp1, tmp2); + } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr(a,b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) + { + const T tmp1 = T( local_u_distr(producer) ); + const T tmp2 = T( local_u_distr(producer) ); + + mem[i] = std::complex(tmp1, tmp2); + } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) + { + const double r = b - a; + + const T tmp1 = T( arma_rng_cxx03::randu_val() * r + a); + const T tmp2 = T( arma_rng_cxx03::randu_val() * r + a); + + mem[0] = std::complex(tmp1, tmp2); + + return; + } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr(a,b); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) + { + const T tmp1 = T( local_u_distr(local_engine) ); + const T tmp2 = T( local_u_distr(local_engine) ); + + mem[i] = std::complex(tmp1, tmp2); + } + } + #endif + } + }; + + + +// + + + +template +struct arma_rng::randn + { + inline + operator eT () const + { + #if defined(ARMA_RNG_ALT) + { + return eT( arma_rng_alt::randn_val() ); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::normal_distribution local_n_distr; + + arma_rng::lock_producer(); + + const eT out = eT( local_n_distr(arma_rng::get_producer()) ); + + arma_rng::unlock_producer(); + + return out; + } + #else + { + return eT( arma_rng_cxx03::randn_val() ); + } + #endif + } + + + inline + static + void + dual_val(eT& out1, eT& out2) + { + #if defined(ARMA_RNG_ALT) + { + arma_rng_alt::randn_dual_val(out1, out2); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::normal_distribution local_n_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + out1 = eT( local_n_distr(producer) ); + out2 = eT( local_n_distr(producer) ); + + arma_rng::unlock_producer(); + } + #else + { + arma_rng_cxx03::randn_dual_val(out1, out2); + } + #endif + } + + + inline + static + void + fill(eT* mem, const uword N) + { + #if defined(ARMA_RNG_ALT) + { + // NOTE: old method to avoid regressions in user code that assumes specific sequence + + uword i, j; + + for(i=0, j=1; j < N; i+=2, j+=2) { arma_rng_alt::randn_dual_val( mem[i], mem[j] ); } + + if(i < N) { mem[i] = eT( arma_rng_alt::randn_val() ); } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::normal_distribution local_n_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(producer) ); } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) { mem[0] = eT( arma_rng_cxx03::randn_val() ); return; } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::normal_distribution local_n_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(local_engine) ); } + } + #endif + } + + + inline + static + void + fill(eT* mem, const uword N, const double mu, const double sd) + { + #if defined(ARMA_RNG_ALT) + { + // NOTE: old method to avoid regressions in user code that assumes specific sequence + + uword i, j; + + for(i=0, j=1; j < N; i+=2, j+=2) + { + eT val_i = eT(0); + eT val_j = eT(0); + + arma_rng_alt::randn_dual_val( val_i, val_j ); + + mem[i] = (val_i * sd) + mu; + mem[j] = (val_j * sd) + mu; + } + + if(i < N) + { + const eT val_i = eT( arma_rng_alt::randn_val() ); + + mem[i] = (val_i * sd) + mu; + } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::normal_distribution local_n_distr(mu, sd); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(producer) ); } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) + { + const eT val = eT( arma_rng_cxx03::randn_val() ); + + mem[0] = (val * sd) + mu; + + return; + } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::normal_distribution local_n_distr(mu, sd); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(local_engine) ); } + } + #endif + } + }; + + + +template +struct arma_rng::randn< std::complex > + { + inline + operator std::complex () const + { + #if defined(_MSC_VER) + // attempt at workaround for MSVC bug + // does MS even test their so-called compilers before release? + T a; + T b; + #else + T a(0); + T b(0); + #endif + + arma_rng::randn::dual_val(a, b); + + return std::complex(a, b); + } + + + inline + static + void + dual_val(std::complex& out1, std::complex& out2) + { + #if defined(_MSC_VER) + T a; + T b; + #else + T a(0); + T b(0); + #endif + + arma_rng::randn::dual_val(a,b); + out1 = std::complex(a,b); + + arma_rng::randn::dual_val(a,b); + out2 = std::complex(a,b); + } + + + inline + static + void + fill(std::complex* mem, const uword N) + { + #if defined(ARMA_RNG_ALT) + { + for(uword i=0; i < N; ++i) { mem[i] = std::complex( arma_rng::randn< std::complex >() ); } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::normal_distribution local_n_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) + { + const T a = T( local_n_distr(producer) ); + const T b = T( local_n_distr(producer) ); + + mem[i] = std::complex(a,b); + } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) + { + T a = T(0); + T b = T(0); + + arma_rng_cxx03::randn_dual_val(a,b); + + mem[0] = std::complex(a,b); + + return; + } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::normal_distribution local_n_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) + { + const T a = T( local_n_distr(local_engine) ); + const T b = T( local_n_distr(local_engine) ); + + mem[i] = std::complex(a,b); + } + } + #endif + } + + + inline + static + void + fill(std::complex* mem, const uword N, const double mu, const double sd) + { + arma_rng::randn< std::complex >::fill(mem, N); + + if( (mu == double(0)) && (sd == double(1)) ) { return; } + + for(uword i=0; i& val = mem[i]; + + mem[i] = std::complex( ((val.real() * sd) + mu), ((val.imag() * sd) + mu) ); + } + } + }; + + + +// + + + +template +struct arma_rng::randg + { + inline + static + void + fill(eT* mem, const uword N, const double a, const double b) + { + #if defined(ARMA_USE_CXX11_RNG) + { + std::gamma_distribution local_g_distr(a,b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i local_g_distr(a,b); + + local_engine.seed( local_seed_type(arma_rng::randi()) ); + + for(uword i=0; i + inline static void randn_dual_val(eT& out1, eT& out2); + + template + inline static void randi_fill(eT* mem, const uword N, const int a, const int b); + + inline static int randi_max_val(); + }; + + + +inline +void +arma_rng_cxx03::set_seed(const arma_rng_cxx03::seed_type val) + { + std::srand(val); + } + + + +arma_inline +int +arma_rng_cxx03::randi_val() + { + #if (RAND_MAX == 32767) + { + // NOTE: this is a better-than-nothing solution + // NOTE: see also arma_rng_cxx03::randi_max_val() + + u32 val1 = u32(std::rand()); + u32 val2 = u32(std::rand()); + + val1 <<= 15; + + return (val1 | val2); + } + #else + { + return std::rand(); + } + #endif + } + + + +arma_inline +double +arma_rng_cxx03::randu_val() + { + return double( double(randi_val()) * ( double(1) / double(randi_max_val()) ) ); + } + + + +inline +double +arma_rng_cxx03::randn_val() + { + // polar form of the Box-Muller transformation: + // http://en.wikipedia.org/wiki/Box-Muller_transformation + // http://en.wikipedia.org/wiki/Marsaglia_polar_method + + double tmp1 = double(0); + double tmp2 = double(0); + double w = double(0); + + do + { + tmp1 = double(2) * double(randi_val()) * (double(1) / double(randi_max_val())) - double(1); + tmp2 = double(2) * double(randi_val()) * (double(1) / double(randi_max_val())) - double(1); + + w = tmp1*tmp1 + tmp2*tmp2; + } + while( w >= double(1) ); + + return double( tmp1 * std::sqrt( (double(-2) * std::log(w)) / w) ); + } + + + +template +inline +void +arma_rng_cxx03::randn_dual_val(eT& out1, eT& out2) + { + // make sure we are internally using at least floats + typedef typename promote_type::result eTp; + + eTp tmp1 = eTp(0); + eTp tmp2 = eTp(0); + eTp w = eTp(0); + + do + { + tmp1 = eTp(2) * eTp(randi_val()) * (eTp(1) / eTp(randi_max_val())) - eTp(1); + tmp2 = eTp(2) * eTp(randi_val()) * (eTp(1) / eTp(randi_max_val())) - eTp(1); + + w = tmp1*tmp1 + tmp2*tmp2; + } + while( w >= eTp(1) ); + + const eTp k = std::sqrt( (eTp(-2) * std::log(w)) / w); + + out1 = eT(tmp1*k); + out2 = eT(tmp2*k); + } + + + +template +inline +void +arma_rng_cxx03::randi_fill(eT* mem, const uword N, const int a, const int b) + { + if( (a == 0) && (b == RAND_MAX) ) + { + for(uword i=0; i n_chars_prealloc) { std::free(mem); } + + mem = nullptr; + n_chars = 0; + } + + inline + char_buffer() + { + mem = &(local_mem[0]); + n_chars = n_chars_prealloc; + + if(n_chars > 0) { mem[0] = char(0); } + } + + inline + void + set_size(const uword new_n_chars) + { + if(n_chars > n_chars_prealloc) { std::free(mem); } + + mem = (new_n_chars <= n_chars_prealloc) ? &(local_mem[0]) : (char*)std::malloc(new_n_chars); + n_chars = (new_n_chars <= n_chars_prealloc) ? n_chars_prealloc : new_n_chars; + + if(n_chars > 0) { mem[0] = char(0); } + } + }; + + + class format + { + public: + + const std::string fmt; + + inline format(const char* in_fmt) : fmt(in_fmt) { } + inline format(const std::string& in_fmt) : fmt(in_fmt) { } + + private: + format(); + }; + + + + template + class basic_format + { + public: + + const T1& A; + const T2& B; + + inline basic_format(const T1& in_A, const T2& in_B) : A(in_A) , B(in_B) { } + + private: + basic_format(); + }; + + + + template + inline + basic_format< format, T2 > + operator% (const format& X, const T2& arg) + { + return basic_format< format, T2 >(X, arg); + } + + + + template + inline + basic_format< basic_format, T3 > + operator% (const basic_format& X, const T3& arg) + { + return basic_format< basic_format, T3 >(X, arg); + } + + + + template + inline + std::string + str(const basic_format< format, T2>& X) + { + std::string out; + char_buffer buf; + + bool status = false; + + while(status == false) + { + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.fmt.c_str(), X.B); + + if(required_size < 0) { break; } + + if(uword(required_size) >= buf.n_chars) + { + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); + } + else + { + status = true; + } + + if(status) { out = buf.mem; } + } + + return out; + } + + + + template + inline + std::string + str(const basic_format< basic_format< format, T2>, T3>& X) + { + char_buffer buf; + std::string out; + + bool status = false; + + while(status == false) + { + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.fmt.c_str(), X.A.B, X.B); + + if(required_size < 0) { break; } + + if(uword(required_size) >= buf.n_chars) + { + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); + } + else + { + status = true; + } + + if(status) { out = buf.mem; } + } + + return out; + } + + + + template + inline + std::string + str(const basic_format< basic_format< basic_format< format, T2>, T3>, T4>& X) + { + char_buffer buf; + std::string out; + + bool status = false; + + while(status == false) + { + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.fmt.c_str(), X.A.A.B, X.A.B, X.B); + + if(required_size < 0) { break; } + + if(uword(required_size) >= buf.n_chars) + { + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); + } + else + { + status = true; + } + + if(status) { out = buf.mem; } + } + + return out; + } + + + + template + inline + std::string + str(const basic_format< basic_format< basic_format< basic_format< format, T2>, T3>, T4>, T5>& X) + { + char_buffer buf; + std::string out; + + bool status = false; + + while(status == false) + { + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.A.fmt.c_str(), X.A.A.A.B, X.A.A.B, X.A.B, X.B); + + if(required_size < 0) { break; } + + if(uword(required_size) >= buf.n_chars) + { + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); + } + else + { + status = true; + } + + if(status) { out = buf.mem; } + } + + return out; + } + + + + template + inline + std::string + str(const basic_format< basic_format< basic_format< basic_format< basic_format< format, T2>, T3>, T4>, T5>, T6>& X) + { + char_buffer buf; + std::string out; + + bool status = false; + + while(status == false) + { + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.A.A.fmt.c_str(), X.A.A.A.A.B, X.A.A.A.B, X.A.A.B, X.A.B, X.B); + + if(required_size < 0) { break; } + + if(uword(required_size) >= buf.n_chars) + { + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); + } + else + { + status = true; + } + + if(status) { out = buf.mem; } + } + + return out; + } + + + + template + inline + std::string + str(const basic_format< basic_format< basic_format< basic_format< basic_format< basic_format< format, T2>, T3>, T4>, T5>, T6>, T7>& X) + { + char_buffer buf; + std::string out; + + bool status = false; + + while(status == false) + { + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.A.A.A.fmt.c_str(), X.A.A.A.A.A.B, X.A.A.A.A.B, X.A.A.A.B, X.A.A.B, X.A.B, X.B); + + if(required_size < 0) { break; } + + if(uword(required_size) >= buf.n_chars) + { + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); + } + else + { + status = true; + } + + if(status) { out = buf.mem; } + } + + return out; + } + + + + template + struct format_metaprog + { + static constexpr uword depth = 0; + + inline + static + const std::string& + get_fmt(const T1& X) + { + return X.A; + } + }; + + + + //template<> + template + struct format_metaprog< basic_format > + { + static constexpr uword depth = 1 + format_metaprog::depth; + + inline + static + const std::string& + get_fmt(const T1& X) + { + return format_metaprog::get_fmt(X.A); + } + + }; + + + + template + inline + std::string + str(const basic_format& X) + { + return format_metaprog< basic_format >::get_fmt(X.A); + } + + + + template + inline + std::ostream& + operator<< (std::ostream& o, const basic_format& X) + { + o << str(X); + return o; + } + + + template struct string_only { }; + template<> struct string_only { typedef std::string result; }; + + template struct char_only { }; + template<> struct char_only { typedef char result; }; + + template + struct basic_format_only { }; + + template + struct basic_format_only< basic_format > { typedef basic_format result; }; + + + + template + inline + static + const T1& + str_wrapper(const T1& x, const typename string_only::result* junk = nullptr) + { + arma_ignore(junk); + + return x; + } + + + + template + inline + static + const T1* + str_wrapper(const T1* x, const typename char_only::result* junk = nullptr) + { + arma_ignore(junk); + + return x; + } + + + + template + inline + static + std::string + str_wrapper(const T1& x, const typename basic_format_only::result* junk = nullptr) + { + arma_ignore(junk); + + return str(x); + } + + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arma_version.hpp b/src/armadillo/include/armadillo_bits/arma_version.hpp new file mode 100644 index 0000000..a335bb3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arma_version.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arma_version +//! @{ + + + +#define ARMA_VERSION_MAJOR 12 +#define ARMA_VERSION_MINOR 6 +#define ARMA_VERSION_PATCH 7 +#define ARMA_VERSION_NAME "Cortisol Retox" + + + +struct arma_version + { + static constexpr unsigned int major = ARMA_VERSION_MAJOR; + static constexpr unsigned int minor = ARMA_VERSION_MINOR; + static constexpr unsigned int patch = ARMA_VERSION_PATCH; + + static + inline + std::string + as_string() + { + const char* nickname = ARMA_VERSION_NAME; + + std::ostringstream ss; + + ss << arma_version::major + << '.' + << arma_version::minor + << '.' + << arma_version::patch + << " (" + << nickname + << ')'; + + return ss.str(); + } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arrayops_bones.hpp b/src/armadillo/include/armadillo_bits/arrayops_bones.hpp new file mode 100644 index 0000000..0beec3a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arrayops_bones.hpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arrayops +//! @{ + + +class arrayops + { + public: + + template + arma_inline static void + copy(eT* dest, const eT* src, const uword n_elem); + + template + inline static void + fill_zeros(eT* dest, const uword n_elem); + + template + arma_hot inline static void + replace(eT* mem, const uword n_elem, const eT old_val, const eT new_val); + + template + arma_hot inline static void + clean(eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk = nullptr); + + template + arma_hot inline static void + clean(std::complex* mem, const uword n_elem, const T abs_limit); + + template + inline static void + clamp(eT* mem, const uword n_elem, const eT min_val, const eT max_val, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void + clamp(std::complex* mem, const uword n_elem, const std::complex& min_val, const std::complex& max_val); + + + // + // array = convert(array) + + template + arma_inline static void + convert_cx_scalar(out_eT& out, const in_eT& in, const typename arma_not_cx::result* junk1 = nullptr, const typename arma_not_cx< in_eT>::result* junk2 = nullptr); + + template + arma_inline static void + convert_cx_scalar(out_eT& out, const std::complex& in, const typename arma_not_cx::result* junk = nullptr); + + template + arma_inline static void + convert_cx_scalar(std::complex& out, const std::complex< in_T>& in); + + template + arma_hot inline static void + convert(out_eT* dest, const in_eT* src, const uword n_elem); + + template + arma_hot inline static void + convert_cx(out_eT* dest, const in_eT* src, const uword n_elem); + + + // + // array op= array + + template + arma_hot inline static + void + inplace_plus(eT* dest, const eT* src, const uword n_elem); + + template + arma_hot inline static + void + inplace_minus(eT* dest, const eT* src, const uword n_elem); + + template + arma_hot inline static + void + inplace_mul(eT* dest, const eT* src, const uword n_elem); + + template + arma_hot inline static + void + inplace_div(eT* dest, const eT* src, const uword n_elem); + + + template + arma_hot inline static + void + inplace_plus_base(eT* dest, const eT* src, const uword n_elem); + + template + arma_hot inline static + void + inplace_minus_base(eT* dest, const eT* src, const uword n_elem); + + template + arma_hot inline static + void + inplace_mul_base(eT* dest, const eT* src, const uword n_elem); + + template + arma_hot inline static + void + inplace_div_base(eT* dest, const eT* src, const uword n_elem); + + + // + // array op= scalar + + template + arma_hot inline static + void + inplace_set(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static + void + inplace_set_simple(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static + void + inplace_set_base(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static + void + inplace_set_fixed(eT* dest, const eT val); + + template + arma_hot inline static + void + inplace_plus(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static + void + inplace_minus(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static void + inplace_mul(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static + void + inplace_div(eT* dest, const eT val, const uword n_elem); + + + template + arma_hot inline static + void + inplace_plus_base(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static + void + inplace_minus_base(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static void + inplace_mul_base(eT* dest, const eT val, const uword n_elem); + + template + arma_hot inline static + void + inplace_div_base(eT* dest, const eT val, const uword n_elem); + + + // + // scalar = op(array) + + template + arma_hot inline static + eT + accumulate(const eT* src, const uword n_elem); + + template + arma_hot inline static + eT + product(const eT* src, const uword n_elem); + + template + arma_hot inline static + bool + is_zero(const eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk = nullptr); + + template + arma_hot inline static + bool + is_zero(const std::complex* mem, const uword n_elem, const T abs_limit); + + template + arma_hot inline static + bool + is_finite(const eT* src, const uword n_elem); + + template + arma_hot inline static + bool + has_inf(const eT* src, const uword n_elem); + + template + arma_hot inline static + bool + has_nan(const eT* src, const uword n_elem); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/arrayops_meat.hpp b/src/armadillo/include/armadillo_bits/arrayops_meat.hpp new file mode 100644 index 0000000..57f1a1d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/arrayops_meat.hpp @@ -0,0 +1,1108 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup arrayops +//! @{ + + + +template +arma_inline +void +arrayops::copy(eT* dest, const eT* src, const uword n_elem) + { + if( (dest == src) || (n_elem == 0) ) { return; } + + std::memcpy(dest, src, n_elem*sizeof(eT)); + } + + + +template +inline +void +arrayops::fill_zeros(eT* dest, const uword n_elem) + { + typedef typename get_pod_type::result pod_type; + + if(n_elem == 0) { return; } + + if(std::numeric_limits::is_integer || std::numeric_limits::is_iec559) + { + std::memset((void*)dest, 0, sizeof(eT)*n_elem); + } + else + { + arrayops::inplace_set_simple(dest, eT(0), n_elem); + } + } + + + +template +inline +void +arrayops::replace(eT* mem, const uword n_elem, const eT old_val, const eT new_val) + { + if(arma_isnan(old_val)) + { + for(uword i=0; i +inline +void +arrayops::clean(eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk) + { + arma_ignore(junk); + + for(uword i=0; i +inline +void +arrayops::clean(std::complex* mem, const uword n_elem, const T abs_limit) + { + typedef typename std::complex eT; + + for(uword i=0; i(T(0), val_imag); + } + else + if(std::abs(val_imag) <= abs_limit) + { + val = std::complex(val_real, T(0)); + } + } + } + + + +template +inline +void +arrayops::clamp(eT* mem, const uword n_elem, const eT min_val, const eT max_val, const typename arma_not_cx::result* junk) + { + arma_ignore(junk); + + for(uword i=0; i max_val) ? max_val : val); + } + } + + + +template +inline +void +arrayops::clamp(std::complex* mem, const uword n_elem, const std::complex& min_val, const std::complex& max_val) + { + typedef typename std::complex eT; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + val = std::complex(val_real,val_imag); + } + } + + + +template +arma_inline +void +arrayops::convert_cx_scalar + ( + out_eT& out, + const in_eT& in, + const typename arma_not_cx::result* junk1, + const typename arma_not_cx< in_eT>::result* junk2 + ) + { + arma_ignore(junk1); + arma_ignore(junk2); + + out = out_eT(in); + } + + + +template +arma_inline +void +arrayops::convert_cx_scalar + ( + out_eT& out, + const std::complex& in, + const typename arma_not_cx::result* junk + ) + { + arma_ignore(junk); + + const in_T val = in.real(); + + const bool conversion_ok = (std::is_integral::value && std::is_floating_point::value) ? arma_isfinite(val) : true; + + out = conversion_ok ? out_eT(val) : out_eT(0); + } + + + +template +arma_inline +void +arrayops::convert_cx_scalar + ( + std::complex& out, + const std::complex< in_T>& in + ) + { + typedef std::complex out_eT; + + out = out_eT(in); + } + + + +template +inline +void +arrayops::convert(out_eT* dest, const in_eT* src, const uword n_elem) + { + if(is_same_type::value) + { + const out_eT* src2 = (const out_eT*)src; + + if(dest != src2) { arrayops::copy(dest, src2, n_elem); } + + return; + } + + const bool check_finite = (std::is_integral::value && std::is_floating_point::value); + + uword j; + + for(j=1; j::value) + ? out_eT( tmp_i ) + : ( cond_rel< is_signed::value >::lt(tmp_i, in_eT(0)) ? out_eT(0) : out_eT(tmp_i) ) + ) + : out_eT(0); + + dest++; + + (*dest) = ok_j + ? ( + (is_signed::value) + ? out_eT( tmp_j ) + : ( cond_rel< is_signed::value >::lt(tmp_j, in_eT(0)) ? out_eT(0) : out_eT(tmp_j) ) + ) + : out_eT(0); + dest++; + } + + if((j-1) < n_elem) + { + const in_eT tmp_i = (*src); + + // dest[i] = out_eT( tmp_i ); + + const bool ok_i = check_finite ? arma_isfinite(tmp_i) : true; + + (*dest) = ok_i + ? ( + (is_signed::value) + ? out_eT( tmp_i ) + : ( cond_rel< is_signed::value >::lt(tmp_i, in_eT(0)) ? out_eT(0) : out_eT(tmp_i) ) + ) + : out_eT(0); + } + } + + + +template +inline +void +arrayops::convert_cx(out_eT* dest, const in_eT* src, const uword n_elem) + { + uword j; + + for(j=1; j +inline +void +arrayops::inplace_plus(eT* dest, const eT* src, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_plus_base(dest, src, n_elem); + } + else + { + arrayops::inplace_plus_base(dest, src, n_elem); + } + } + else + { + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_plus_base(dest, src, n_elem); + } + else + { + arrayops::inplace_plus_base(dest, src, n_elem); + } + } + } + + + +template +inline +void +arrayops::inplace_minus(eT* dest, const eT* src, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_minus_base(dest, src, n_elem); + } + else + { + arrayops::inplace_minus_base(dest, src, n_elem); + } + } + else + { + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_minus_base(dest, src, n_elem); + } + else + { + arrayops::inplace_minus_base(dest, src, n_elem); + } + } + } + + + +template +inline +void +arrayops::inplace_mul(eT* dest, const eT* src, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_mul_base(dest, src, n_elem); + } + else + { + arrayops::inplace_mul_base(dest, src, n_elem); + } + } + else + { + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_mul_base(dest, src, n_elem); + } + else + { + arrayops::inplace_mul_base(dest, src, n_elem); + } + } + } + + + +template +inline +void +arrayops::inplace_div(eT* dest, const eT* src, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_div_base(dest, src, n_elem); + } + else + { + arrayops::inplace_div_base(dest, src, n_elem); + } + } + else + { + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + + arrayops::inplace_div_base(dest, src, n_elem); + } + else + { + arrayops::inplace_div_base(dest, src, n_elem); + } + } + } + + + +template +inline +void +arrayops::inplace_plus_base(eT* dest, const eT* src, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_minus_base(eT* dest, const eT* src, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_mul_base(eT* dest, const eT* src, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_div_base(eT* dest, const eT* src, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_set(eT* dest, const eT val, const uword n_elem) + { + if(val == eT(0)) + { + arrayops::fill_zeros(dest, n_elem); + } + else + { + arrayops::inplace_set_simple(dest, val, n_elem); + } + } + + + +template +inline +void +arrayops::inplace_set_simple(eT* dest, const eT val, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + arrayops::inplace_set_base(dest, val, n_elem); + } + else + { + arrayops::inplace_set_base(dest, val, n_elem); + } + } + + + +template +inline +void +arrayops::inplace_set_base(eT* dest, const eT val, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_set_fixed(eT* dest, const eT val) + { + for(uword i=0; i +inline +void +arrayops::inplace_plus(eT* dest, const eT val, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + arrayops::inplace_plus_base(dest, val, n_elem); + } + else + { + arrayops::inplace_plus_base(dest, val, n_elem); + } + } + + + +template +inline +void +arrayops::inplace_minus(eT* dest, const eT val, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + arrayops::inplace_minus_base(dest, val, n_elem); + } + else + { + arrayops::inplace_minus_base(dest, val, n_elem); + } + } + + + +template +inline +void +arrayops::inplace_mul(eT* dest, const eT val, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + arrayops::inplace_mul_base(dest, val, n_elem); + } + else + { + arrayops::inplace_mul_base(dest, val, n_elem); + } + } + + + +template +inline +void +arrayops::inplace_div(eT* dest, const eT val, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + arrayops::inplace_div_base(dest, val, n_elem); + } + else + { + arrayops::inplace_div_base(dest, val, n_elem); + } + } + + + +template +inline +void +arrayops::inplace_plus_base(eT* dest, const eT val, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_minus_base(eT* dest, const eT val, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_mul_base(eT* dest, const eT val, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +void +arrayops::inplace_div_base(eT* dest, const eT val, const uword n_elem) + { + #if defined(ARMA_SIMPLE_LOOPS) + { + for(uword i=0; i +inline +eT +arrayops::accumulate(const eT* src, const uword n_elem) + { + #if defined(__FAST_MATH__) + { + eT acc = eT(0); + + if(memory::is_aligned(src)) + { + memory::mark_as_aligned(src); + for(uword i=0; i +inline +eT +arrayops::product(const eT* src, const uword n_elem) + { + eT val1 = eT(1); + eT val2 = eT(1); + + uword i,j; + + for(i=0, j=1; j +inline +bool +arrayops::is_zero(const eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk) + { + arma_ignore(junk); + + if(n_elem == 0) { return false; } + + if(abs_limit == eT(0)) + { + for(uword i=0; i abs_limit) { return false; } + } + } + + return true; + } + + + +template +inline +bool +arrayops::is_zero(const std::complex* mem, const uword n_elem, const T abs_limit) + { + typedef typename std::complex eT; + + if(n_elem == 0) { return false; } + + if(abs_limit == T(0)) + { + for(uword i=0; i abs_limit) { return false; } + if(std::abs(std::imag(val)) > abs_limit) { return false; } + } + } + + return true; + } + + + +template +inline +bool +arrayops::is_finite(const eT* src, const uword n_elem) + { + uword j; + + for(j=1; j +inline +bool +arrayops::has_inf(const eT* src, const uword n_elem) + { + uword j; + + for(j=1; j +inline +bool +arrayops::has_nan(const eT* src, const uword n_elem) + { + uword j; + + for(j=1; j + inline static bool inv(Mat& A); + + template + inline static bool inv(Mat& out, const Mat& X); + + template + inline static bool inv_rcond(Mat& A, typename get_pod_type::result& out_rcond); + + template + inline static bool inv_tr(Mat& A, const uword layout); + + template + inline static bool inv_tr_rcond(Mat& A, typename get_pod_type::result& out_rcond, const uword layout); + + template + inline static bool inv_sympd(Mat& A, bool& out_sympd_state); + + template + inline static bool inv_sympd(Mat& out, const Mat& X); + + template + inline static bool inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond); + + template + inline static bool inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out_rcond); + + + // + // det and log_det + + template + inline static bool det(eT& out_val, Mat& A); + + template + inline static bool log_det(eT& out_val, typename get_pod_type::result& out_sign, Mat& A); + + template + inline static bool log_det_sympd(typename get_pod_type::result& out_val, Mat& A); + + + // + // lu + + template + inline static bool lu(Mat& L, Mat& U, podarray& ipiv, const Base& X); + + template + inline static bool lu(Mat& L, Mat& U, Mat& P, const Base& X); + + template + inline static bool lu(Mat& L, Mat& U, const Base& X); + + + // + // eig_gen + + template + inline static bool eig_gen(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base& expr); + + template + inline static bool eig_gen(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base< std::complex, T1 >& expr); + + + // + // eig_gen_balance + + template + inline static bool eig_gen_balance(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base& expr); + + template + inline static bool eig_gen_balance(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base< std::complex, T1 >& expr); + + + // + // eig_gen_twosided + + template + inline static bool eig_gen_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base& expr); + + template + inline static bool eig_gen_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base< std::complex, T1 >& expr); + + + // + // eig_gen_twosided_balance + + template + inline static bool eig_gen_twosided_balance(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base& expr); + + template + inline static bool eig_gen_twosided_balance(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base< std::complex, T1 >& expr); + + + // + // eig_pair + + template + inline static bool eig_pair(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base& A_expr, const Base& B_expr); + + template + inline static bool eig_pair(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base< std::complex, T1 >& A_expr, const Base< std::complex, T2 >& B_expr); + + + // + // eig_pair_twosided + + template + inline static bool eig_pair_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base& A_expr, const Base& B_expr); + + template + inline static bool eig_pair_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base< std::complex, T1 >& A_expr, const Base< std::complex, T2 >& B_expr); + + + // + // eig_sym + + template + inline static bool eig_sym(Col& eigval, Mat& A); + + template + inline static bool eig_sym(Col& eigval, Mat< std::complex >& A); + + template + inline static bool eig_sym(Col& eigval, Mat& eigvec, const Mat& X); + + template + inline static bool eig_sym(Col& eigval, Mat< std::complex >& eigvec, const Mat< std::complex >& X); + + template + inline static bool eig_sym_dc(Col& eigval, Mat& eigvec, const Mat& X); + + template + inline static bool eig_sym_dc(Col& eigval, Mat< std::complex >& eigvec, const Mat< std::complex >& X); + + + // + // chol + + template + inline static bool chol_simple(Mat& X); + + template + inline static bool chol(Mat& X, const uword layout); + + template + inline static bool chol_band(Mat& X, const uword KD, const uword layout); + + template + inline static bool chol_band(Mat< std::complex >& X, const uword KD, const uword layout); + + template + inline static bool chol_band_common(Mat& X, const uword KD, const uword layout); + + template + inline static bool chol_pivot(Mat& X, Mat& P, const uword layout); + + + // + // hessenberg decomposition + + template + inline static bool hess(Mat& H, const Base& X, Col& tao); + + + // + // qr + + template + inline static bool qr(Mat& Q, Mat& R, const Base& X); + + template + inline static bool qr_econ(Mat& Q, Mat& R, const Base& X); + + template + inline static bool qr_pivot(Mat& Q, Mat& R, Mat& P, const Base& X); + + template + inline static bool qr_pivot(Mat< std::complex >& Q, Mat< std::complex >& R, Mat& P, const Base,T1>& X); + + + // + // svd + + template + inline static bool svd(Col& S, Mat& A); + + template + inline static bool svd(Col& S, Mat< std::complex >& A); + + + template + inline static bool svd(Mat& U, Col& S, Mat& V, Mat& A); + + template + inline static bool svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A); + + template + inline static bool svd_econ(Mat& U, Col& S, Mat& V, Mat& A, const char mode); + + template + inline static bool svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A, const char mode); + + + template + inline static bool svd_dc(Col& S, Mat& A); + + template + inline static bool svd_dc(Col& S, Mat< std::complex >& A); + + + template + inline static bool svd_dc(Mat& U, Col& S, Mat& V, Mat& A); + + template + inline static bool svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A); + + template + inline static bool svd_dc_econ(Mat& U, Col& S, Mat& V, Mat& A); + + template + inline static bool svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A); + + + // + // solve + + template + inline static bool solve_square_fast(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_square_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); + + template + inline static bool solve_square_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate); + + template + inline static bool solve_square_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate); + + // + + template + inline static bool solve_sympd_fast(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_sympd_fast_common(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_sympd_rcond(Mat& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); + + template + inline static bool solve_sympd_rcond(Mat< std::complex >& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr); + + template + inline static bool solve_sympd_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate); + + template + inline static bool solve_sympd_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate); + + // + + template + inline static bool solve_rect_fast(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_rect_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); + + // + + template + inline static bool solve_approx_svd(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_approx_svd(Mat< std::complex >& out, Mat< std::complex >& A, const Base,T1>& B_expr); + + // + + template + inline static bool solve_trimat_fast(Mat& out, const Mat& A, const Base& B_expr, const uword layout); + + template + inline static bool solve_trimat_rcond(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const Base& B_expr, const uword layout); + + // + + template + inline static bool solve_band_fast(Mat& out, Mat& A, const uword KL, const uword KU, const Base& B_expr); + + template + inline static bool solve_band_fast(Mat< std::complex >& out, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr); + + template + inline static bool solve_band_fast_common(Mat& out, const Mat& A, const uword KL, const uword KU, const Base& B_expr); + + template + inline static bool solve_band_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr); + + template + inline static bool solve_band_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr); + + template + inline static bool solve_band_rcond_common(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const uword KL, const uword KU, const Base& B_expr); + + template + inline static bool solve_band_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool equilibrate); + + template + inline static bool solve_band_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base,T1>& B_expr, const bool equilibrate); + + // + + template + inline static bool solve_tridiag_fast(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_tridiag_fast(Mat< std::complex >& out, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr); + + template + inline static bool solve_tridiag_fast_common(Mat& out, const Mat& A, const Base& B_expr); + + + // + // Schur decomposition + + template + inline static bool schur(Mat& U, Mat& S, const Base& X, const bool calc_U = true); + + template + inline static bool schur(Mat< std::complex >& U, Mat< std::complex >& S, const Base,T1>& X, const bool calc_U = true); + + template + inline static bool schur(Mat< std::complex >& U, Mat< std::complex >& S, const bool calc_U = true); + + // + // solve the Sylvester equation AX + XB = C + + template + inline static bool syl(Mat& X, const Mat& A, const Mat& B, const Mat& C); + + + // + // QZ decomposition + + template + inline static bool qz(Mat& A, Mat& B, Mat& vsl, Mat& vsr, const Base& X_expr, const Base& Y_expr, const char mode); + + template + inline static bool qz(Mat< std::complex >& A, Mat< std::complex >& B, Mat< std::complex >& vsl, Mat< std::complex >& vsr, const Base< std::complex, T1 >& X_expr, const Base< std::complex, T2 >& Y_expr, const char mode); + + + // + // rcond + + template + inline static eT rcond(Mat& A); + + template + inline static T rcond(Mat< std::complex >& A); + + template + inline static eT rcond_sympd(Mat& A, bool& calc_ok); + + template + inline static T rcond_sympd(Mat< std::complex >& A, bool& calc_ok); + + template + inline static eT rcond_trimat(const Mat& A, const uword layout); + + template + inline static T rcond_trimat(const Mat< std::complex >& A, const uword layout); + + + // + // lu_rcond (rcond from pre-computed LU decomposition) + + template + inline static eT lu_rcond(const Mat& A, const eT norm_val); + + template + inline static T lu_rcond(const Mat< std::complex >& A, const T norm_val); + + template + inline static eT lu_rcond_sympd(const Mat& A, const eT norm_val); + + template + inline static T lu_rcond_sympd(const Mat< std::complex >& A, const T norm_val); + + template + inline static eT lu_rcond_band(const Mat& AB, const uword KL, const uword KU, const podarray& ipiv, const eT norm_val); + + template + inline static T lu_rcond_band(const Mat< std::complex >& AB, const uword KL, const uword KU, const podarray& ipiv, const T norm_val); + + + // + // misc + + template + inline static bool crippled_lapack(const Base&); + + template + inline static bool rudimentary_sym_check(const Mat& X); + + template + inline static bool rudimentary_sym_check(const Mat< std::complex >& X); + + template + inline static typename get_pod_type::result norm1_gen(const Mat& A); + + template + inline static typename get_pod_type::result norm1_sym(const Mat& A); + + template + inline static typename get_pod_type::result norm1_band(const Mat& A, const uword KL, const uword KU); + }; + + + +namespace qz_helper + { + template inline blas_int select_lhp(const T* x_ptr, const T* y_ptr, const T* z_ptr); + template inline blas_int select_rhp(const T* x_ptr, const T* y_ptr, const T* z_ptr); + template inline blas_int select_iuc(const T* x_ptr, const T* y_ptr, const T* z_ptr); + template inline blas_int select_ouc(const T* x_ptr, const T* y_ptr, const T* z_ptr); + + template inline blas_int cx_select_lhp(const std::complex* x_ptr, const std::complex* y_ptr); + template inline blas_int cx_select_rhp(const std::complex* x_ptr, const std::complex* y_ptr); + template inline blas_int cx_select_iuc(const std::complex* x_ptr, const std::complex* y_ptr); + template inline blas_int cx_select_ouc(const std::complex* x_ptr, const std::complex* y_ptr); + + template inline void_ptr ptr_cast(blas_int (*function)(const T*, const T*, const T*)); + template inline void_ptr ptr_cast(blas_int (*function)(const std::complex*, const std::complex*)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/auxlib_meat.hpp b/src/armadillo/include/armadillo_bits/auxlib_meat.hpp new file mode 100644 index 0000000..373aaa4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/auxlib_meat.hpp @@ -0,0 +1,7050 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup auxlib +//! @{ + + + +template +inline +bool +auxlib::inv(Mat& A) + { + arma_extra_debug_sigprint(); + + if(A.is_empty()) { return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + + podarray ipiv(A.n_rows); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n, &n, A.memptr(), &lda, ipiv.memptr(), &info); + + if(info != 0) { return false; } + + if(n > 16) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::getri()"); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_extra_debug_print("lapack::getri()"); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + return (info == 0); + } + #else + { + arma_ignore(A); + arma_stop_logic_error("inv(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv(Mat& out, const Mat& X) + { + arma_extra_debug_sigprint(); + + out = X; + + return auxlib::inv(out); + } + + + +template +inline +bool +auxlib::inv_rcond(Mat& A, typename get_pod_type::result& out_rcond) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out_rcond = T(0); + + if(A.is_empty()) { return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + T norm_val = T(0); + + podarray junk(1); + podarray ipiv(A.n_rows); + + arma_extra_debug_print("lapack::lange()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &n, &n, A.memptr(), &lda, junk.memptr()); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n, &n, A.memptr(), &lda, ipiv.memptr(), &info); + + if(info != 0) { return false; } + + out_rcond = auxlib::lu_rcond(A, norm_val); + + if(n > 16) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::getri()"); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_extra_debug_print("lapack::getri()"); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + return (info == 0); + } + #else + { + arma_ignore(A); + arma_stop_logic_error("inv_rcond(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_tr(Mat& A, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(A.is_empty()) { return true; } + + arma_debug_assert_blas_size(A); + + char uplo = (layout == 0) ? 'U' : 'L'; + char diag = 'N'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + + arma_extra_debug_print("lapack::trtri()"); + lapack::trtri(&uplo, &diag, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(layout); + arma_stop_logic_error("inv(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_tr_rcond(Mat& A, typename get_pod_type::result& out_rcond, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename get_pod_type::result T; + + if(A.is_empty()) { return true; } + + out_rcond = auxlib::rcond_trimat(A, layout); + + arma_debug_assert_blas_size(A); + + char uplo = (layout == 0) ? 'U' : 'L'; + char diag = 'N'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + + arma_extra_debug_print("lapack::trtri()"); + lapack::trtri(&uplo, &diag, &n, A.memptr(), &n, &info); + + if(info != 0) { out_rcond = T(0); return false; } + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(out_rcond); + arma_ignore(layout); + arma_stop_logic_error("inv(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_sympd(Mat& A, bool& out_sympd_state) + { + arma_extra_debug_sigprint(); + + out_sympd_state = false; + + if(A.is_empty()) { return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + + // NOTE: for complex matrices, zpotrf() assumes the matrix is hermitian (not simply symmetric) + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + out_sympd_state = true; + + arma_extra_debug_print("lapack::potri()"); + lapack::potri(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + A = symmatl(A); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_stop_logic_error("inv_sympd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_sympd(Mat& out, const Mat& X) + { + arma_extra_debug_sigprint(); + + out = X; + + bool sympd_state_junk = false; + + return auxlib::inv_sympd(out, sympd_state_junk); + } + + + +template +inline +bool +auxlib::inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond) + { + arma_extra_debug_sigprint(); + + out_sympd_state = false; + + if(A.is_empty()) { return true; } + + #if defined(ARMA_USE_LAPACK) + { + typedef typename get_pod_type::result T; + + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + T norm_val = T(0); + + podarray work(A.n_rows); + + arma_extra_debug_print("lapack::lansy()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { out_rcond = eT(0); return false; } + + out_sympd_state = true; + + out_rcond = auxlib::lu_rcond_sympd(A, norm_val); + + if(arma_isnan(out_rcond)) { return false; } + + arma_extra_debug_print("lapack::potri()"); + lapack::potri(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + A = symmatl(A); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + arma_stop_logic_error("inv_sympd_rcond(): use LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out_rcond) + { + arma_extra_debug_sigprint(); + + out_sympd_state = false; + + if(A.is_empty()) { return true; } + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + return false; + } + #elif defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + T norm_val = T(0); + + podarray work(A.n_rows); + + arma_extra_debug_print("lapack::lanhe()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { out_rcond = T(0); return false; } + + out_sympd_state = true; + + out_rcond = auxlib::lu_rcond_sympd(A, norm_val); + + if(arma_isnan(out_rcond)) { return false; } + + arma_extra_debug_print("lapack::potri()"); + lapack::potri(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + A = symmatl(A); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + arma_stop_logic_error("inv_sympd_rcond(): use LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! determinant of a matrix +template +inline +bool +auxlib::det(eT& out_val, Mat& A) + { + arma_extra_debug_sigprint(); + + if(A.is_empty()) { out_val = eT(1); return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + podarray ipiv(A.n_rows); + + blas_int info = 0; + blas_int n_rows = blas_int(A.n_rows); + blas_int n_cols = blas_int(A.n_cols); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n_rows, &n_cols, A.memptr(), &n_rows, ipiv.memptr(), &info); + + if(info < 0) { return false; } + + // on output A appears to be L+U_alt, where U_alt is U with the main diagonal set to zero + eT val = A.at(0,0); + for(uword i=1; i < A.n_rows; ++i) { val *= A.at(i,i); } + + blas_int sign = +1; + for(uword i=0; i < A.n_rows; ++i) + { + // NOTE: adjustment of -1 is required as Fortran counts from 1 + if( blas_int(i) != (ipiv.mem[i] - 1) ) { sign *= -1; } + } + + out_val = (sign < 0) ? eT(-val) : eT(val); + + return true; + } + #else + { + arma_ignore(out_val); + arma_ignore(A); + arma_stop_logic_error("det(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! log determinant of a matrix +template +inline +bool +auxlib::log_det(eT& out_val, typename get_pod_type::result& out_sign, Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.is_empty()) { out_val = eT(0); out_sign = T(1); return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + podarray ipiv(A.n_rows); + + blas_int info = 0; + blas_int n_rows = blas_int(A.n_rows); + blas_int n_cols = blas_int(A.n_cols); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n_rows, &n_cols, A.memptr(), &n_rows, ipiv.memptr(), &info); + + if(info < 0) { return false; } + + // on output A appears to be L+U_alt, where U_alt is U with the main diagonal set to zero + + sword sign = (is_cx::no) ? ( (access::tmp_real( A.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; + eT val = (is_cx::no) ? std::log( (access::tmp_real( A.at(0,0) ) < T(0)) ? A.at(0,0)*T(-1) : A.at(0,0) ) : std::log( A.at(0,0) ); + + for(uword i=1; i < A.n_rows; ++i) + { + const eT x = A.at(i,i); + + sign *= (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; + val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + } + + for(uword i=0; i < A.n_rows; ++i) + { + if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 + { + sign *= -1; + } + } + + out_val = val; + out_sign = T(sign); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(out_val); + arma_ignore(out_sign); + arma_stop_logic_error("log_det(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::log_det_sympd(typename get_pod_type::result& out_val, Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.is_empty()) { out_val = T(0); return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + T val = T(0); + + for(uword i=0; i < A.n_rows; ++i) { val += std::log( access::tmp_real(A.at(i,i)) ); } + + out_val = T(2) * val; + + return true; + } + #else + { + arma_ignore(out_val); + arma_ignore(A); + arma_stop_logic_error("log_det_sympd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! LU decomposition of a matrix +template +inline +bool +auxlib::lu(Mat& L, Mat& U, podarray& ipiv, const Base& X) + { + arma_extra_debug_sigprint(); + + U = X.get_ref(); + + const uword U_n_rows = U.n_rows; + const uword U_n_cols = U.n_cols; + + if(U.is_empty()) { L.set_size(U_n_rows, 0); U.set_size(0, U_n_cols); ipiv.reset(); return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(U); + + ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); + + blas_int info = 0; + + blas_int n_rows = blas_int(U_n_rows); + blas_int n_cols = blas_int(U_n_cols); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info); + + if(info < 0) { return false; } + + // take into account that Fortran counts from 1 + arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem); + + L.copy_size(U); + + for(uword col=0; col < U_n_cols; ++col) + { + for(uword row=0; (row < col) && (row < U_n_rows); ++row) + { + L.at(row,col) = eT(0); + } + + if( L.in_range(col,col) ) + { + L.at(col,col) = eT(1); + } + + for(uword row = (col+1); row < U_n_rows; ++row) + { + L.at(row,col) = U.at(row,col); + U.at(row,col) = eT(0); + } + } + + return true; + } + #else + { + arma_stop_logic_error("lu(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::lu(Mat& L, Mat& U, Mat& P, const Base& X) + { + arma_extra_debug_sigprint(); + + podarray ipiv1; + const bool status = auxlib::lu(L, U, ipiv1, X); + + if(status == false) { return false; } + + if(U.is_empty()) + { + // L and U have been already set to the correct empty matrices + P.eye(L.n_rows, L.n_rows); + return true; + } + + const uword n = ipiv1.n_elem; + const uword P_rows = U.n_rows; + + podarray ipiv2(P_rows); + + const blas_int* ipiv1_mem = ipiv1.memptr(); + blas_int* ipiv2_mem = ipiv2.memptr(); + + for(uword i=0; i(ipiv1_mem[i]); + + if( ipiv2_mem[i] != ipiv2_mem[k] ) + { + std::swap( ipiv2_mem[i], ipiv2_mem[k] ); + } + } + + P.zeros(P_rows, P_rows); + + for(uword row=0; row(ipiv2_mem[row])) = eT(1); + } + + if(L.n_cols > U.n_rows) + { + L.shed_cols(U.n_rows, L.n_cols-1); + } + + if(U.n_rows > L.n_cols) + { + U.shed_rows(L.n_cols, U.n_rows-1); + } + + return true; + } + + + +template +inline +bool +auxlib::lu(Mat& L, Mat& U, const Base& X) + { + arma_extra_debug_sigprint(); + + podarray ipiv1; + const bool status = auxlib::lu(L, U, ipiv1, X); + + if(status == false) { return false; } + + if(U.is_empty()) + { + // L and U have been already set to the correct empty matrices + return true; + } + + const uword n = ipiv1.n_elem; + const uword P_rows = U.n_rows; + + podarray ipiv2(P_rows); + + const blas_int* ipiv1_mem = ipiv1.memptr(); + blas_int* ipiv2_mem = ipiv2.memptr(); + + for(uword i=0; i(ipiv1_mem[i]); + + if( ipiv2_mem[i] != ipiv2_mem[k] ) + { + std::swap( ipiv2_mem[i], ipiv2_mem[k] ); + L.swap_rows( static_cast(ipiv2_mem[i]), static_cast(ipiv2_mem[k]) ); + } + } + + if(L.n_cols > U.n_rows) + { + L.shed_cols(U.n_rows, L.n_cols-1); + } + + if(U.n_rows > L.n_cols) + { + U.shed_rows(L.n_cols, U.n_rows-1); + } + + return true; + } + + + +//! eigen decomposition of general square matrix (real) +template +inline +bool +auxlib::eig_gen + ( + Mat< std::complex >& vals, + Mat< std::complex >& vecs, + const bool vecs_on, + const Base& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + Mat tmp(1, 1, arma_nozeros_indicator()); + + if(vecs_on) + { + vecs.set_size(X.n_rows, X.n_rows); + tmp.set_size(X.n_rows, X.n_rows); + } + + podarray junk(1); + + char jobvl = 'N'; + char jobvr = (vecs_on) ? 'V' : 'N'; + blas_int N = blas_int(X.n_rows); + T* vl = junk.memptr(); + T* vr = (vecs_on) ? tmp.memptr() : junk.memptr(); + blas_int ldvl = blas_int(1); + blas_int ldvr = (vecs_on) ? blas_int(tmp.n_rows) : blas_int(1); + blas_int lwork = 64*N; // lwork_min = (vecs_on) ? (std::max)(blas_int(1), 4*N) : (std::max)(blas_int(1), 3*N) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + + podarray vals_real(X.n_rows); + podarray vals_imag(X.n_rows); + + arma_extra_debug_print("lapack::geev() -- START"); + lapack::geev(&jobvl, &jobvr, &N, X.memptr(), &N, vals_real.memptr(), vals_imag.memptr(), vl, &ldvl, vr, &ldvr, work.memptr(), &lwork, &info); + arma_extra_debug_print("lapack::geev() -- END"); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + std::complex* vals_mem = vals.memptr(); + + for(uword i=0; i < X.n_rows; ++i) { vals_mem[i] = std::complex(vals_real[i], vals_imag[i]); } + + if(vecs_on) + { + for(uword j=0; j < X.n_rows; ++j) + { + if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < X.n_rows; ++i) + { + vecs.at(i,j) = std::complex( tmp.at(i,j), tmp.at(i,j+1) ); + vecs.at(i,j+1) = std::complex( tmp.at(i,j), -tmp.at(i,j+1) ); + } + + ++j; + } + else + { + for(uword i=0; i(tmp.at(i,j), T(0)); + } + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(vecs); + arma_ignore(vecs_on); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigen decomposition of general square matrix (complex) +template +inline +bool +auxlib::eig_gen + ( + Mat< std::complex >& vals, + Mat< std::complex >& vecs, + const bool vecs_on, + const Base< std::complex, T1 >& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + if(vecs_on) { vecs.set_size(X.n_rows, X.n_rows); } + + podarray junk(1); + + char jobvl = 'N'; + char jobvr = (vecs_on) ? 'V' : 'N'; + blas_int N = blas_int(X.n_rows); + eT* vl = junk.memptr(); + eT* vr = (vecs_on) ? vecs.memptr() : junk.memptr(); + blas_int ldvl = blas_int(1); + blas_int ldvr = (vecs_on) ? blas_int(vecs.n_rows) : blas_int(1); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 2*N) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + podarray< T> rwork( static_cast(2*N) ); + + arma_extra_debug_print("lapack::cx_geev() -- START"); + lapack::cx_geev(&jobvl, &jobvr, &N, X.memptr(), &N, vals.memptr(), vl, &ldvl, vr, &ldvr, work.memptr(), &lwork, rwork.memptr(), &info); + arma_extra_debug_print("lapack::cx_geev() -- END"); + + return (info == 0); + } + #else + { + arma_ignore(vals); + arma_ignore(vecs); + arma_ignore(vecs_on); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigen decomposition of general square matrix (real, balance given matrix) +template +inline +bool +auxlib::eig_gen_balance + ( + Mat< std::complex >& vals, + Mat< std::complex >& vecs, + const bool vecs_on, + const Base& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + Mat tmp(1, 1, arma_nozeros_indicator()); + + if(vecs_on) + { + vecs.set_size(X.n_rows, X.n_rows); + tmp.set_size(X.n_rows, X.n_rows); + } + + podarray junk(1); + + char bal = 'B'; + char jobvl = 'N'; + char jobvr = (vecs_on) ? 'V' : 'N'; + char sense = 'N'; + blas_int N = blas_int(X.n_rows); + T* vl = junk.memptr(); + T* vr = (vecs_on) ? tmp.memptr() : junk.memptr(); + blas_int ldvl = blas_int(1); + blas_int ldvr = (vecs_on) ? blas_int(tmp.n_rows) : blas_int(1); + blas_int ilo = blas_int(0); + blas_int ihi = blas_int(0); + T abnrm = T(0); + blas_int lwork = 64*N; // lwork_min = (vecs_on) ? (std::max)(blas_int(1), 2*N) : (std::max)(blas_int(1), 3*N) + blas_int info = blas_int(0); + + podarray scale(X.n_rows); + podarray rconde(X.n_rows); + podarray rcondv(X.n_rows); + + podarray work( static_cast(lwork) ); + podarray iwork( uword(1) ); // iwork not used by lapack::geevx() as sense = 'N' + + podarray vals_real(X.n_rows); + podarray vals_imag(X.n_rows); + + arma_extra_debug_print("lapack::geevx() -- START"); + lapack::geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals_real.memptr(), vals_imag.memptr(), vl, &ldvl, vr, &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, iwork.memptr(), &info); + arma_extra_debug_print("lapack::geevx() -- END"); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + std::complex* vals_mem = vals.memptr(); + + for(uword i=0; i < X.n_rows; ++i) { vals_mem[i] = std::complex(vals_real[i], vals_imag[i]); } + + if(vecs_on) + { + for(uword j=0; j < X.n_rows; ++j) + { + if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < X.n_rows; ++i) + { + vecs.at(i,j) = std::complex( tmp.at(i,j), tmp.at(i,j+1) ); + vecs.at(i,j+1) = std::complex( tmp.at(i,j), -tmp.at(i,j+1) ); + } + + ++j; + } + else + { + for(uword i=0; i(tmp.at(i,j), T(0)); + } + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(vecs); + arma_ignore(vecs_on); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigen decomposition of general square matrix (complex, balance given matrix) +template +inline +bool +auxlib::eig_gen_balance + ( + Mat< std::complex >& vals, + Mat< std::complex >& vecs, + const bool vecs_on, + const Base< std::complex, T1 >& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::eig_gen_balance(): redirecting to auxlib::eig_gen() due to crippled LAPACK"); + + return auxlib::eig_gen(vals, vecs, vecs_on, expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + if(vecs_on) { vecs.set_size(X.n_rows, X.n_rows); } + + podarray junk(1); + + char bal = 'B'; + char jobvl = 'N'; + char jobvr = (vecs_on) ? 'V' : 'N'; + char sense = 'N'; + blas_int N = blas_int(X.n_rows); + eT* vl = junk.memptr(); + eT* vr = (vecs_on) ? vecs.memptr() : junk.memptr(); + blas_int ldvl = blas_int(1); + blas_int ldvr = (vecs_on) ? blas_int(vecs.n_rows) : blas_int(1); + blas_int ilo = blas_int(0); + blas_int ihi = blas_int(0); + T abnrm = T(0); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), blas_int(2*N)) + blas_int info = blas_int(0); + + podarray scale(X.n_rows); + podarray rconde(X.n_rows); + podarray rcondv(X.n_rows); + + podarray work( static_cast(lwork) ); + podarray< T> rwork( static_cast(2*N) ); + + arma_extra_debug_print("lapack::cx_geevx() -- START"); + lapack::cx_geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals.memptr(), vl, &ldvl, vr, &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); + arma_extra_debug_print("lapack::cx_geevx() -- END"); + + return (info == 0); + } + #else + { + arma_ignore(vals); + arma_ignore(vecs); + arma_ignore(vecs_on); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigen decomposition of general square matrix (real) +template +inline +bool +auxlib::eig_gen_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); + + Mat ltmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + Mat rtmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(X.n_rows); + blas_int ldvl = blas_int(ltmp.n_rows); + blas_int ldvr = blas_int(rtmp.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 4*N) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + + podarray vals_real(X.n_rows); + podarray vals_imag(X.n_rows); + + arma_extra_debug_print("lapack::geev() -- START"); + lapack::geev(&jobvl, &jobvr, &N, X.memptr(), &N, vals_real.memptr(), vals_imag.memptr(), ltmp.memptr(), &ldvl, rtmp.memptr(), &ldvr, work.memptr(), &lwork, &info); + arma_extra_debug_print("lapack::geev() -- END"); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + std::complex* vals_mem = vals.memptr(); + + for(uword i=0; i < X.n_rows; ++i) { vals_mem[i] = std::complex(vals_real[i], vals_imag[i]); } + + for(uword j=0; j < X.n_rows; ++j) + { + if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < X.n_rows; ++i) + { + lvecs.at(i,j) = std::complex( ltmp.at(i,j), ltmp.at(i,j+1) ); + lvecs.at(i,j+1) = std::complex( ltmp.at(i,j), -ltmp.at(i,j+1) ); + rvecs.at(i,j) = std::complex( rtmp.at(i,j), rtmp.at(i,j+1) ); + rvecs.at(i,j+1) = std::complex( rtmp.at(i,j), -rtmp.at(i,j+1) ); + } + ++j; + } + else + { + for(uword i=0; i(ltmp.at(i,j), T(0)); + rvecs.at(i,j) = std::complex(rtmp.at(i,j), T(0)); + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigen decomposition of general square matrix (complex) +template +inline +bool +auxlib::eig_gen_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base< std::complex, T1 >& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(X.n_rows); + blas_int ldvl = blas_int(lvecs.n_rows); + blas_int ldvr = blas_int(rvecs.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 2*N) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + podarray< T> rwork( static_cast(2*N) ); + + arma_extra_debug_print("lapack::cx_geev() -- START"); + lapack::cx_geev(&jobvl, &jobvr, &N, X.memptr(), &N, vals.memptr(), lvecs.memptr(), &ldvl, rvecs.memptr(), &ldvr, work.memptr(), &lwork, rwork.memptr(), &info); + arma_extra_debug_print("lapack::cx_geev() -- END"); + + return (info == 0); + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigen decomposition of general square matrix (real, balance given matrix) +template +inline +bool +auxlib::eig_gen_twosided_balance + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); + + Mat ltmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + Mat rtmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + + char bal = 'B'; + char jobvl = 'V'; + char jobvr = 'V'; + char sense = 'N'; + blas_int N = blas_int(X.n_rows); + blas_int ldvl = blas_int(ltmp.n_rows); + blas_int ldvr = blas_int(rtmp.n_rows); + blas_int ilo = blas_int(0); + blas_int ihi = blas_int(0); + T abnrm = T(0); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), blas_int(3*N)) + blas_int info = blas_int(0); + + podarray scale(X.n_rows); + podarray rconde(X.n_rows); + podarray rcondv(X.n_rows); + + podarray work( static_cast(lwork) ); + podarray iwork( uword(1) ); // iwork not used by lapack::geevx() as sense = 'N' + + podarray vals_real(X.n_rows); + podarray vals_imag(X.n_rows); + + arma_extra_debug_print("lapack::geevx() -- START"); + lapack::geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals_real.memptr(), vals_imag.memptr(), ltmp.memptr(), &ldvl, rtmp.memptr(), &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, iwork.memptr(), &info); + arma_extra_debug_print("lapack::geevx() -- END"); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + std::complex* vals_mem = vals.memptr(); + + for(uword i=0; i < X.n_rows; ++i) { vals_mem[i] = std::complex(vals_real[i], vals_imag[i]); } + + for(uword j=0; j < X.n_rows; ++j) + { + if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < X.n_rows; ++i) + { + lvecs.at(i,j) = std::complex( ltmp.at(i,j), ltmp.at(i,j+1) ); + lvecs.at(i,j+1) = std::complex( ltmp.at(i,j), -ltmp.at(i,j+1) ); + rvecs.at(i,j) = std::complex( rtmp.at(i,j), rtmp.at(i,j+1) ); + rvecs.at(i,j+1) = std::complex( rtmp.at(i,j), -rtmp.at(i,j+1) ); + } + ++j; + } + else + { + for(uword i=0; i(ltmp.at(i,j), T(0)); + rvecs.at(i,j) = std::complex(rtmp.at(i,j), T(0)); + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigen decomposition of general square matrix (complex, balance given matrix) +template +inline +bool +auxlib::eig_gen_twosided_balance + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base< std::complex, T1 >& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::eig_gen_twosided_balance(): redirecting to auxlib::eig_gen() due to crippled LAPACK"); + + return auxlib::eig_gen(vals, lvecs, rvecs, expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); + + char bal = 'B'; + char jobvl = 'V'; + char jobvr = 'V'; + char sense = 'N'; + blas_int N = blas_int(X.n_rows); + blas_int ldvl = blas_int(lvecs.n_rows); + blas_int ldvr = blas_int(rvecs.n_rows); + blas_int ilo = blas_int(0); + blas_int ihi = blas_int(0); + T abnrm = T(0); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), blas_int(2*N)) + blas_int info = blas_int(0); + + podarray scale(X.n_rows); + podarray rconde(X.n_rows); + podarray rcondv(X.n_rows); + + podarray work( static_cast(lwork) ); + podarray< T> rwork( static_cast(2*N) ); + + arma_extra_debug_print("lapack::cx_geevx() -- START"); + lapack::cx_geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals.memptr(), lvecs.memptr(), &ldvl, rvecs.memptr(), &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); + arma_extra_debug_print("lapack::cx_geevx() -- END"); + + return (info == 0); + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigendecomposition of general square matrix pair (real) +template +inline +bool +auxlib::eig_pair + ( + Mat< std::complex >& vals, + Mat< std::complex >& vecs, + const bool vecs_on, + const Base& A_expr, + const Base& B_expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef std::complex eT; + + Mat A(A_expr.get_ref()); + Mat B(B_expr.get_ref()); + + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "eig_pair(): given matrices must be square sized" ); + + arma_debug_check( (A.n_rows != B.n_rows), "eig_pair(): given matrices must have the same size" ); + + arma_debug_assert_blas_size(A); + + if(A.is_empty()) { vals.reset(); vecs.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + vals.set_size(A.n_rows, 1); + + Mat tmp(1, 1, arma_nozeros_indicator()); + + if(vecs_on) + { + vecs.set_size(A.n_rows, A.n_rows); + tmp.set_size(A.n_rows, A.n_rows); + } + + podarray junk(1); + + char jobvl = 'N'; + char jobvr = (vecs_on) ? 'V' : 'N'; + blas_int N = blas_int(A.n_rows); + T* vl = junk.memptr(); + T* vr = (vecs_on) ? tmp.memptr() : junk.memptr(); + blas_int ldvl = blas_int(1); + blas_int ldvr = (vecs_on) ? blas_int(tmp.n_rows) : blas_int(1); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 8*N) + blas_int info = 0; + + podarray alphar(A.n_rows); + podarray alphai(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + + arma_extra_debug_print("lapack::ggev()"); + lapack::ggev(&jobvl, &jobvr, &N, A.memptr(), &N, B.memptr(), &N, alphar.memptr(), alphai.memptr(), beta.memptr(), vl, &ldvl, vr, &ldvr, work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + eT* vals_mem = vals.memptr(); + const T* alphar_mem = alphar.memptr(); + const T* alphai_mem = alphai.memptr(); + const T* beta_mem = beta.memptr(); + + bool beta_has_zero = false; + + for(uword j=0; j(re, im); + + if( (alphai_val > T(0)) && (j < (A.n_rows-1)) ) + { + ++j; + vals_mem[j] = std::complex(re,-im); // force exact conjugate + } + } + + if(beta_has_zero) { arma_debug_warn_level(1, "eig_pair(): given matrices appear ill-conditioned"); } + + if(vecs_on) + { + for(uword j=0; j( tmp.at(i,j), tmp.at(i,j+1) ); + vecs.at(i,j+1) = std::complex( tmp.at(i,j), -tmp.at(i,j+1) ); + } + + ++j; + } + else + { + for(uword i=0; i(tmp.at(i,j), T(0)); + } + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(vecs); + arma_ignore(vecs_on); + arma_ignore(A_expr); + arma_ignore(B_expr); + arma_stop_logic_error("eig_pair(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigendecomposition of general square matrix pair (complex) +template +inline +bool +auxlib::eig_pair + ( + Mat< std::complex >& vals, + Mat< std::complex >& vecs, + const bool vecs_on, + const Base< std::complex, T1 >& A_expr, + const Base< std::complex, T2 >& B_expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat A(A_expr.get_ref()); + Mat B(B_expr.get_ref()); + + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "eig_pair(): given matrices must be square sized" ); + + arma_debug_check( (A.n_rows != B.n_rows), "eig_pair(): given matrices must have the same size" ); + + arma_debug_assert_blas_size(A); + + if(A.is_empty()) { vals.reset(); vecs.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + vals.set_size(A.n_rows, 1); + + if(vecs_on) { vecs.set_size(A.n_rows, A.n_rows); } + + podarray junk(1); + + char jobvl = 'N'; + char jobvr = (vecs_on) ? 'V' : 'N'; + blas_int N = blas_int(A.n_rows); + eT* vl = junk.memptr(); + eT* vr = (vecs_on) ? vecs.memptr() : junk.memptr(); + blas_int ldvl = blas_int(1); + blas_int ldvr = (vecs_on) ? blas_int(vecs.n_rows) : blas_int(1); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1),2*N) + blas_int info = 0; + + podarray alpha(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + podarray rwork( static_cast(8*N) ); + + arma_extra_debug_print("lapack::cx_ggev()"); + lapack::cx_ggev(&jobvl, &jobvr, &N, A.memptr(), &N, B.memptr(), &N, alpha.memptr(), beta.memptr(), vl, &ldvl, vr, &ldvr, work.memptr(), &lwork, rwork.memptr(), &info); + + if(info != 0) { return false; } + + eT* vals_mem = vals.memptr(); + const eT* alpha_mem = alpha.memptr(); + const eT* beta_mem = beta.memptr(); + + const std::complex zero(T(0), T(0)); + + bool beta_has_zero = false; + + for(uword i=0; i +inline +bool +auxlib::eig_pair_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base& A_expr, + const Base& B_expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef std::complex eT; + + Mat A(A_expr.get_ref()); + Mat B(B_expr.get_ref()); + + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "eig_pair(): given matrices must be square sized" ); + + arma_debug_check( (A.n_rows != B.n_rows), "eig_pair(): given matrices must have the same size" ); + + arma_debug_assert_blas_size(A); + + if(A.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + vals.set_size(A.n_rows, 1); + + lvecs.set_size(A.n_rows, A.n_rows); + rvecs.set_size(A.n_rows, A.n_rows); + + Mat ltmp(A.n_rows, A.n_rows, arma_nozeros_indicator()); + Mat rtmp(A.n_rows, A.n_rows, arma_nozeros_indicator()); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(A.n_rows); + blas_int ldvl = blas_int(ltmp.n_rows); + blas_int ldvr = blas_int(rtmp.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 8*N) + blas_int info = 0; + + podarray alphar(A.n_rows); + podarray alphai(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + + arma_extra_debug_print("lapack::ggev()"); + lapack::ggev(&jobvl, &jobvr, &N, A.memptr(), &N, B.memptr(), &N, alphar.memptr(), alphai.memptr(), beta.memptr(), ltmp.memptr(), &ldvl, rtmp.memptr(), &ldvr, work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + eT* vals_mem = vals.memptr(); + const T* alphar_mem = alphar.memptr(); + const T* alphai_mem = alphai.memptr(); + const T* beta_mem = beta.memptr(); + + bool beta_has_zero = false; + + for(uword j=0; j(re, im); + + if( (alphai_val > T(0)) && (j < (A.n_rows-1)) ) + { + ++j; + vals_mem[j] = std::complex(re,-im); // force exact conjugate + } + } + + if(beta_has_zero) { arma_debug_warn_level(1, "eig_pair(): given matrices appear ill-conditioned"); } + + for(uword j=0; j < A.n_rows; ++j) + { + if( (j < (A.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < A.n_rows; ++i) + { + lvecs.at(i,j) = std::complex( ltmp.at(i,j), ltmp.at(i,j+1) ); + lvecs.at(i,j+1) = std::complex( ltmp.at(i,j), -ltmp.at(i,j+1) ); + rvecs.at(i,j) = std::complex( rtmp.at(i,j), rtmp.at(i,j+1) ); + rvecs.at(i,j+1) = std::complex( rtmp.at(i,j), -rtmp.at(i,j+1) ); + } + ++j; + } + else + { + for(uword i=0; i(ltmp.at(i,j), T(0)); + rvecs.at(i,j) = std::complex(rtmp.at(i,j), T(0)); + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(A_expr); + arma_ignore(B_expr); + arma_stop_logic_error("eig_pair(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigendecomposition of general square matrix pair (complex) +template +inline +bool +auxlib::eig_pair_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base< std::complex, T1 >& A_expr, + const Base< std::complex, T2 >& B_expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat A(A_expr.get_ref()); + Mat B(B_expr.get_ref()); + + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "eig_pair(): given matrices must be square sized" ); + + arma_debug_check( (A.n_rows != B.n_rows), "eig_pair(): given matrices must have the same size" ); + + arma_debug_assert_blas_size(A); + + if(A.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + vals.set_size(A.n_rows, 1); + + lvecs.set_size(A.n_rows, A.n_rows); + rvecs.set_size(A.n_rows, A.n_rows); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(A.n_rows); + blas_int ldvl = blas_int(lvecs.n_rows); + blas_int ldvr = blas_int(rvecs.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1),2*N) + blas_int info = 0; + + podarray alpha(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + podarray rwork( static_cast(8*N) ); + + arma_extra_debug_print("lapack::cx_ggev()"); + lapack::cx_ggev(&jobvl, &jobvr, &N, A.memptr(), &N, B.memptr(), &N, alpha.memptr(), beta.memptr(), lvecs.memptr(), &ldvl, rvecs.memptr(), &ldvr, work.memptr(), &lwork, rwork.memptr(), &info); + + if(info != 0) { return false; } + + eT* vals_mem = vals.memptr(); + const eT* alpha_mem = alpha.memptr(); + const eT* beta_mem = beta.memptr(); + + const std::complex zero(T(0), T(0)); + + bool beta_has_zero = false; + + for(uword i=0; i +inline +bool +auxlib::eig_sym(Col& eigval, Mat& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix must be square sized" ); + + if(A.is_empty()) { eigval.reset(); return true; } + + if((arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false)) + { + arma_debug_warn_level(1, "eig_sym(): given matrix is not symmetric"); + } + + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(A)) { return false; } + + arma_debug_assert_blas_size(A); + + eigval.set_size(A.n_rows); + + char jobz = 'N'; + char uplo = 'U'; + + blas_int N = blas_int(A.n_rows); + blas_int lwork = (64+2)*N; // lwork_min = (std::max)(blas_int(1), 3*N-1) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + + arma_extra_debug_print("lapack::syev()"); + lapack::syev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info); + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(A); + arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigenvalues of a hermitian complex matrix +template +inline +bool +auxlib::eig_sym(Col& eigval, Mat< std::complex >& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix must be square sized" ); + + if(A.is_empty()) { eigval.reset(); return true; } + + if((arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false)) + { + arma_debug_warn_level(1, "eig_sym(): given matrix is not hermitian"); + } + + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(A)) { return false; } + + arma_debug_assert_blas_size(A); + + eigval.set_size(A.n_rows); + + char jobz = 'N'; + char uplo = 'U'; + + blas_int N = blas_int(A.n_rows); + blas_int lwork = (64+1)*N; // lwork_min = (std::max)(blas_int(1), 2*N-1) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + podarray rwork( static_cast( (std::max)(blas_int(1), 3*N) ) ); + + arma_extra_debug_print("lapack::heev()"); + lapack::heev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(A); + arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigenvalues and eigenvectors of a symmetric real matrix +template +inline +bool +auxlib::eig_sym(Col& eigval, Mat& eigvec, const Mat& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } + + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } + + arma_debug_assert_blas_size(eigvec); + + eigval.set_size(eigvec.n_rows); + + char jobz = 'V'; + char uplo = 'U'; + + blas_int N = blas_int(eigvec.n_rows); + blas_int lwork = (64+2)*N; // lwork_min = (std::max)(blas_int(1), 3*N-1) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + + arma_extra_debug_print("lapack::syev()"); + lapack::syev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info); + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigenvalues and eigenvectors of a hermitian complex matrix +template +inline +bool +auxlib::eig_sym(Col& eigval, Mat< std::complex >& eigvec, const Mat< std::complex >& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } + + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } + + arma_debug_assert_blas_size(eigvec); + + eigval.set_size(eigvec.n_rows); + + char jobz = 'V'; + char uplo = 'U'; + + blas_int N = blas_int(eigvec.n_rows); + blas_int lwork = (64+1)*N; // lwork_min = (std::max)(blas_int(1), 2*N-1) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + podarray rwork( static_cast((std::max)(blas_int(1), 3*N)) ); + + arma_extra_debug_print("lapack::heev()"); + lapack::heev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigenvalues and eigenvectors of a symmetric real matrix (divide and conquer algorithm) +template +inline +bool +auxlib::eig_sym_dc(Col& eigval, Mat& eigvec, const Mat& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } + + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } + + arma_debug_assert_blas_size(eigvec); + + eigval.set_size(eigvec.n_rows); + + char jobz = 'V'; + char uplo = 'U'; + + blas_int N = blas_int(eigvec.n_rows); + blas_int lwork_min = 1 + 6*N + 2*(N*N); + blas_int liwork_min = 3 + 5*N; + blas_int info = 0; + + blas_int lwork_proposed = 0; + blas_int liwork_proposed = 0; + + if(N >= 32) + { + eT work_query[2] = {}; + blas_int iwork_query[2] = {}; + + blas_int lwork_query = -1; + blas_int liwork_query = -1; + + arma_extra_debug_print("lapack::syevd()"); + lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), &work_query[0], &lwork_query, &iwork_query[0], &liwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + liwork_proposed = iwork_query[0]; + } + + blas_int lwork_final = (std::max)( lwork_proposed, lwork_min); + blas_int liwork_final = (std::max)(liwork_proposed, liwork_min); + + podarray work( static_cast( lwork_final) ); + podarray iwork( static_cast(liwork_final) ); + + arma_extra_debug_print("lapack::syevd()"); + lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork_final, iwork.memptr(), &liwork_final, &info); + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigenvalues and eigenvectors of a hermitian complex matrix (divide and conquer algorithm) +template +inline +bool +auxlib::eig_sym_dc(Col& eigval, Mat< std::complex >& eigvec, const Mat< std::complex >& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } + + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } + + arma_debug_assert_blas_size(eigvec); + + eigval.set_size(eigvec.n_rows); + + char jobz = 'V'; + char uplo = 'U'; + + blas_int N = blas_int(eigvec.n_rows); + blas_int lwork_min = 2*N + N*N; + blas_int lrwork_min = 1 + 5*N + 2*(N*N); + blas_int liwork_min = 3 + 5*N; + blas_int info = 0; + + blas_int lwork_proposed = 0; + blas_int lrwork_proposed = 0; + blas_int liwork_proposed = 0; + + if(N >= 32) + { + eT work_query[2] = {}; + T rwork_query[2] = {}; + blas_int iwork_query[2] = {}; + + blas_int lwork_query = -1; + blas_int lrwork_query = -1; + blas_int liwork_query = -1; + + arma_extra_debug_print("lapack::heevd()"); + lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), &work_query[0], &lwork_query, &rwork_query[0], &lrwork_query, &iwork_query[0], &liwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + lrwork_proposed = static_cast( rwork_query[0] ); + liwork_proposed = iwork_query[0]; + } + + blas_int lwork_final = (std::max)( lwork_proposed, lwork_min); + blas_int lrwork_final = (std::max)(lrwork_proposed, lrwork_min); + blas_int liwork_final = (std::max)(liwork_proposed, liwork_min); + + podarray work( static_cast( lwork_final) ); + podarray< T> rwork( static_cast(lrwork_final) ); + podarray iwork( static_cast(liwork_final) ); + + arma_extra_debug_print("lapack::heevd()"); + lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork_final, rwork.memptr(), &lrwork_final, iwork.memptr(), &liwork_final, &info); + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::chol_simple(Mat& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(X); + + char uplo = 'U'; + blas_int n = blas_int(X.n_rows); + blas_int info = 0; + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, X.memptr(), &n, &info); + + return (info == 0); + } + #else + { + arma_ignore(X); + + arma_stop_logic_error("chol(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::chol(Mat& X, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(X); + + char uplo = (layout == 0) ? 'U' : 'L'; + blas_int n = blas_int(X.n_rows); + blas_int info = 0; + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, X.memptr(), &n, &info); + + if(info != 0) { return false; } + + X = (layout == 0) ? trimatu(X) : trimatl(X); // trimatu() and trimatl() return the same type + + return true; + } + #else + { + arma_ignore(X); + arma_ignore(layout); + + arma_stop_logic_error("chol(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::chol_band(Mat& X, const uword KD, const uword layout) + { + arma_extra_debug_sigprint(); + + return auxlib::chol_band_common(X, KD, layout); + } + + + +template +inline +bool +auxlib::chol_band(Mat< std::complex >& X, const uword KD, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::chol_band(): redirecting to auxlib::chol() due to crippled LAPACK"); + + arma_ignore(KD); + + return auxlib::chol(X, layout); + } + #else + { + return auxlib::chol_band_common(X, KD, layout); + } + #endif + } + + + +template +inline +bool +auxlib::chol_band_common(Mat& X, const uword KD, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + const uword N = X.n_rows; + + const uword KL = (layout == 0) ? uword(0) : KD; + const uword KU = (layout == 0) ? KD : uword(0); + + Mat AB; + band_helper::compress(AB, X, KL, KU, false); + + arma_debug_assert_blas_size(AB); + + char uplo = (layout == 0) ? 'U' : 'L'; + blas_int n = blas_int(N); + blas_int kd = blas_int(KD); + blas_int ldab = blas_int(AB.n_rows); + blas_int info = 0; + + arma_extra_debug_print("lapack::pbtrf()"); + lapack::pbtrf(&uplo, &n, &kd, AB.memptr(), &ldab, &info); + + if(info != 0) { return false; } + + band_helper::uncompress(X, AB, KL, KU, false); + + return true; + } + #else + { + arma_ignore(X); + arma_ignore(KD); + arma_ignore(layout); + + arma_stop_logic_error("chol(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::chol_pivot(Mat& X, Mat& P, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename get_pod_type::result T; + + arma_debug_assert_blas_size(X); + + char uplo = (layout == 0) ? 'U' : 'L'; + blas_int n = blas_int(X.n_rows); + blas_int rank = 0; + T tol = T(-1); + blas_int info = 0; + + podarray ipiv( X.n_rows); + podarray work(2*X.n_rows); + + ipiv.zeros(); + + arma_extra_debug_print("lapack::pstrf()"); + lapack::pstrf(&uplo, &n, X.memptr(), &n, ipiv.memptr(), &rank, &tol, work.memptr(), &info); + + if(info != 0) { return false; } + + X = (layout == 0) ? trimatu(X) : trimatl(X); // trimatu() and trimatl() return the same type + + P.set_size(X.n_rows, 1); + + for(uword i=0; i < X.n_rows; ++i) + { + P[i] = uword(ipiv[i] - 1); // take into account that Fortran counts from 1 + } + + return true; + } + #else + { + arma_ignore(X); + arma_ignore(P); + arma_ignore(layout); + + arma_stop_logic_error("chol(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +// +// hessenberg decomposition +template +inline +bool +auxlib::hess(Mat& H, const Base& X, Col& tao) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + H = X.get_ref(); + + arma_debug_check( (H.is_square() == false), "hess(): given matrix must be square sized" ); + + if(H.is_empty()) { return true; } + + arma_debug_assert_blas_size(H); + + if(H.n_rows > 2) + { + tao.set_size(H.n_rows-1); + + blas_int n = blas_int(H.n_rows); + blas_int ilo = 1; + blas_int ihi = blas_int(H.n_rows); + blas_int lda = blas_int(H.n_rows); + blas_int lwork = blas_int(H.n_rows) * 64; + blas_int info = 0; + + podarray work(static_cast(lwork)); + + arma_extra_debug_print("lapack::gehrd()"); + lapack::gehrd(&n, &ilo, &ihi, H.memptr(), &lda, tao.memptr(), work.memptr(), &lwork, &info); + + return (info == 0); + } + + return true; + } + #else + { + arma_ignore(H); + arma_ignore(X); + arma_ignore(tao); + arma_stop_logic_error("hess(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::qr(Mat& Q, Mat& R, const Base& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + R = X.get_ref(); + + const uword R_n_rows = R.n_rows; + const uword R_n_cols = R.n_cols; + + if(R.is_empty()) { Q.eye(R_n_rows, R_n_rows); return true; } + + arma_debug_assert_blas_size(R); + + blas_int m = static_cast(R_n_rows); + blas_int n = static_cast(R_n_cols); + blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() + blas_int k = (std::min)(m,n); + blas_int info = 0; + + podarray tau( static_cast(k) ); + + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + + if(info != 0) { return false; } + + Q.set_size(R_n_rows, R_n_rows); + + arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); + + // + // construct R + + for(uword col=0; col < R_n_cols; ++col) + { + for(uword row=(col+1); row < R_n_rows; ++row) + { + R.at(row,col) = eT(0); + } + } + + + if( (is_float::value) || (is_double::value) ) + { + arma_extra_debug_print("lapack::orgqr()"); + lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + else + if( (is_cx_float::value) || (is_cx_double::value) ) + { + arma_extra_debug_print("lapack::ungqr()"); + lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + + return (info == 0); + } + #else + { + arma_ignore(Q); + arma_ignore(R); + arma_ignore(X); + arma_stop_logic_error("qr(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::qr_econ(Mat& Q, Mat& R, const Base& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(is_Mat::value) + { + const unwrap tmp(X.get_ref()); + const Mat& M = tmp.M; + + if(M.n_rows < M.n_cols) { return auxlib::qr(Q, R, X); } + } + + Q = X.get_ref(); + + const uword Q_n_rows = Q.n_rows; + const uword Q_n_cols = Q.n_cols; + + if( Q_n_rows <= Q_n_cols ) { return auxlib::qr(Q, R, Q); } + + if(Q.is_empty()) { Q.set_size(Q_n_rows, 0); R.set_size(0, Q_n_cols); return true; } + + arma_debug_assert_blas_size(Q); + + blas_int m = static_cast(Q_n_rows); + blas_int n = static_cast(Q_n_cols); + blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() + blas_int k = (std::min)(m,n); + blas_int info = 0; + + podarray tau( static_cast(k) ); + + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + + if(info != 0) { return false; } + + R.zeros(Q_n_cols, Q_n_cols); + + // + // construct R + + for(uword col=0; col < Q_n_cols; ++col) + { + for(uword row=0; row <= col; ++row) + { + R.at(row,col) = Q.at(row,col); + } + } + + if( (is_float::value) || (is_double::value) ) + { + arma_extra_debug_print("lapack::orgqr()"); + lapack::orgqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + else + if( (is_cx_float::value) || (is_cx_double::value) ) + { + arma_extra_debug_print("lapack::ungqr()"); + lapack::ungqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + + return (info == 0); + } + #else + { + arma_ignore(Q); + arma_ignore(R); + arma_ignore(X); + arma_stop_logic_error("qr_econ(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::qr_pivot(Mat& Q, Mat& R, Mat& P, const Base& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + R = X.get_ref(); + + const uword R_n_rows = R.n_rows; + const uword R_n_cols = R.n_cols; + + if(R.is_empty()) + { + Q.eye(R_n_rows, R_n_rows); + + P.set_size(R_n_cols, 1); + + for(uword col=0; col < R_n_cols; ++col) { P.at(col) = col; } + + return true; + } + + arma_debug_assert_blas_size(R); + + blas_int m = static_cast(R_n_rows); + blas_int n = static_cast(R_n_cols); + blas_int lwork_min = (std::max)(blas_int(3*n + 1), (std::max)(m,n)); // take into account requirements of geqp3() and orgqr() + blas_int k = (std::min)(m,n); + blas_int info = 0; + + podarray tau( static_cast(k) ); + podarray jpvt( R_n_cols ); + + jpvt.zeros(); + + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqp3()"); + lapack::geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::geqp3()"); + lapack::geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), work.memptr(), &lwork_final, &info); + + if(info != 0) { return false; } + + Q.set_size(R_n_rows, R_n_rows); + + arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); + + // + // construct R and P + + P.set_size(R_n_cols, 1); + + for(uword col=0; col < R_n_cols; ++col) + { + for(uword row=(col+1); row < R_n_rows; ++row) { R.at(row,col) = eT(0); } + + P.at(col) = jpvt[col] - 1; // take into account that Fortran counts from 1 + } + + arma_extra_debug_print("lapack::orgqr()"); + lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + + return (info == 0); + } + #else + { + arma_ignore(Q); + arma_ignore(R); + arma_ignore(P); + arma_ignore(X); + arma_stop_logic_error("qr(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::qr_pivot(Mat< std::complex >& Q, Mat< std::complex >& R, Mat& P, const Base,T1>& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + R = X.get_ref(); + + const uword R_n_rows = R.n_rows; + const uword R_n_cols = R.n_cols; + + if(R.is_empty()) + { + Q.eye(R_n_rows, R_n_rows); + + P.set_size(R_n_cols, 1); + + for(uword col=0; col < R_n_cols; ++col) { P.at(col) = col; } + + return true; + } + + arma_debug_assert_blas_size(R); + + blas_int m = static_cast(R_n_rows); + blas_int n = static_cast(R_n_cols); + blas_int lwork_min = (std::max)(blas_int(3*n + 1), (std::max)(m,n)); // take into account requirements of geqp3() and ungqr() + blas_int k = (std::min)(m,n); + blas_int info = 0; + + podarray tau( static_cast(k) ); + podarray< T> rwork( 2*R_n_cols ); + podarray jpvt( R_n_cols ); + + jpvt.zeros(); + + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqp3()"); + lapack::cx_geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), &work_query[0], &lwork_query, rwork.memptr(), &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::geqp3()"); + lapack::cx_geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), work.memptr(), &lwork_final, rwork.memptr(), &info); + + if(info != 0) { return false; } + + Q.set_size(R_n_rows, R_n_rows); + + arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); + + // + // construct R and P + + P.set_size(R_n_cols, 1); + + for(uword col=0; col < R_n_cols; ++col) + { + for(uword row=(col+1); row < R_n_rows; ++row) { R.at(row,col) = eT(0); } + + P.at(col) = jpvt[col] - 1; // take into account that Fortran counts from 1 + } + + arma_extra_debug_print("lapack::ungqr()"); + lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + + return (info == 0); + } + #else + { + arma_ignore(Q); + arma_ignore(R); + arma_ignore(P); + arma_ignore(X); + arma_stop_logic_error("qr(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd(Col& S, Mat& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(A.is_empty()) { S.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, A.n_cols, arma_nozeros_indicator()); + + char jobu = 'N'; + char jobvt = 'N'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, &info); + + return (info == 0); + } + #else + { + arma_ignore(S); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd(Col& S, Mat< std::complex >& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef std::complex eT; + + if(A.is_empty()) { S.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, A.n_cols, arma_nozeros_indicator()); + + char jobu = 'N'; + char jobvt = 'N'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), 2*min_mn+(std::max)(m,n) ); + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + podarray rwork( static_cast(5*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), &info); + + return (info == 0); + } + #else + { + arma_ignore(S); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd(Mat& U, Col& S, Mat& V, Mat& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + U.set_size(A.n_rows, A.n_rows); + V.set_size(A.n_cols, A.n_cols); + + char jobu = 'A'; + char jobvt = 'A'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + // query to find optimum size of workspace + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, &info); + + if(info != 0) { return false; } + + op_strans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef std::complex eT; + + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + U.set_size(A.n_rows, A.n_rows); + V.set_size(A.n_cols, A.n_cols); + + char jobu = 'A'; + char jobvt = 'A'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), 2*min_mn + (std::max)(m,n) ); + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + podarray rwork( static_cast(5*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), &info); + + if(info != 0) { return false; } + + op_htrans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_econ(Mat& U, Col& S, Mat& V, Mat& A, const char mode) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(A.is_empty()) { U.eye(); S.reset(); V.eye(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + + S.set_size( static_cast(min_mn) ); + + blas_int ldu = 0; + blas_int ldvt = 0; + + char jobu = char(0); + char jobvt = char(0); + + if(mode == 'l') + { + jobu = 'S'; + jobvt = 'N'; + + ldu = m; + ldvt = 1; + + U.set_size( static_cast(ldu), static_cast(min_mn) ); + V.reset(); + } + + if(mode == 'r') + { + jobu = 'N'; + jobvt = 'S'; + + ldu = 1; + ldvt = (std::min)(m,n); + + U.reset(); + V.set_size( static_cast(ldvt), static_cast(n) ); + } + + if(mode == 'b') + { + jobu = 'S'; + jobvt = 'S'; + + ldu = m; + ldvt = (std::min)(m,n); + + U.set_size( static_cast(ldu), static_cast(min_mn) ); + V.set_size( static_cast(ldvt), static_cast(n ) ); + } + + + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast(work_query[0]); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, &info); + + if(info != 0) { return false; } + + op_strans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_ignore(mode); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A, const char mode) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef std::complex eT; + + if(A.is_empty()) { U.eye(); S.reset(); V.eye(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + + S.set_size( static_cast(min_mn) ); + + blas_int ldu = 0; + blas_int ldvt = 0; + + char jobu = char(0); + char jobvt = char(0); + + if(mode == 'l') + { + jobu = 'S'; + jobvt = 'N'; + + ldu = m; + ldvt = 1; + + U.set_size( static_cast(ldu), static_cast(min_mn) ); + V.reset(); + } + + if(mode == 'r') + { + jobu = 'N'; + jobvt = 'S'; + + ldu = 1; + ldvt = (std::min)(m,n); + + U.reset(); + V.set_size( static_cast(ldvt), static_cast(n) ); + } + + if(mode == 'b') + { + jobu = 'S'; + jobvt = 'S'; + + ldu = m; + ldvt = (std::min)(m,n); + + U.set_size( static_cast(ldu), static_cast(min_mn) ); + V.set_size( static_cast(ldvt), static_cast(n) ); + } + + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; + + podarray rwork( static_cast(5*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), &info); + + if(info != 0) { return false; } + + op_htrans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_ignore(mode); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_dc(Col& S, Mat& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(A.is_empty()) { S.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, 1, arma_nozeros_indicator()); + + char jobz = 'N'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = 3*min_mn + (std::max)( max_mn, 7*min_mn ); + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + podarray iwork( static_cast(8*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, iwork.memptr(), &info); + + return (info == 0); + } + #else + { + arma_ignore(S); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_dc(Col& S, Mat< std::complex >& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef std::complex eT; + + if(A.is_empty()) { S.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, 1, arma_nozeros_indicator()); + + char jobz = 'N'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = 2*min_mn + max_mn; + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + podarray rwork( static_cast(7*min_mn) ); // from LAPACK 3.8 docs: LAPACK <= v3.6 needs 7*mn + podarray iwork( static_cast(8*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); + + return (info == 0); + } + #else + { + arma_ignore(S); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_dc(Mat& U, Col& S, Mat& V, Mat& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + U.set_size(A.n_rows, A.n_rows); + V.set_size(A.n_cols, A.n_cols); + + char jobz = 'A'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork1 = 3*min_mn*min_mn + (std::max)(max_mn, 4*min_mn*min_mn + 4*min_mn); // as per LAPACK 3.2 docs + blas_int lwork2 = 4*min_mn*min_mn + 6*min_mn + max_mn; // as per LAPACK 3.8 docs; consistent with LAPACK 3.4 docs + blas_int lwork_min = (std::max)(lwork1, lwork2); // due to differences between LAPACK 3.2 and 3.8 + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + podarray iwork( static_cast(8*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast(work_query[0]); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, iwork.memptr(), &info); + + if(info != 0) { return false; } + + op_strans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef std::complex eT; + + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + U.set_size(A.n_rows, A.n_rows); + V.set_size(A.n_cols, A.n_cols); + + char jobz = 'A'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = min_mn*min_mn + 2*min_mn + max_mn; // as per LAPACK 3.2, 3.4, 3.8 docs + blas_int lrwork = min_mn * ((std::max)(5*min_mn+7, 2*max_mn + 2*min_mn+1)); // as per LAPACK 3.4 docs; LAPACK 3.8 uses 5*min_mn+5 instead of 5*min_mn+7 + blas_int info = 0; + + S.set_size( static_cast(min_mn) ); + + podarray rwork( static_cast(lrwork ) ); + podarray iwork( static_cast(8*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + op_htrans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_dc_econ(Mat& U, Col& S, Mat& V, Mat& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + char jobz = 'S'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = m; + blas_int ldvt = min_mn; + blas_int lwork1 = 3*min_mn*min_mn + (std::max)( max_mn, 4*min_mn*min_mn + 4*min_mn ); // as per LAPACK 3.2 docs + blas_int lwork2 = 4*min_mn*min_mn + 6*min_mn + max_mn; // as per LAPACK 3.4 docs; LAPACK 3.8 requires 4*min_mn*min_mn + 7*min_mn + blas_int lwork_min = (std::max)(lwork1, lwork2); // due to differences between LAPACK 3.2 and 3.4 + blas_int info = 0; + + if(A.is_empty()) + { + U.eye(); + S.reset(); + V.eye( static_cast(n), static_cast(min_mn) ); + return true; + } + + S.set_size( static_cast(min_mn) ); + + U.set_size( static_cast(m), static_cast(min_mn) ); + + V.set_size( static_cast(min_mn), static_cast(n) ); + + podarray iwork( static_cast(8*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast(work_query[0]); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, iwork.memptr(), &info); + + if(info != 0) { return false; } + + op_strans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef std::complex eT; + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + char jobz = 'S'; + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = m; + blas_int ldvt = min_mn; + blas_int lwork_min = min_mn*min_mn + 2*min_mn + max_mn; // as per LAPACK 3.2 docs + blas_int lrwork = min_mn * ((std::max)(5*min_mn+7, 2*max_mn + 2*min_mn+1)); // LAPACK 3.8 uses 5*min_mn+5 instead of 5*min_mn+7 + blas_int info = 0; + + if(A.is_empty()) + { + U.eye(); + S.reset(); + V.eye( static_cast(n), static_cast(min_mn) ); + return true; + } + + S.set_size( static_cast(min_mn) ); + + U.set_size( static_cast(m), static_cast(min_mn) ); + + V.set_size( static_cast(min_mn), static_cast(n) ); + + podarray rwork( static_cast(lrwork ) ); + podarray iwork( static_cast(8*min_mn) ); + + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + op_htrans::apply_mat_inplace(V); + + return true; + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(V); + arma_ignore(A); + arma_stop_logic_error("svd(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition +template +inline +bool +auxlib::solve_square_fast(Mat& out, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + + arma_debug_assert_blas_size(A); + + blas_int n = blas_int(A.n_rows); // assuming A is square + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(B_n_rows); + blas_int nrhs = blas_int(B_n_cols); + blas_int info = blas_int(0); + + podarray ipiv(A.n_rows + 2); // +2 for paranoia: some versions of Lapack might be trashing memory + + arma_extra_debug_print("lapack::gesv()"); + lapack::gesv(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition with rcond estimate +template +inline +bool +auxlib::solve_square_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out_rcond = T(0); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + char trans = 'N'; + blas_int n = blas_int(A.n_rows); // assuming A is square + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(B_n_rows); + blas_int nrhs = blas_int(B_n_cols); + blas_int info = blas_int(0); + T norm_val = T(0); + + podarray junk(1); + podarray ipiv(A.n_rows + 2); // +2 for paranoia + + arma_extra_debug_print("lapack::lange()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &n, &n, A.memptr(), &lda, junk.memptr()); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n, &n, A.memptr(), &n, ipiv.memptr(), &info); + + if(info != blas_int(0)) { return false; } + + arma_extra_debug_print("lapack::getrs()"); + lapack::getrs(&trans, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + if(info != blas_int(0)) { return false; } + + out_rcond = auxlib::lu_rcond(A, norm_val); + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition with refinement (real matrices) +template +inline +bool +auxlib::solve_square_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + // Mat B = B_expr.get_ref(); // B is overwritten by lapack::gesvx() if equilibrate is enabled + + quasi_unwrap UB(B_expr.get_ref()); // deliberately not declaring as const + + const Mat& UB_M_as_Mat = UB.M; // so we don't confuse the ?: operator below + + const bool use_copy = ((equilibrate && UB.is_const) || UB.is_alias(out)); + + Mat B_tmp; if(use_copy) { B_tmp = UB_M_as_Mat; } + + const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } + + arma_debug_assert_blas_size(A,B); + + out.set_size(A.n_rows, B.n_cols); + + char fact = (equilibrate) ? 'E' : 'N'; + char trans = 'N'; + char equed = char(0); + blas_int n = blas_int(A.n_rows); + blas_int nrhs = blas_int(B.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldaf = blas_int(A.n_rows); + blas_int ldb = blas_int(A.n_rows); + blas_int ldx = blas_int(A.n_rows); + blas_int info = blas_int(0); + eT rcond = eT(0); + + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); + + podarray IPIV( A.n_rows); + podarray R( A.n_rows); + podarray C( A.n_rows); + podarray FERR( B.n_cols); + podarray BERR( B.n_cols); + podarray WORK(4*A.n_rows); + podarray IWORK( A.n_rows); + + arma_extra_debug_print("lapack::gesvx()"); + lapack::gesvx + ( + &fact, &trans, &n, &nrhs, + A.memptr(), &lda, + AF.memptr(), &ldaf, + IPIV.memptr(), + &equed, + R.memptr(), + C.memptr(), + const_cast(B.memptr()), &ldb, + out.memptr(), &ldx, + &rcond, + FERR.memptr(), + BERR.memptr(), + WORK.memptr(), + IWORK.memptr(), + &info + ); + + // NOTE: using const_cast(B.memptr()) to allow B to be overwritten for equilibration; + // NOTE: B is created as a copy of B_expr if equilibration is enabled; otherwise B is a reference to B_expr + + out_rcond = rcond; + + return ((info == 0) || (info == (n+1))); + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_ignore(equilibrate); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition with refinement (complex matrices) +template +inline +bool +auxlib::solve_square_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + // Mat B = B_expr.get_ref(); // B is overwritten by lapack::cx_gesvx() if equilibrate is enabled + + quasi_unwrap UB(B_expr.get_ref()); // deliberately not declaring as const + + const Mat& UB_M_as_Mat = UB.M; // so we don't confuse the ?: operator below + + const bool use_copy = ((equilibrate && UB.is_const) || UB.is_alias(out)); + + Mat B_tmp; if(use_copy) { B_tmp = UB_M_as_Mat; } + + const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } + + arma_debug_assert_blas_size(A,B); + + out.set_size(A.n_rows, B.n_cols); + + char fact = (equilibrate) ? 'E' : 'N'; + char trans = 'N'; + char equed = char(0); + blas_int n = blas_int(A.n_rows); + blas_int nrhs = blas_int(B.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldaf = blas_int(A.n_rows); + blas_int ldb = blas_int(A.n_rows); + blas_int ldx = blas_int(A.n_rows); + blas_int info = blas_int(0); + T rcond = T(0); + + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); + + podarray IPIV( A.n_rows); + podarray< T> R( A.n_rows); + podarray< T> C( A.n_rows); + podarray< T> FERR( B.n_cols); + podarray< T> BERR( B.n_cols); + podarray WORK(2*A.n_rows); + podarray< T> RWORK(2*A.n_rows); + + arma_extra_debug_print("lapack::cx_gesvx()"); + lapack::cx_gesvx + ( + &fact, &trans, &n, &nrhs, + A.memptr(), &lda, + AF.memptr(), &ldaf, + IPIV.memptr(), + &equed, + R.memptr(), + C.memptr(), + const_cast(B.memptr()), &ldb, + out.memptr(), &ldx, + &rcond, + FERR.memptr(), + BERR.memptr(), + WORK.memptr(), + RWORK.memptr(), + &info + ); + + // NOTE: using const_cast(B.memptr()) to allow B to be overwritten for equilibration; + // NOTE: B is created as a copy of B_expr if equilibration is enabled; otherwise B is a reference to B_expr + + out_rcond = rcond; + + return ((info == 0) || (info == (n+1))); + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_ignore(equilibrate); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_sympd_fast(Mat& out, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::solve_sympd_fast(): redirecting to auxlib::solve_square_fast() due to crippled LAPACK"); + + return auxlib::solve_square_fast(out, A, B_expr); + } + #else + { + return auxlib::solve_sympd_fast_common(out, A, B_expr); + } + #endif + } + + + +template +inline +bool +auxlib::solve_sympd_fast_common(Mat& out, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + + arma_debug_assert_blas_size(A, out); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); // assuming A is square + blas_int nrhs = blas_int(B_n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(B_n_rows); + blas_int info = blas_int(0); + + arma_extra_debug_print("lapack::posv()"); + lapack::posv(&uplo, &n, &nrhs, A.memptr(), &lda, out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via Cholesky decomposition with rcond estimate (real matrices) +template +inline +bool +auxlib::solve_sympd_rcond(Mat& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out_sympd_state = false; + out_rcond = T(0); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + arma_debug_assert_blas_size(A, out); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); // assuming A is square + blas_int nrhs = blas_int(B_n_cols); + blas_int info = blas_int(0); + T norm_val = T(0); + + podarray work(A.n_rows); + + arma_extra_debug_print("lapack::lansy()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + out_sympd_state = true; + + arma_extra_debug_print("lapack::potrs()"); + lapack::potrs(&uplo, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); + + if(info != 0) { return false; } + + out_rcond = auxlib::lu_rcond_sympd(A, norm_val); + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via Cholesky decomposition with rcond estimate (complex matrices) +template +inline +bool +auxlib::solve_sympd_rcond(Mat< std::complex >& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::solve_sympd_rcond(): redirecting to auxlib::solve_square_rcond() due to crippled LAPACK"); + + out_sympd_state = false; + + return auxlib::solve_square_rcond(out, out_rcond, A, B_expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + out_sympd_state = false; + out_rcond = T(0); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + arma_debug_assert_blas_size(A, out); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); // assuming A is square + blas_int nrhs = blas_int(B_n_cols); + blas_int info = blas_int(0); + T norm_val = T(0); + + podarray work(A.n_rows); + + arma_extra_debug_print("lapack::lanhe()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + out_sympd_state = true; + + arma_extra_debug_print("lapack::potrs()"); + lapack::potrs(&uplo, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); + + if(info != 0) { return false; } + + out_rcond = auxlib::lu_rcond_sympd(A, norm_val); + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via Cholesky decomposition with refinement (real matrices) +template +inline +bool +auxlib::solve_sympd_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + // Mat B = B_expr.get_ref(); // B is overwritten by lapack::posvx() if equilibrate is enabled + + quasi_unwrap UB(B_expr.get_ref()); // deliberately not declaring as const + + const Mat& UB_M_as_Mat = UB.M; // so we don't confuse the ?: operator below + + const bool use_copy = ((equilibrate && UB.is_const) || UB.is_alias(out)); + + Mat B_tmp; if(use_copy) { B_tmp = UB_M_as_Mat; } + + const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } + + arma_debug_assert_blas_size(A,B); + + out.set_size(A.n_rows, B.n_cols); + + char fact = (equilibrate) ? 'E' : 'N'; + char uplo = 'L'; + char equed = char(0); + blas_int n = blas_int(A.n_rows); + blas_int nrhs = blas_int(B.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldaf = blas_int(A.n_rows); + blas_int ldb = blas_int(A.n_rows); + blas_int ldx = blas_int(A.n_rows); + blas_int info = blas_int(0); + eT rcond = eT(0); + + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); + + podarray S( A.n_rows); + podarray FERR( B.n_cols); + podarray BERR( B.n_cols); + podarray WORK(3*A.n_rows); + podarray IWORK( A.n_rows); + + arma_extra_debug_print("lapack::posvx()"); + lapack::posvx(&fact, &uplo, &n, &nrhs, A.memptr(), &lda, AF.memptr(), &ldaf, &equed, S.memptr(), const_cast(B.memptr()), &ldb, out.memptr(), &ldx, &rcond, FERR.memptr(), BERR.memptr(), WORK.memptr(), IWORK.memptr(), &info); + + // NOTE: using const_cast(B.memptr()) to allow B to be overwritten for equilibration; + // NOTE: B is created as a copy of B_expr if equilibration is enabled; otherwise B is a reference to B_expr + + // NOTE: lapack::posvx() sets rcond to zero if A is not sympd + out_rcond = rcond; + + return ((info == 0) || (info == (n+1))); + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_ignore(equilibrate); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via Cholesky decomposition with refinement (complex matrices) +template +inline +bool +auxlib::solve_sympd_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::solve_sympd_refine(): redirecting to auxlib::solve_square_refine() due to crippled LAPACK"); + + return auxlib::solve_square_refine(out, out_rcond, A, B_expr, equilibrate); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + // Mat B = B_expr.get_ref(); // B is overwritten by lapack::cx_posvx() if equilibrate is enabled + + quasi_unwrap UB(B_expr.get_ref()); // deliberately not declaring as const + + const Mat& UB_M_as_Mat = UB.M; // so we don't confuse the ?: operator below + + const bool use_copy = ((equilibrate && UB.is_const) || UB.is_alias(out)); + + Mat B_tmp; if(use_copy) { B_tmp = UB_M_as_Mat; } + + const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } + + arma_debug_assert_blas_size(A,B); + + out.set_size(A.n_rows, B.n_cols); + + char fact = (equilibrate) ? 'E' : 'N'; + char uplo = 'L'; + char equed = char(0); + blas_int n = blas_int(A.n_rows); + blas_int nrhs = blas_int(B.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldaf = blas_int(A.n_rows); + blas_int ldb = blas_int(A.n_rows); + blas_int ldx = blas_int(A.n_rows); + blas_int info = blas_int(0); + T rcond = T(0); + + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); + + podarray< T> S( A.n_rows); + podarray< T> FERR( B.n_cols); + podarray< T> BERR( B.n_cols); + podarray WORK(2*A.n_rows); + podarray< T> RWORK( A.n_rows); + + arma_extra_debug_print("lapack::cx_posvx()"); + lapack::cx_posvx(&fact, &uplo, &n, &nrhs, A.memptr(), &lda, AF.memptr(), &ldaf, &equed, S.memptr(), const_cast(B.memptr()), &ldb, out.memptr(), &ldx, &rcond, FERR.memptr(), BERR.memptr(), WORK.memptr(), RWORK.memptr(), &info); + + // NOTE: using const_cast(B.memptr()) to allow B to be overwritten for equilibration; + // NOTE: B is created as a copy of B_expr if equilibration is enabled; otherwise B is a reference to B_expr + + // NOTE: lapack::cx_posvx() sets rcond to zero if A is not sympd + out_rcond = rcond; + + return ((info == 0) || (info == (n+1))); + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_ignore(equilibrate); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a non-square full-rank system via QR or LQ decomposition +template +inline +bool +auxlib::solve_rect_fast(Mat& out, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + + const unwrap U(B_expr.get_ref()); + const Mat& B = U.M; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } + + arma_debug_assert_blas_size(A,B); + + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); + + if(arma::size(tmp) == arma::size(B)) + { + tmp = B; + } + else + { + tmp.zeros(); + tmp(0,0, arma::size(B)) = B; + } + + char trans = 'N'; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(tmp.n_rows); + blas_int nrhs = blas_int(B.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lwork_min = (std::max)(blas_int(1), min_mn + (std::max)(min_mn, nrhs)); + blas_int info = 0; + + blas_int lwork_proposed = 0; + + if(A.n_elem >= ((is_cx::yes) ? uword(256) : uword(1024))) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gels()"); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, &work_query[0], &lwork_query, &info ); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gels()"); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork_final, &info ); + + if(info != 0) { return false; } + + if(tmp.n_rows == A.n_cols) + { + out.steal_mem(tmp); + } + else + { + out = tmp.head_rows(A.n_cols); + } + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a non-square full-rank system via QR or LQ decomposition with rcond estimate (experimental) +template +inline +bool +auxlib::solve_rect_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out_rcond = T(0); + + const unwrap U(B_expr.get_ref()); + const Mat& B = U.M; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } + + arma_debug_assert_blas_size(A,B); + + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); + + if(arma::size(tmp) == arma::size(B)) + { + tmp = B; + } + else + { + tmp.zeros(); + tmp(0,0, arma::size(B)) = B; + } + + char trans = 'N'; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(tmp.n_rows); + blas_int nrhs = blas_int(B.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lwork_min = (std::max)(blas_int(1), min_mn + (std::max)(min_mn, nrhs)); + blas_int info = 0; + + blas_int lwork_proposed = 0; + + if(A.n_elem >= ((is_cx::yes) ? uword(256) : uword(1024))) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gels()"); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, &work_query[0], &lwork_query, &info ); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gels()"); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork_final, &info ); + + if(info != 0) { return false; } + + if(A.n_rows >= A.n_cols) + { + arma_extra_debug_print("estimating rcond via R"); + + // xGELS docs: for M >= N, A contains details of its QR decomposition as returned by xGEQRF + // xGEQRF docs: elements on and above the diagonal contain the min(M,N)-by-N upper trapezoidal matrix R + + Mat R(A.n_cols, A.n_cols, arma_zeros_indicator()); + + for(uword col=0; col < A.n_cols; ++col) + { + for(uword row=0; row <= col; ++row) + { + R.at(row,col) = A.at(row,col); + } + } + + // determine quality of solution + out_rcond = auxlib::rcond_trimat(R, 0); // 0: upper triangular; 1: lower triangular + } + else + if(A.n_rows < A.n_cols) + { + arma_extra_debug_print("estimating rcond via L"); + + // xGELS docs: for M < N, A contains details of its LQ decomposition as returned by xGELQF + // xGELQF docs: elements on and below the diagonal contain the M-by-min(M,N) lower trapezoidal matrix L + + Mat L(A.n_rows, A.n_rows, arma_zeros_indicator()); + + for(uword col=0; col < A.n_rows; ++col) + { + for(uword row=col; row < A.n_rows; ++row) + { + L.at(row,col) = A.at(row,col); + } + } + + // determine quality of solution + out_rcond = auxlib::rcond_trimat(L, 1); // 0: upper triangular; 1: lower triangular + } + + if(tmp.n_rows == A.n_cols) + { + out.steal_mem(tmp); + } + else + { + out = tmp.head_rows(A.n_cols); + } + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_approx_svd(Mat& out, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + const unwrap U(B_expr.get_ref()); + const Mat& B = U.M; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A,B); + + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); + + if(arma::size(tmp) == arma::size(B)) + { + tmp = B; + } + else + { + tmp.zeros(); + tmp(0,0, arma::size(B)) = B; + } + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m, n); + blas_int nrhs = blas_int(B.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(tmp.n_rows); + //eT rcond = eT(-1); // -1 means "use machine precision" + eT rcond = (std::max)(A.n_rows, A.n_cols) * std::numeric_limits::epsilon(); + blas_int rank = blas_int(0); + blas_int info = blas_int(0); + + podarray S( static_cast(min_mn) ); + + // NOTE: with LAPACK 3.8, can use the workspace query to also obtain liwork, + // NOTE: which makes the call to lapack::laenv() redundant + + blas_int ispec = blas_int(9); + + const char* const_name = (is_float::value) ? "SGELSD" : "DGELSD"; + const char* const_opts = " "; + + char* name = const_cast(const_name); + char* opts = const_cast(const_opts); + + blas_int n1 = m; + blas_int n2 = n; + blas_int n3 = nrhs; + blas_int n4 = lda; + + blas_int laenv_result = (arma_config::hidden_args) ? blas_int(lapack::laenv(&ispec, name, opts, &n1, &n2, &n3, &n4, 6, 1)) : blas_int(0); + + blas_int smlsiz = (std::max)( blas_int(25), laenv_result ); + blas_int smlsiz_p1 = blas_int(1) + smlsiz; + + blas_int nlvl = (std::max)( blas_int(0), blas_int(1) + blas_int( std::log2( double(min_mn)/double(smlsiz_p1) ) ) ); + blas_int liwork = (std::max)( blas_int(1), (blas_int(3)*min_mn*nlvl + blas_int(11)*min_mn) ); + + podarray iwork( static_cast(liwork) ); + + blas_int lwork_min = blas_int(12)*min_mn + blas_int(2)*min_mn*smlsiz + blas_int(8)*min_mn*nlvl + min_mn*nrhs + smlsiz_p1*smlsiz_p1; + + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::gelsd()"); + lapack::gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, &work_query[0], &lwork_query, iwork.memptr(), &info); + + if(info != 0) { return false; } + + // NOTE: in LAPACK 3.8, iwork[0] returns the minimum liwork + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gelsd()"); + lapack::gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, work.memptr(), &lwork_final, iwork.memptr(), &info); + + if(info != 0) { return false; } + + if(tmp.n_rows == A.n_cols) + { + out.steal_mem(tmp); + } + else + { + out = tmp.head_rows(A.n_cols); + } + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_approx_svd(Mat< std::complex >& out, Mat< std::complex >& A, const Base,T1>& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + const unwrap U(B_expr.get_ref()); + const Mat& B = U.M; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A,B); + + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); + + if(arma::size(tmp) == arma::size(B)) + { + tmp = B; + } + else + { + tmp.zeros(); + tmp(0,0, arma::size(B)) = B; + } + + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m, n); + blas_int nrhs = blas_int(B.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(tmp.n_rows); + //T rcond = T(-1); // -1 means "use machine precision" + T rcond = (std::max)(A.n_rows, A.n_cols) * std::numeric_limits::epsilon(); + blas_int rank = blas_int(0); + blas_int info = blas_int(0); + + podarray S( static_cast(min_mn) ); + + blas_int ispec = blas_int(9); + + const char* const_name = (is_float::value) ? "CGELSD" : "ZGELSD"; + const char* const_opts = " "; + + char* name = const_cast(const_name); + char* opts = const_cast(const_opts); + + blas_int n1 = m; + blas_int n2 = n; + blas_int n3 = nrhs; + blas_int n4 = lda; + + blas_int laenv_result = (arma_config::hidden_args) ? blas_int(lapack::laenv(&ispec, name, opts, &n1, &n2, &n3, &n4, 6, 1)) : blas_int(0); + + blas_int smlsiz = (std::max)( blas_int(25), laenv_result ); + blas_int smlsiz_p1 = blas_int(1) + smlsiz; + + blas_int nlvl = (std::max)( blas_int(0), blas_int(1) + blas_int( std::log2( double(min_mn)/double(smlsiz_p1) ) ) ); + + blas_int lrwork = (m >= n) + ? blas_int(10)*n + blas_int(2)*n*smlsiz + blas_int(8)*n*nlvl + blas_int(3)*smlsiz*nrhs + (std::max)( (smlsiz_p1)*(smlsiz_p1), n*(blas_int(1)+nrhs) + blas_int(2)*nrhs ) + : blas_int(10)*m + blas_int(2)*m*smlsiz + blas_int(8)*m*nlvl + blas_int(3)*smlsiz*nrhs + (std::max)( (smlsiz_p1)*(smlsiz_p1), n*(blas_int(1)+nrhs) + blas_int(2)*nrhs ); + + blas_int liwork = (std::max)( blas_int(1), (blas_int(3)*blas_int(min_mn)*nlvl + blas_int(11)*blas_int(min_mn)) ); + + podarray rwork( static_cast(lrwork) ); + podarray iwork( static_cast(liwork) ); + + blas_int lwork_min = 2*min_mn + min_mn*nrhs; + + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::cx_gelsd()"); + lapack::cx_gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real( work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::cx_gelsd()"); + lapack::cx_gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + if(tmp.n_rows == A.n_cols) + { + out.steal_mem(tmp); + } + else + { + out = tmp.head_rows(A.n_cols); + } + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_trimat_fast(Mat& out, const Mat& A, const Base& B_expr, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + arma_debug_assert_blas_size(A,out); + + char uplo = (layout == 0) ? 'U' : 'L'; + char trans = 'N'; + char diag = 'N'; + blas_int n = blas_int(A.n_rows); + blas_int nrhs = blas_int(B_n_cols); + blas_int info = 0; + + arma_extra_debug_print("lapack::trtrs()"); + lapack::trtrs(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); + + return (info == 0); + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(B_expr); + arma_ignore(layout); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_trimat_rcond(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const Base& B_expr, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + + out_rcond = T(0); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + arma_debug_assert_blas_size(A,out); + + char uplo = (layout == 0) ? 'U' : 'L'; + char trans = 'N'; + char diag = 'N'; + blas_int n = blas_int(A.n_rows); + blas_int nrhs = blas_int(B_n_cols); + blas_int info = 0; + + arma_extra_debug_print("lapack::trtrs()"); + lapack::trtrs(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); + + if(info != 0) { return false; } + + // determine quality of solution + out_rcond = auxlib::rcond_trimat(A, layout); + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(B_expr); + arma_ignore(layout); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition (real band matrix) +template +inline +bool +auxlib::solve_band_fast(Mat& out, Mat& A, const uword KL, const uword KU, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + return auxlib::solve_band_fast_common(out, A, KL, KU, B_expr); + } + + + +//! solve a system of linear equations via LU decomposition (complex band matrix) +template +inline +bool +auxlib::solve_band_fast(Mat< std::complex >& out, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::solve_band_fast(): redirecting to auxlib::solve_square_fast() due to crippled LAPACK"); + + arma_ignore(KL); + arma_ignore(KU); + + return auxlib::solve_square_fast(out, A, B_expr); + } + #else + { + return auxlib::solve_band_fast_common(out, A, KL, KU, B_expr); + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition (band matrix) +template +inline +bool +auxlib::solve_band_fast_common(Mat& out, const Mat& A, const uword KL, const uword KU, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_rows, B_n_cols); return true; } + + // for gbsv, matrix AB size: 2*KL+KU+1 x N; band representation of A stored in rows KL+1 to 2*KL+KU+1 (note: fortran counts from 1) + + Mat AB; + band_helper::compress(AB, A, KL, KU, true); + + const uword N = AB.n_cols; // order of the original square matrix A + + arma_debug_assert_blas_size(AB,out); + + blas_int n = blas_int(N); + blas_int kl = blas_int(KL); + blas_int ku = blas_int(KU); + blas_int nrhs = blas_int(B_n_cols); + blas_int ldab = blas_int(AB.n_rows); + blas_int ldb = blas_int(B_n_rows); + blas_int info = blas_int(0); + + podarray ipiv(N + 2); // +2 for paranoia + + // NOTE: AB is overwritten + + arma_extra_debug_print("lapack::gbsv()"); + lapack::gbsv(&n, &kl, &ku, &nrhs, AB.memptr(), &ldab, ipiv.memptr(), out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(KL); + arma_ignore(KU); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition (real band matrix) +template +inline +bool +auxlib::solve_band_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + return auxlib::solve_band_rcond_common(out, out_rcond, A, KL, KU, B_expr); + } + + + +//! solve a system of linear equations via LU decomposition (complex band matrix) +template +inline +bool +auxlib::solve_band_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::solve_band_rcond(): redirecting to auxlib::solve_square_rcond() due to crippled LAPACK"); + + arma_ignore(KL); + arma_ignore(KU); + + return auxlib::solve_square_rcond(out, out_rcond, A, B_expr); + } + #else + { + return auxlib::solve_band_rcond_common(out, out_rcond, A, KL, KU, B_expr); + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition (band matrix) +template +inline +bool +auxlib::solve_band_rcond_common(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const uword KL, const uword KU, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out_rcond = T(0); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_rows, B_n_cols); return true; } + + // for gbtrf, matrix AB size: 2*KL+KU+1 x N; band representation of A stored in rows KL+1 to 2*KL+KU+1 (note: fortran counts from 1) + + Mat AB; + band_helper::compress(AB, A, KL, KU, true); + + const uword N = AB.n_cols; // order of the original square matrix A + + arma_debug_assert_blas_size(AB,out); + + //char norm_id = '1'; + char trans = 'N'; + blas_int n = blas_int(N); // assuming square matrix + blas_int kl = blas_int(KL); + blas_int ku = blas_int(KU); + blas_int nrhs = blas_int(B_n_cols); + blas_int ldab = blas_int(AB.n_rows); + blas_int ldb = blas_int(B_n_rows); + blas_int info = blas_int(0); + T norm_val = T(0); + + //podarray junk(1); + podarray ipiv(N + 2); // +2 for paranoia + + // // NOTE: lapack::langb() and lapack::gbtrf() use incompatible storage formats for the band matrix + // arma_extra_debug_print("lapack::langb()"); + // norm_val = lapack::langb(&norm_id, &n, &kl, &ku, AB.memptr(), &ldab, junk.memptr()); + + norm_val = auxlib::norm1_band(A,KL,KU); + + arma_extra_debug_print("lapack::gbtrf()"); + lapack::gbtrf(&n, &n, &kl, &ku, AB.memptr(), &ldab, ipiv.memptr(), &info); + + if(info != 0) { return false; } + + arma_extra_debug_print("lapack::gbtrs()"); + lapack::gbtrs(&trans, &n, &kl, &ku, &nrhs, AB.memptr(), &ldab, ipiv.memptr(), out.memptr(), &ldb, &info); + + if(info != 0) { return false; } + + out_rcond = auxlib::lu_rcond_band(AB, KL, KU, ipiv, norm_val); + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(KL); + arma_ignore(KU); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition with refinement (real band matrices) +template +inline +bool +auxlib::solve_band_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool equilibrate) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + Mat B = B_expr.get_ref(); // B is overwritten + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } + + // for gbsvx, matrix AB size: KL+KU+1 x N; band representation of A stored in rows 1 to KL+KU+1 (note: fortran counts from 1) + + Mat AB; + band_helper::compress(AB, A, KL, KU, false); + + const uword N = AB.n_cols; + + arma_debug_assert_blas_size(AB,B); + + out.set_size(N, B.n_cols); + + Mat AFB(2*KL+KU+1, N, arma_nozeros_indicator()); + + char fact = (equilibrate) ? 'E' : 'N'; + char trans = 'N'; + char equed = char(0); + blas_int n = blas_int(N); + blas_int kl = blas_int(KL); + blas_int ku = blas_int(KU); + blas_int nrhs = blas_int(B.n_cols); + blas_int ldab = blas_int(AB.n_rows); + blas_int ldafb = blas_int(AFB.n_rows); + blas_int ldb = blas_int(B.n_rows); + blas_int ldx = blas_int(N); + blas_int info = blas_int(0); + eT rcond = eT(0); + + podarray IPIV( N); + podarray R( N); + podarray C( N); + podarray FERR( B.n_cols); + podarray BERR( B.n_cols); + podarray WORK(3*N); + podarray IWORK( N); + + arma_extra_debug_print("lapack::gbsvx()"); + lapack::gbsvx + ( + &fact, &trans, &n, &kl, &ku, &nrhs, + AB.memptr(), &ldab, + AFB.memptr(), &ldafb, + IPIV.memptr(), + &equed, + R.memptr(), + C.memptr(), + B.memptr(), &ldb, + out.memptr(), &ldx, + &rcond, + FERR.memptr(), + BERR.memptr(), + WORK.memptr(), + IWORK.memptr(), + &info + ); + + out_rcond = rcond; + + return ((info == 0) || (info == (n+1))); + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(KL); + arma_ignore(KU); + arma_ignore(B_expr); + arma_ignore(equilibrate); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via LU decomposition with refinement (complex band matrices) +template +inline +bool +auxlib::solve_band_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base,T1>& B_expr, const bool equilibrate) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::solve_band_refine(): redirecting to auxlib::solve_square_refine() due to crippled LAPACK"); + + arma_ignore(KL); + arma_ignore(KU); + + return auxlib::solve_square_refine(out, out_rcond, A, B_expr, equilibrate); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat B = B_expr.get_ref(); // B is overwritten + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } + + // for gbsvx, matrix AB size: KL+KU+1 x N; band representation of A stored in rows 1 to KL+KU+1 (note: fortran counts from 1) + + Mat AB; + band_helper::compress(AB, A, KL, KU, false); + + const uword N = AB.n_cols; + + arma_debug_assert_blas_size(AB,B); + + out.set_size(N, B.n_cols); + + Mat AFB(2*KL+KU+1, N, arma_nozeros_indicator()); + + char fact = (equilibrate) ? 'E' : 'N'; + char trans = 'N'; + char equed = char(0); + blas_int n = blas_int(N); + blas_int kl = blas_int(KL); + blas_int ku = blas_int(KU); + blas_int nrhs = blas_int(B.n_cols); + blas_int ldab = blas_int(AB.n_rows); + blas_int ldafb = blas_int(AFB.n_rows); + blas_int ldb = blas_int(B.n_rows); + blas_int ldx = blas_int(N); + blas_int info = blas_int(0); + T rcond = T(0); + + podarray IPIV( N); + podarray< T> R( N); + podarray< T> C( N); + podarray< T> FERR( B.n_cols); + podarray< T> BERR( B.n_cols); + podarray WORK(2*N); + podarray< T> RWORK( N); // NOTE: according to lapack 3.6.1 docs, the size of RWORK in zgbsvx is different to RWORK in dgesvx + + arma_extra_debug_print("lapack::cx_gbsvx()"); + lapack::cx_gbsvx + ( + &fact, &trans, &n, &kl, &ku, &nrhs, + AB.memptr(), &ldab, + AFB.memptr(), &ldafb, + IPIV.memptr(), + &equed, + R.memptr(), + C.memptr(), + B.memptr(), &ldb, + out.memptr(), &ldx, + &rcond, + FERR.memptr(), + BERR.memptr(), + WORK.memptr(), + RWORK.memptr(), + &info + ); + + out_rcond = rcond; + + return ((info == 0) || (info == (n+1))); + } + #else + { + arma_ignore(out); + arma_ignore(out_rcond); + arma_ignore(A); + arma_ignore(KL); + arma_ignore(KU); + arma_ignore(B_expr); + arma_ignore(equilibrate); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a system of linear equations via Gaussian elimination with partial pivoting (real tridiagonal band matrix) +template +inline +bool +auxlib::solve_tridiag_fast(Mat& out, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + return auxlib::solve_tridiag_fast_common(out, A, B_expr); + } + + + +//! solve a system of linear equations via Gaussian elimination with partial pivoting (complex tridiagonal band matrix) +template +inline +bool +auxlib::solve_tridiag_fast(Mat< std::complex >& out, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::solve_tridiag_fast(): redirecting to auxlib::solve_square_fast() due to crippled LAPACK"); + + return auxlib::solve_square_fast(out, A, B_expr); + } + #else + { + return auxlib::solve_tridiag_fast_common(out, A, B_expr); + } + #endif + } + + + +//! solve a system of linear equations via Gaussian elimination with partial pivoting (tridiagonal band matrix) +template +inline +bool +auxlib::solve_tridiag_fast_common(Mat& out, const Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_rows, B_n_cols); return true; } + + Mat tridiag; + band_helper::extract_tridiag(tridiag, A); + + arma_debug_assert_blas_size(tridiag, out); + + blas_int n = blas_int(A.n_rows); + blas_int nrhs = blas_int(B_n_cols); + blas_int ldb = blas_int(B_n_rows); + blas_int info = blas_int(0); + + arma_extra_debug_print("lapack::gtsv()"); + lapack::gtsv(&n, &nrhs, tridiag.colptr(0), tridiag.colptr(1), tridiag.colptr(2), out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +// +// Schur decomposition + +template +inline +bool +auxlib::schur(Mat& U, Mat& S, const Base& X, const bool calc_U) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + S = X.get_ref(); + + arma_debug_check( (S.is_square() == false), "schur(): given matrix must be square sized" ); + + if(S.is_empty()) { U.reset(); S.reset(); return true; } + + arma_debug_assert_blas_size(S); + + const uword S_n_rows = S.n_rows; + + if(calc_U) { U.set_size(S_n_rows, S_n_rows); } else { U.set_size(1,1); } + + char jobvs = calc_U ? 'V' : 'N'; + char sort = 'N'; + void* select = 0; + blas_int n = blas_int(S_n_rows); + blas_int sdim = 0; + blas_int ldvs = calc_U ? n : blas_int(1); + blas_int lwork = 64*n; // lwork_min = (std::max)(blas_int(1), 3*n) + blas_int info = 0; + + podarray wr(S_n_rows); + podarray wi(S_n_rows); + + podarray work( static_cast(lwork) ); + podarray bwork(S_n_rows); + + arma_extra_debug_print("lapack::gees()"); + lapack::gees(&jobvs, &sort, select, &n, S.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), U.memptr(), &ldvs, work.memptr(), &lwork, bwork.memptr(), &info); + + return (info == 0); + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(X); + arma_ignore(calc_U); + arma_stop_logic_error("schur(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::schur(Mat< std::complex >& U, Mat< std::complex >& S, const Base,T1>& X, const bool calc_U) + { + arma_extra_debug_sigprint(); + + S = X.get_ref(); + + arma_debug_check( (S.is_square() == false), "schur(): given matrix must be square sized" ); + + return auxlib::schur(U,S,calc_U); + } + + + +template +inline +bool +auxlib::schur(Mat< std::complex >& U, Mat< std::complex >& S, const bool calc_U) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef std::complex eT; + + if(S.is_empty()) { U.reset(); S.reset(); return true; } + + arma_debug_assert_blas_size(S); + + const uword S_n_rows = S.n_rows; + + if(calc_U) { U.set_size(S_n_rows, S_n_rows); } else { U.set_size(1,1); } + + char jobvs = calc_U ? 'V' : 'N'; + char sort = 'N'; + void* select = 0; + blas_int n = blas_int(S_n_rows); + blas_int sdim = 0; + blas_int ldvs = calc_U ? n : blas_int(1); + blas_int lwork = 64*n; // lwork_min = (std::max)(blas_int(1), 2*n) + blas_int info = 0; + + podarray w(S_n_rows); + podarray work( static_cast(lwork) ); + podarray< T> rwork(S_n_rows); + podarray bwork(S_n_rows); + + arma_extra_debug_print("lapack::cx_gees()"); + lapack::cx_gees(&jobvs, &sort, select, &n, S.memptr(), &n, &sdim, w.memptr(), U.memptr(), &ldvs, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info); + + return (info == 0); + } + #else + { + arma_ignore(U); + arma_ignore(S); + arma_ignore(calc_U); + arma_stop_logic_error("schur(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +// +// solve the Sylvester equation AX + XB = C + +template +inline +bool +auxlib::syl(Mat& X, const Mat& A, const Mat& B, const Mat& C) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_check( (A.is_square() == false) || (B.is_square() == false), "syl(): given matrices must be square sized" ); + + arma_debug_check( (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols), "syl(): matrices are not conformant" ); + + if(A.is_empty() || B.is_empty() || C.is_empty()) { X.reset(); return true; } + + Mat Z1, Z2, T1, T2; + + const bool status_sd1 = auxlib::schur(Z1, T1, A); + const bool status_sd2 = auxlib::schur(Z2, T2, B); + + if( (status_sd1 == false) || (status_sd2 == false) ) { return false; } + + char trana = 'N'; + char tranb = 'N'; + blas_int isgn = +1; + blas_int m = blas_int(T1.n_rows); + blas_int n = blas_int(T2.n_cols); + + eT scale = eT(0); + blas_int info = 0; + + Mat Y = trans(Z1) * C * Z2; + + arma_extra_debug_print("lapack::trsyl()"); + lapack::trsyl(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info); + + if(info < 0) { return false; } + + //Y /= scale; + Y /= (-scale); + + X = Z1 * Y * trans(Z2); + + return true; + } + #else + { + arma_ignore(X); + arma_ignore(A); + arma_ignore(B); + arma_ignore(C); + arma_stop_logic_error("syl(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +// +// QZ decomposition of general square real matrix pair + +template +inline +bool +auxlib::qz(Mat& A, Mat& B, Mat& vsl, Mat& vsr, const Base& X_expr, const Base& Y_expr, const char mode) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + A = X_expr.get_ref(); + B = Y_expr.get_ref(); + + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "qz(): given matrices must be square sized", [&](){ A.soft_reset(); B.soft_reset(); } ); + + arma_debug_check( (A.n_rows != B.n_rows), "qz(): given matrices must have the same size" ); + + if(A.is_empty()) { A.reset(); B.reset(); vsl.reset(); vsr.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + vsl.set_size(A.n_rows, A.n_rows); + vsr.set_size(A.n_rows, A.n_rows); + + char jobvsl = 'V'; + char jobvsr = 'V'; + char eigsort = 'N'; + void* selctg = 0; + blas_int N = blas_int(A.n_rows); + blas_int sdim = 0; + blas_int lwork = 64*N+16; // lwork_min = (std::max)(blas_int(1),8*N+16) + blas_int info = 0; + + if(mode == 'l') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::select_lhp)); } + else if(mode == 'r') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::select_rhp)); } + else if(mode == 'i') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::select_iuc)); } + else if(mode == 'o') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::select_ouc)); } + + podarray alphar(A.n_rows); + podarray alphai(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + podarray bwork( static_cast(N) ); + + arma_extra_debug_print("lapack::gges()"); + + lapack::gges + ( + &jobvsl, &jobvsr, &eigsort, selctg, &N, + A.memptr(), &N, B.memptr(), &N, &sdim, + alphar.memptr(), alphai.memptr(), beta.memptr(), + vsl.memptr(), &N, vsr.memptr(), &N, + work.memptr(), &lwork, bwork.memptr(), + &info + ); + + if(info != 0) { return false; } + + op_strans::apply_mat_inplace(vsl); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(B); + arma_ignore(vsl); + arma_ignore(vsr); + arma_ignore(X_expr); + arma_ignore(Y_expr); + arma_ignore(mode); + arma_stop_logic_error("qz(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +// +// QZ decomposition of general square complex matrix pair + +template +inline +bool +auxlib::qz(Mat< std::complex >& A, Mat< std::complex >& B, Mat< std::complex >& vsl, Mat< std::complex >& vsr, const Base< std::complex, T1 >& X_expr, const Base< std::complex, T2 >& Y_expr, const char mode) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + A = X_expr.get_ref(); + B = Y_expr.get_ref(); + + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "qz(): given matrices must be square sized", [&](){ A.soft_reset(); B.soft_reset(); } ); + + arma_debug_check( (A.n_rows != B.n_rows), "qz(): given matrices must have the same size" ); + + if(A.is_empty()) { A.reset(); B.reset(); vsl.reset(); vsr.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + arma_debug_assert_blas_size(A); + + vsl.set_size(A.n_rows, A.n_rows); + vsr.set_size(A.n_rows, A.n_rows); + + char jobvsl = 'V'; + char jobvsr = 'V'; + char eigsort = 'N'; + void* selctg = 0; + blas_int N = blas_int(A.n_rows); + blas_int sdim = 0; + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1),2*N) + blas_int info = 0; + + if(mode == 'l') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::cx_select_lhp)); } + else if(mode == 'r') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::cx_select_rhp)); } + else if(mode == 'i') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::cx_select_iuc)); } + else if(mode == 'o') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::cx_select_ouc)); } + + podarray alpha(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + podarray< T> rwork( static_cast(8*N) ); + podarray bwork( static_cast(N) ); + + arma_extra_debug_print("lapack::cx_gges()"); + + lapack::cx_gges + ( + &jobvsl, &jobvsr, &eigsort, selctg, &N, + A.memptr(), &N, B.memptr(), &N, &sdim, + alpha.memptr(), beta.memptr(), + vsl.memptr(), &N, vsr.memptr(), &N, + work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), + &info + ); + + if(info != 0) { return false; } + + op_htrans::apply_mat_inplace(vsl); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(B); + arma_ignore(vsl); + arma_ignore(vsr); + arma_ignore(X_expr); + arma_ignore(Y_expr); + arma_ignore(mode); + arma_stop_logic_error("qz(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +eT +auxlib::rcond(Mat& A) + { + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_rows); // assuming square matrix + blas_int lda = blas_int(A.n_rows); + eT norm_val = eT(0); + eT rcond = eT(0); + blas_int info = blas_int(0); + + podarray work(4*A.n_rows); + podarray iwork( A.n_rows); + podarray ipiv( (std::min)(A.n_rows, A.n_cols) ); + + arma_extra_debug_print("lapack::lange()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &m, &n, A.memptr(), &lda, work.memptr()); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&m, &n, A.memptr(), &lda, ipiv.memptr(), &info); + + if(info != blas_int(0)) { return eT(0); } + + arma_extra_debug_print("lapack::gecon()"); + lapack::gecon(&norm_id, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), iwork.memptr(), &info); + + if(info != blas_int(0)) { return eT(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + return eT(0); + } + #endif + } + + + +template +inline +T +auxlib::rcond(Mat< std::complex >& A) + { + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_rows); // assuming square matrix + blas_int lda = blas_int(A.n_rows); + T norm_val = T(0); + T rcond = T(0); + blas_int info = blas_int(0); + + podarray< T> junk(1); + podarray work(2*A.n_rows); + podarray< T> rwork(2*A.n_rows); + podarray ipiv( (std::min)(A.n_rows, A.n_cols) ); + + arma_extra_debug_print("lapack::lange()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &m, &n, A.memptr(), &lda, junk.memptr()); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&m, &n, A.memptr(), &lda, ipiv.memptr(), &info); + + if(info != blas_int(0)) { return T(0); } + + arma_extra_debug_print("lapack::cx_gecon()"); + lapack::cx_gecon(&norm_id, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), rwork.memptr(), &info); + + if(info != blas_int(0)) { return T(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + return T(0); + } + #endif + } + + + +template +inline +eT +auxlib::rcond_sympd(Mat& A, bool& calc_ok) + { + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + calc_ok = false; + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + blas_int lda = blas_int(A.n_rows); + eT norm_val = eT(0); + eT rcond = eT(0); + blas_int info = blas_int(0); + + podarray work(3*A.n_rows); + podarray iwork( A.n_rows); + + arma_extra_debug_print("lapack::lansy()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &lda, work.memptr()); + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &lda, &info); + + if(info != blas_int(0)) { return eT(0); } + + arma_extra_debug_print("lapack::pocon()"); + lapack::pocon(&uplo, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), iwork.memptr(), &info); + + if(info != blas_int(0)) { return eT(0); } + + calc_ok = true; + + return rcond; + } + #else + { + arma_ignore(A); + calc_ok = false; + arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + return eT(0); + } + #endif + } + + + +template +inline +T +auxlib::rcond_sympd(Mat< std::complex >& A, bool& calc_ok) + { + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::rcond_sympd(): redirecting to auxlib::rcond() due to crippled LAPACK"); + + calc_ok = true; + + return auxlib::rcond(A); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_debug_assert_blas_size(A); + + calc_ok = false; + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + blas_int lda = blas_int(A.n_rows); + T norm_val = T(0); + T rcond = T(0); + blas_int info = blas_int(0); + + podarray work(2*A.n_rows); + podarray< T> rwork( A.n_rows); + + arma_extra_debug_print("lapack::lanhe()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, rwork.memptr()); + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &lda, &info); + + if(info != blas_int(0)) { return T(0); } + + arma_extra_debug_print("lapack::cx_pocon()"); + lapack::cx_pocon(&uplo, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), rwork.memptr(), &info); + + if(info != blas_int(0)) { return T(0); } + + calc_ok = true; + + return rcond; + } + #else + { + arma_ignore(A); + calc_ok = false; + arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + return T(0); + } + #endif + } + + + +template +inline +eT +auxlib::rcond_trimat(const Mat& A, const uword layout) + { + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + char uplo = (layout == 0) ? 'U' : 'L'; + char diag = 'N'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + eT rcond = eT(0); + blas_int info = blas_int(0); + + podarray work(3*A.n_rows); + podarray iwork( A.n_rows); + + arma_extra_debug_print("lapack::trcon()"); + lapack::trcon(&norm_id, &uplo, &diag, &n, A.memptr(), &n, &rcond, work.memptr(), iwork.memptr(), &info); + + if(info != blas_int(0)) { return eT(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_ignore(layout); + arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + return eT(0); + } + #endif + } + + + +template +inline +T +auxlib::rcond_trimat(const Mat< std::complex >& A, const uword layout) + { + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + char uplo = (layout == 0) ? 'U' : 'L'; + char diag = 'N'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + T rcond = T(0); + blas_int info = blas_int(0); + + podarray work(2*A.n_rows); + podarray< T> rwork( A.n_rows); + + arma_extra_debug_print("lapack::cx_trcon()"); + lapack::cx_trcon(&norm_id, &uplo, &diag, &n, A.memptr(), &n, &rcond, work.memptr(), rwork.memptr(), &info); + + if(info != blas_int(0)) { return T(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_ignore(layout); + arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + return T(0); + } + #endif + } + + + +template +inline +eT +auxlib::lu_rcond(const Mat& A, const eT norm_val) + { + #if defined(ARMA_USE_LAPACK) + { + char norm_id = '1'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + blas_int lda = blas_int(A.n_rows); + eT rcond = eT(0); + blas_int info = blas_int(0); + + podarray work(4*A.n_rows); + podarray iwork( A.n_rows); + + arma_extra_debug_print("lapack::gecon()"); + lapack::gecon(&norm_id, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), iwork.memptr(), &info); + + if(info != blas_int(0)) { return eT(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_ignore(norm_val); + return eT(0); + } + #endif + } + + + +template +inline +T +auxlib::lu_rcond(const Mat< std::complex >& A, const T norm_val) + { + #if defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + char norm_id = '1'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + blas_int lda = blas_int(A.n_rows); + T rcond = T(0); + blas_int info = blas_int(0); + + podarray work(2*A.n_rows); + podarray< T> rwork(2*A.n_rows); + + arma_extra_debug_print("lapack::cx_gecon()"); + lapack::cx_gecon(&norm_id, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), rwork.memptr(), &info); + + if(info != blas_int(0)) { return T(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_ignore(norm_val); + return T(0); + } + #endif + } + + + +template +inline +eT +auxlib::lu_rcond_sympd(const Mat& A, const eT norm_val) + { + #if defined(ARMA_USE_LAPACK) + { + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + eT rcond = eT(0); + blas_int info = blas_int(0); + + podarray work(3*A.n_rows); + podarray iwork( A.n_rows); + + arma_extra_debug_print("lapack::pocon()"); + lapack::pocon(&uplo, &n, A.memptr(), &n, &norm_val, &rcond, work.memptr(), iwork.memptr(), &info); + + if(info != blas_int(0)) { return eT(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_ignore(norm_val); + return eT(0); + } + #endif + } + + + +template +inline +T +auxlib::lu_rcond_sympd(const Mat< std::complex >& A, const T norm_val) + { + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_ignore(A); + arma_ignore(norm_val); + return T(0); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); // assuming square matrix + T rcond = T(0); + blas_int info = blas_int(0); + + podarray work(2*A.n_rows); + podarray< T> rwork( A.n_rows); + + arma_extra_debug_print("lapack::cx_pocon()"); + lapack::cx_pocon(&uplo, &n, A.memptr(), &n, &norm_val, &rcond, work.memptr(), rwork.memptr(), &info); + + if(info != blas_int(0)) { return T(0); } + + return rcond; + } + #else + { + arma_ignore(A); + arma_ignore(norm_val); + return T(0); + } + #endif + } + + + +template +inline +eT +auxlib::lu_rcond_band(const Mat& AB, const uword KL, const uword KU, const podarray& ipiv, const eT norm_val) + { + #if defined(ARMA_USE_LAPACK) + { + const uword N = AB.n_cols; // order of the original square matrix A + + char norm_id = '1'; + blas_int n = blas_int(N); + blas_int kl = blas_int(KL); + blas_int ku = blas_int(KU); + blas_int ldab = blas_int(AB.n_rows); + eT rcond = eT(0); + blas_int info = blas_int(0); + + podarray work(3*N); + podarray iwork( N); + + arma_extra_debug_print("lapack::gbcon()"); + lapack::gbcon(&norm_id, &n, &kl, &ku, AB.memptr(), &ldab, ipiv.memptr(), &norm_val, &rcond, work.memptr(), iwork.memptr(), &info); + + if(info != blas_int(0)) { return eT(0); } + + return rcond; + } + #else + { + arma_ignore(AB); + arma_ignore(KL); + arma_ignore(KU); + arma_ignore(ipiv); + arma_ignore(norm_val); + return eT(0); + } + #endif + } + + + +template +inline +T +auxlib::lu_rcond_band(const Mat< std::complex >& AB, const uword KL, const uword KU, const podarray& ipiv, const T norm_val) + { + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_ignore(AB); + arma_ignore(KL); + arma_ignore(KU); + arma_ignore(ipiv); + arma_ignore(norm_val); + return T(0); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + const uword N = AB.n_cols; // order of the original square matrix A + + char norm_id = '1'; + blas_int n = blas_int(N); + blas_int kl = blas_int(KL); + blas_int ku = blas_int(KU); + blas_int ldab = blas_int(AB.n_rows); + T rcond = T(0); + blas_int info = blas_int(0); + + podarray work(2*N); + podarray< T> rwork( N); + + arma_extra_debug_print("lapack::cx_gbcon()"); + lapack::cx_gbcon(&norm_id, &n, &kl, &ku, AB.memptr(), &ldab, ipiv.memptr(), &norm_val, &rcond, work.memptr(), rwork.memptr(), &info); + + if(info != blas_int(0)) { return T(0); } + + return rcond; + } + #else + { + arma_ignore(AB); + arma_ignore(KL); + arma_ignore(KU); + arma_ignore(ipiv); + arma_ignore(norm_val); + return T(0); + } + #endif + } + + + +template +inline +bool +auxlib::crippled_lapack(const Base&) + { + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::crippled_lapack(): true"); + + return (is_cx::yes); + } + #else + { + return false; + } + #endif + } + + + +template +inline +bool +auxlib::rudimentary_sym_check(const Mat& X) + { + arma_extra_debug_sigprint(); + + const uword N = X.n_rows; + const uword Nm2 = N-2; + + if(N != X.n_cols) { return false; } + if(N <= uword(1)) { return true; } + + const eT* X_mem = X.memptr(); + + const eT* X_offsetA = &(X_mem[Nm2 ]); + const eT* X_offsetB = &(X_mem[Nm2*N]); + + const eT A1 = *(X_offsetA ); + const eT A2 = *(X_offsetA+1); // bottom-left corner (ie. last value in first column) + const eT B1 = *(X_offsetB ); + const eT B2 = *(X_offsetB+N); // top-right corner (ie. first value in last column) + + const eT C1 = (std::max)(std::abs(A1), std::abs(B1)); + const eT C2 = (std::max)(std::abs(A2), std::abs(B2)); + + const eT delta1 = std::abs(A1 - B1); + const eT delta2 = std::abs(A2 - B2); + + const eT tol = eT(10000)*std::numeric_limits::epsilon(); // allow some leeway + + const bool okay1 = ( (delta1 <= tol) || (delta1 <= (C1 * tol)) ); + const bool okay2 = ( (delta2 <= tol) || (delta2 <= (C2 * tol)) ); + + return (okay1 && okay2); + } + + + +template +inline +bool +auxlib::rudimentary_sym_check(const Mat< std::complex >& X) + { + arma_extra_debug_sigprint(); + + // NOTE: the function name is a misnomer, as it checks for hermitian complex matrices; + // NOTE: for simplicity of use, the function name is the same as for real matrices + + typedef typename std::complex eT; + + const uword N = X.n_rows; + const uword Nm1 = N-1; + + if(N != X.n_cols) { return false; } + if(N == uword(0)) { return true; } + + const eT* X_mem = X.memptr(); + + const T tol = T(10000)*std::numeric_limits::epsilon(); // allow some leeway + + if(std::abs(X_mem[0 ].imag()) > tol) { return false; } // check top-left + if(std::abs(X_mem[X.n_elem-1].imag()) > tol) { return false; } // check bottom-right + + const eT& A = X_mem[Nm1 ]; // bottom-left corner (ie. last value in first column) + const eT& B = X_mem[Nm1*N]; // top-right corner (ie. first value in last column) + + const T C_real = (std::max)(std::abs(A.real()), std::abs(B.real())); + const T C_imag = (std::max)(std::abs(A.imag()), std::abs(B.imag())); + + const T delta_real = std::abs(A.real() - B.real()); + const T delta_imag = std::abs(A.imag() + B.imag()); // take into account the conjugate + + const bool okay_real = ( (delta_real <= tol) || (delta_real <= (C_real * tol)) ); + const bool okay_imag = ( (delta_imag <= tol) || (delta_imag <= (C_imag * tol)) ); + + return (okay_real && okay_imag); + } + + + +template +inline +typename get_pod_type::result +auxlib::norm1_gen(const Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.n_elem == 0) { return T(0); } + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + T max_val = T(0); + + for(uword c=0; c < n_cols; ++c) + { + const eT* colmem = A.colptr(c); + T acc_val = T(0); + + for(uword r=0; r < n_rows; ++r) { acc_val += std::abs(colmem[r]); } + + max_val = (acc_val > max_val) ? acc_val : max_val; + } + + return max_val; + } + + + +template +inline +typename get_pod_type::result +auxlib::norm1_sym(const Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.n_elem == 0) { return T(0); } + + const uword N = (std::min)(A.n_rows, A.n_cols); + + T max_val = T(0); + + for(uword col=0; col < N; ++col) + { + const eT* colmem = A.colptr(col); + T acc_val = T(0); + + for(uword c=0; c < col; ++c) { acc_val += std::abs(A.at(col,c)); } + + for(uword r=col; r < N; ++r) { acc_val += std::abs(colmem[r]); } + + max_val = (acc_val > max_val) ? acc_val : max_val; + } + + return max_val; + } + + + +template +inline +typename get_pod_type::result +auxlib::norm1_band(const Mat& A, const uword KL, const uword KU) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.n_elem == 0) { return T(0); } + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + T max_val = T(0); + + for(uword c=0; c < n_cols; ++c) + { + const eT* colmem = A.colptr(c); + T acc_val = T(0); + + // use values only from main diagonal + KU upper diagonals + KL lower diagonals + + const uword start = ( c > KU ) ? (c - KU) : 0; + const uword end = ((c + KL) < n_rows) ? (c + KL) : (n_rows-1); + + for(uword r=start; r <= end; ++r) { acc_val += std::abs(colmem[r]); } + + max_val = (acc_val > max_val) ? acc_val : max_val; + } + + return max_val; + } + + + +// + + + +namespace qz_helper +{ + +// sgges() and dgges() require an external function with three arguments: +// select(alpha_real, alpha_imag, beta) +// where the eigenvalue is defined as complex(alpha_real, alpha_imag) / beta + +template +inline +blas_int +select_lhp(const T* x_ptr, const T* y_ptr, const T* z_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "select_lhp(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "select_lhp(): (*y_ptr) = " << (*y_ptr) << endl; + // cout << "select_lhp(): (*z_ptr) = " << (*z_ptr) << endl; + + arma_ignore(y_ptr); // ignore imaginary part + + const T x = (*x_ptr); + const T z = (*z_ptr); + + if(z == T(0)) { return blas_int(0); } // consider an infinite eig value not to lie in either lhp or rhp + + return ((x/z) < T(0)) ? blas_int(1) : blas_int(0); + } + + + +template +inline +blas_int +select_rhp(const T* x_ptr, const T* y_ptr, const T* z_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "select_rhp(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "select_rhp(): (*y_ptr) = " << (*y_ptr) << endl; + // cout << "select_rhp(): (*z_ptr) = " << (*z_ptr) << endl; + + arma_ignore(y_ptr); // ignore imaginary part + + const T x = (*x_ptr); + const T z = (*z_ptr); + + if(z == T(0)) { return blas_int(0); } // consider an infinite eig value not to lie in either lhp or rhp + + return ((x/z) > T(0)) ? blas_int(1) : blas_int(0); + } + + + +template +inline +blas_int +select_iuc(const T* x_ptr, const T* y_ptr, const T* z_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "select_iuc(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "select_iuc(): (*y_ptr) = " << (*y_ptr) << endl; + // cout << "select_iuc(): (*z_ptr) = " << (*z_ptr) << endl; + + const T x = (*x_ptr); + const T y = (*y_ptr); + const T z = (*z_ptr); + + if(z == T(0)) { return blas_int(0); } // consider an infinite eig value to be outside of the unit circle + + //return (std::abs(std::complex(x,y) / z) < T(1)) ? blas_int(1) : blas_int(0); + return (std::sqrt(x*x + y*y) < std::abs(z)) ? blas_int(1) : blas_int(0); + } + + + +template +inline +blas_int +select_ouc(const T* x_ptr, const T* y_ptr, const T* z_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "select_ouc(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "select_ouc(): (*y_ptr) = " << (*y_ptr) << endl; + // cout << "select_ouc(): (*z_ptr) = " << (*z_ptr) << endl; + + const T x = (*x_ptr); + const T y = (*y_ptr); + const T z = (*z_ptr); + + if(z == T(0)) + { + return (x == T(0)) ? blas_int(0) : blas_int(1); // consider an infinite eig value to be outside of the unit circle + } + + //return (std::abs(std::complex(x,y) / z) > T(1)) ? blas_int(1) : blas_int(0); + return (std::sqrt(x*x + y*y) > std::abs(z)) ? blas_int(1) : blas_int(0); + } + + + +// cgges() and zgges() require an external function with two arguments: +// select(alpha, beta) +// where the complex eigenvalue is defined as (alpha / beta) + +template +inline +blas_int +cx_select_lhp(const std::complex* x_ptr, const std::complex* y_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "cx_select_lhp(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "cx_select_lhp(): (*y_ptr) = " << (*y_ptr) << endl; + + const std::complex& x = (*x_ptr); + const std::complex& y = (*y_ptr); + + if( (y.real() == T(0)) && (y.imag() == T(0)) ) { return blas_int(0); } // consider an infinite eig value not to lie in either lhp or rhp + + return (std::real(x / y) < T(0)) ? blas_int(1) : blas_int(0); + } + + + +template +inline +blas_int +cx_select_rhp(const std::complex* x_ptr, const std::complex* y_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "cx_select_rhp(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "cx_select_rhp(): (*y_ptr) = " << (*y_ptr) << endl; + + const std::complex& x = (*x_ptr); + const std::complex& y = (*y_ptr); + + if( (y.real() == T(0)) && (y.imag() == T(0)) ) { return blas_int(0); } // consider an infinite eig value not to lie in either lhp or rhp + + return (std::real(x / y) > T(0)) ? blas_int(1) : blas_int(0); + } + + + +template +inline +blas_int +cx_select_iuc(const std::complex* x_ptr, const std::complex* y_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "cx_select_iuc(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "cx_select_iuc(): (*y_ptr) = " << (*y_ptr) << endl; + + const std::complex& x = (*x_ptr); + const std::complex& y = (*y_ptr); + + if( (y.real() == T(0)) && (y.imag() == T(0)) ) { return blas_int(0); } // consider an infinite eig value to be outside of the unit circle + + return (std::abs(x / y) < T(1)) ? blas_int(1) : blas_int(0); + } + + + +template +inline +blas_int +cx_select_ouc(const std::complex* x_ptr, const std::complex* y_ptr) + { + arma_extra_debug_sigprint(); + + // cout << "cx_select_ouc(): (*x_ptr) = " << (*x_ptr) << endl; + // cout << "cx_select_ouc(): (*y_ptr) = " << (*y_ptr) << endl; + + const std::complex& x = (*x_ptr); + const std::complex& y = (*y_ptr); + + if( (y.real() == T(0)) && (y.imag() == T(0)) ) + { + return ((x.real() == T(0)) && (x.imag() == T(0))) ? blas_int(0) : blas_int(1); // consider an infinite eig value to be outside of the unit circle + } + + return (std::abs(x / y) > T(1)) ? blas_int(1) : blas_int(0); + } + + + +// need to do shenanigans with pointers due to: +// - we're using LAPACK ?gges() defined to expect pointer-to-function to be passed as pointer-to-object +// - explicit casting between pointer-to-function and pointer-to-object is a non-standard extension in C +// - the extension is essentially mandatory on POSIX systems +// - some compilers will complain about the extension in pedantic mode + +template +inline +void_ptr +ptr_cast(blas_int (*function)(const T*, const T*, const T*)) + { + union converter + { + blas_int (*fn)(const T*, const T*, const T*); + void_ptr obj; + }; + + converter tmp; + + tmp.obj = 0; + tmp.fn = function; + + return tmp.obj; + } + + + +template +inline +void_ptr +ptr_cast(blas_int (*function)(const std::complex*, const std::complex*)) + { + union converter + { + blas_int (*fn)(const std::complex*, const std::complex*); + void_ptr obj; + }; + + converter tmp; + + tmp.obj = 0; + tmp.fn = function; + + return tmp.obj; + } + + + +} // end of namespace qz_helper + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/band_helper.hpp b/src/armadillo/include/armadillo_bits/band_helper.hpp new file mode 100644 index 0000000..4493c70 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/band_helper.hpp @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup band_helper +//! @{ + + +namespace band_helper +{ + + + +template +inline +bool +is_band(uword& out_KL, uword& out_KU, const Mat& A, const uword N_min) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + // NOTE: assuming that N_min is >= 4 + + const uword N = A.n_rows; + + if(N < N_min) { return false; } + + // first, quickly check bottom-left and top-right corners + + const eT eT_zero = eT(0); + + const eT* A_col0 = A.memptr(); + const eT* A_col1 = A_col0 + N; + + if( (A_col0[N-2] != eT_zero) || (A_col0[N-1] != eT_zero) || (A_col1[N-2] != eT_zero) || (A_col1[N-1] != eT_zero) ) { return false; } + + const eT* A_colNm2 = A.colptr(N-2); + const eT* A_colNm1 = A_colNm2 + N; + + if( (A_colNm2[0] != eT_zero) || (A_colNm2[1] != eT_zero) || (A_colNm1[0] != eT_zero) || (A_colNm1[1] != eT_zero) ) { return false; } + + // if we reached this point, go through the entire matrix to work out number of subdiagonals and superdiagonals + + const uword n_nonzero_threshold = (N*N)/4; // empirically determined + + uword KL = 0; // number of subdiagonals + uword KU = 0; // number of superdiagonals + + const eT* A_colptr = A.memptr(); + + for(uword col=0; col < N; ++col) + { + uword first_nonzero_row = col; + uword last_nonzero_row = col; + + for(uword row=0; row < col; ++row) + { + if( A_colptr[row] != eT_zero ) { first_nonzero_row = row; break; } + } + + for(uword row=(col+1); row < N; ++row) + { + last_nonzero_row = (A_colptr[row] != eT_zero) ? row : last_nonzero_row; + } + + const uword L_count = last_nonzero_row - col; + const uword U_count = col - first_nonzero_row; + + if( (L_count > KL) || (U_count > KU) ) + { + KL = (std::max)(KL, L_count); + KU = (std::max)(KU, U_count); + + const uword n_nonzero = N*(KL+KU+1) - (KL*(KL+1) + KU*(KU+1))/2; + + // return as soon as we know that it's not worth analysing the matrix any further + + if(n_nonzero > n_nonzero_threshold) { return false; } + } + + A_colptr += N; + } + + out_KL = KL; + out_KU = KU; + + return true; + } + + + +template +inline +bool +is_band_lower(uword& out_KD, const Mat& A, const uword N_min) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + // NOTE: assuming that N_min is >= 4 + + const uword N = A.n_rows; + + if(N < N_min) { return false; } + + // first, quickly check bottom-left corner + + const eT eT_zero = eT(0); + + const eT* A_col0 = A.memptr(); + const eT* A_col1 = A_col0 + N; + + if( (A_col0[N-2] != eT_zero) || (A_col0[N-1] != eT_zero) || (A_col1[N-2] != eT_zero) || (A_col1[N-1] != eT_zero) ) { return false; } + + // if we reached this point, go through the bottom triangle to work out number of subdiagonals + + const uword n_nonzero_threshold = ( N*N - (N*(N-1))/2 ) / 4; // empirically determined + + uword KL = 0; // number of subdiagonals + + const eT* A_colptr = A.memptr(); + + for(uword col=0; col < N; ++col) + { + uword last_nonzero_row = col; + + for(uword row=(col+1); row < N; ++row) + { + last_nonzero_row = (A_colptr[row] != eT_zero) ? row : last_nonzero_row; + } + + const uword L_count = last_nonzero_row - col; + + if(L_count > KL) + { + KL = L_count; + + const uword n_nonzero = N*(KL+1) - (KL*(KL+1))/2; + + // return as soon as we know that it's not worth analysing the matrix any further + + if(n_nonzero > n_nonzero_threshold) { return false; } + } + + A_colptr += N; + } + + out_KD = KL; + + return true; + } + + + +template +inline +bool +is_band_upper(uword& out_KD, const Mat& A, const uword N_min) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + // NOTE: assuming that N_min is >= 4 + + const uword N = A.n_rows; + + if(N < N_min) { return false; } + + // first, quickly check top-right corner + + const eT eT_zero = eT(0); + + const eT* A_colNm2 = A.colptr(N-2); + const eT* A_colNm1 = A_colNm2 + N; + + if( (A_colNm2[0] != eT_zero) || (A_colNm2[1] != eT_zero) || (A_colNm1[0] != eT_zero) || (A_colNm1[1] != eT_zero) ) { return false; } + + // if we reached this point, go through the entire matrix to work out number of superdiagonals + + const uword n_nonzero_threshold = ( N*N - (N*(N-1))/2 ) / 4; // empirically determined + + uword KU = 0; // number of superdiagonals + + const eT* A_colptr = A.memptr(); + + for(uword col=0; col < N; ++col) + { + uword first_nonzero_row = col; + + for(uword row=0; row < col; ++row) + { + if( A_colptr[row] != eT_zero ) { first_nonzero_row = row; break; } + } + + const uword U_count = col - first_nonzero_row; + + if(U_count > KU) + { + KU = U_count; + + const uword n_nonzero = N*(KU+1) - (KU*(KU+1))/2; + + // return as soon as we know that it's not worth analysing the matrix any further + + if(n_nonzero > n_nonzero_threshold) { return false; } + } + + A_colptr += N; + } + + out_KD = KU; + + return true; + } + + + +template +inline +void +compress(Mat& AB, const Mat& A, const uword KL, const uword KU, const bool use_offset) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + + // band matrix storage format + // http://www.netlib.org/lapack/lug/node124.html + + // for ?gbsv, matrix AB size: 2*KL+KU+1 x N; band representation of A stored in rows KL+1 to 2*KL+KU+1 (note: fortran counts from 1) + // for ?gbsvx, matrix AB size: KL+KU+1 x N; band representaiton of A stored in rows 1 to KL+KU+1 (note: fortran counts from 1) + // + // the +1 in the above formulas is to take into account the main diagonal + + const uword AB_n_rows = (use_offset) ? uword(2*KL + KU + 1) : uword(KL + KU + 1); + const uword N = A.n_rows; + + AB.set_size(AB_n_rows, N); + + if(A.is_empty()) { AB.zeros(); return; } + + if(AB_n_rows == uword(1)) + { + eT* AB_mem = AB.memptr(); + + for(uword i=0; i KU) ? uword(j - KU) : uword(0); + const uword A_row_endp1 = (std::min)(N, j+KL+1); + + const uword length = A_row_endp1 - A_row_start; + + const uword AB_row_start = (KU > j) ? (KU - j) : uword(0); + + const eT* A_colptr = A.colptr(j) + A_row_start; + eT* AB_colptr = AB.colptr(j) + AB_row_start + ( (use_offset) ? KL : uword(0) ); + + arrayops::copy( AB_colptr, A_colptr, length ); + } + } + } + + + +template +inline +void +uncompress(Mat& A, const Mat& AB, const uword KL, const uword KU, const bool use_offset) + { + arma_extra_debug_sigprint(); + + const uword AB_n_rows = AB.n_rows; + const uword N = AB.n_cols; + + arma_debug_check( (AB_n_rows != ((use_offset) ? uword(2*KL + KU + 1) : uword(KL + KU + 1))), "band_helper::uncompress(): detected inconsistency" ); + + A.zeros(N,N); // assuming there is no aliasing between A and AB + + if(AB_n_rows == uword(1)) + { + const eT* AB_mem = AB.memptr(); + + for(uword i=0; i KU) ? uword(j - KU) : uword(0); + const uword A_row_endp1 = (std::min)(N, j+KL+1); + + const uword length = A_row_endp1 - A_row_start; + + const uword AB_row_start = (KU > j) ? (KU - j) : uword(0); + + const eT* AB_colptr = AB.colptr(j) + AB_row_start + ( (use_offset) ? KL : uword(0) ); + eT* A_colptr = A.colptr(j) + A_row_start; + + arrayops::copy( A_colptr, AB_colptr, length ); + } + } + } + + + +template +inline +void +extract_tridiag(Mat& out, const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size and is at least 2x2 + + const uword N = A.n_rows; + + out.set_size(N, 3); // assuming there is no aliasing between 'out' and 'A' + + if(N < 2) { return; } + + eT* DL = out.colptr(0); + eT* DD = out.colptr(1); + eT* DU = out.colptr(2); + + DD[0] = A[0]; + DL[0] = A[1]; + + const uword Nm1 = N-1; + const uword Nm2 = N-2; + + for(uword i=0; i < Nm2; ++i) + { + const uword ip1 = i+1; + + const eT* data = &(A.at(i, ip1)); + + const eT tmp0 = data[0]; + const eT tmp1 = data[1]; + const eT tmp2 = data[2]; + + DL[ip1] = tmp2; + DD[ip1] = tmp1; + DU[i ] = tmp0; + } + + const eT* data = &(A.at(Nm2, Nm1)); + + DL[Nm1] = 0; + DU[Nm2] = data[0]; + DU[Nm1] = 0; + DD[Nm1] = data[1]; + } + + + +} // end of namespace band_helper + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/compiler_check.hpp b/src/armadillo/include/armadillo_bits/compiler_check.hpp new file mode 100644 index 0000000..8a653d2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/compiler_check.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#undef ARMA_HAVE_CXX11 +#undef ARMA_HAVE_CXX14 +#undef ARMA_HAVE_CXX17 +#undef ARMA_HAVE_CXX20 + +#if (__cplusplus >= 201103L) + #define ARMA_HAVE_CXX11 +#endif + +#if (__cplusplus >= 201402L) + #define ARMA_HAVE_CXX14 +#endif + +#if (__cplusplus >= 201703L) + #define ARMA_HAVE_CXX17 +#endif + +#if (__cplusplus >= 202002L) + #define ARMA_HAVE_CXX20 +#endif + + +// MS really can't get its proverbial shit together +#if defined(_MSVC_LANG) + + #if (_MSVC_LANG >= 201402L) + #undef ARMA_HAVE_CXX11 + #define ARMA_HAVE_CXX11 + + #undef ARMA_HAVE_CXX14 + #define ARMA_HAVE_CXX14 + #endif + + #if (_MSVC_LANG >= 201703L) + #undef ARMA_HAVE_CXX17 + #define ARMA_HAVE_CXX17 + #endif + + #if (_MSVC_LANG >= 202002L) + #undef ARMA_HAVE_CXX20 + #define ARMA_HAVE_CXX20 + #endif + +#endif + + +// warn about ignored option used in old versions of Armadillo +#if defined(ARMA_DONT_USE_CXX11) + #pragma message ("WARNING: option ARMA_DONT_USE_CXX11 ignored") +#endif + + +#if !defined(ARMA_HAVE_CXX11) + #error "*** C++11 compiler required; enable C++11 mode in your compiler, or use an earlier version of Armadillo" +#endif + + +// for compatibility with earlier versions of Armadillo +#undef ARMA_USE_CXX11 +#define ARMA_USE_CXX11 diff --git a/src/armadillo/include/armadillo_bits/compiler_setup.hpp b/src/armadillo/include/armadillo_bits/compiler_setup.hpp new file mode 100644 index 0000000..775b82a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/compiler_setup.hpp @@ -0,0 +1,511 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#undef arma_hot +#undef arma_cold +#undef arma_aligned +#undef arma_align_mem +#undef arma_warn_unused +#undef arma_deprecated +#undef arma_frown +#undef arma_malloc +#undef arma_inline +#undef arma_noinline +#undef arma_ignore + +#define arma_hot +#define arma_cold +#define arma_aligned +#define arma_align_mem +#define arma_warn_unused +#define arma_deprecated +#define arma_frown(msg) +#define arma_malloc +#define arma_inline inline +#define arma_noinline +#define arma_ignore(variable) ((void)(variable)) + +#undef arma_fortran_sans_prefix_B +#undef arma_fortran_with_prefix_B + +#if defined(ARMA_BLAS_UNDERSCORE) + #define arma_fortran_sans_prefix_B(function) function##_ + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + #define arma_fortran_with_prefix_B(function) wrapper2_##function##_ + #else + #define arma_fortran_with_prefix_B(function) wrapper_##function##_ + #endif +#else + #define arma_fortran_sans_prefix_B(function) function + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + #define arma_fortran_with_prefix_B(function) wrapper2_##function + #else + #define arma_fortran_with_prefix_B(function) wrapper_##function + #endif +#endif + +#undef arma_fortran +#undef arma_wrapper + +#if defined(ARMA_USE_WRAPPER) + #define arma_fortran(function) arma_fortran_with_prefix_B(function) + #define arma_wrapper(function) wrapper_##function +#else + #define arma_fortran(function) arma_fortran_sans_prefix_B(function) + #define arma_wrapper(function) function +#endif + +#undef arma_fortran_sans_prefix +#undef arma_fortran_with_prefix + +#define arma_fortran_sans_prefix(function) arma_fortran_sans_prefix_B(function) +#define arma_fortran_with_prefix(function) arma_fortran_with_prefix_B(function) + +#undef ARMA_INCFILE_WRAP +#define ARMA_INCFILE_WRAP(x) + + +#if !defined(ARMA_32BIT_WORD) + #undef ARMA_64BIT_WORD + #define ARMA_64BIT_WORD +#endif + +#if defined(ARMA_64BIT_WORD) && defined(SIZE_MAX) + #if (SIZE_MAX < 0xFFFFFFFFFFFFFFFFull) + // #pragma message ("WARNING: disabled use of 64 bit integers, as std::size_t is smaller than 64 bits") + #undef ARMA_64BIT_WORD + #endif +#endif + + +// most compilers can't vectorise slightly elaborate loops; +// for example clang: http://llvm.org/bugs/show_bug.cgi?id=16358 +#undef ARMA_SIMPLE_LOOPS +#define ARMA_SIMPLE_LOOPS + +#undef ARMA_GOOD_COMPILER + +// posix_memalign() is part of IEEE standard 1003.1 +// http://pubs.opengroup.org/onlinepubs/009696899/functions/posix_memalign.html +// http://pubs.opengroup.org/onlinepubs/9699919799/basedefs/unistd.h.html +// http://sourceforge.net/p/predef/wiki/Standards/ +#if ( defined(_POSIX_ADVISORY_INFO) && (_POSIX_ADVISORY_INFO >= 200112L) ) + #undef ARMA_HAVE_POSIX_MEMALIGN + #define ARMA_HAVE_POSIX_MEMALIGN +#endif + + +#if defined(__APPLE__) || defined(__apple_build_version__) + // NOTE: Apple accelerate framework has broken implementations of functions that return a float value, + // NOTE: such as sdot(), slange(), clange(), slansy(), clanhe(), slangb() + #undef ARMA_BLAS_FLOAT_BUG + #define ARMA_BLAS_FLOAT_BUG + + // #undef ARMA_HAVE_POSIX_MEMALIGN + // NOTE: posix_memalign() is available since macOS 10.6 (late 2009 onwards) +#endif + + +#if defined(__MINGW32__) || defined(__CYGWIN__) || defined(_MSC_VER) + #undef ARMA_HAVE_POSIX_MEMALIGN +#endif + + +#undef ARMA_FNSIG + +#if defined (__GNUG__) + #define ARMA_FNSIG __PRETTY_FUNCTION__ +#elif defined (_MSC_VER) + #define ARMA_FNSIG __FUNCSIG__ +#elif defined(__INTEL_COMPILER) + #define ARMA_FNSIG __FUNCTION__ +#else + #define ARMA_FNSIG __func__ +#endif + + +#if !defined(ARMA_ALLOW_FAKE_GCC) + #if (defined(__GNUG__) || defined(__GNUC__)) && (defined(__INTEL_COMPILER) || defined(__NVCC__) || defined(__CUDACC__) || defined(__PGI) || defined(__PATHSCALE__) || defined(__ARMCC_VERSION) || defined(__IBMCPP__)) + #undef ARMA_DETECTED_FAKE_GCC + #define ARMA_DETECTED_FAKE_GCC + + #pragma message ("WARNING: this compiler is pretending to be GCC but it may not be fully compatible;") + #pragma message ("WARNING: to allow this compiler to use GCC features such as data alignment attributes,") + #pragma message ("WARNING: #define ARMA_ALLOW_FAKE_GCC before #include ") + #endif +#endif + + +#if defined(__GNUG__) && (!defined(__clang__) && !defined(ARMA_DETECTED_FAKE_GCC)) + + // #pragma message ("using GCC extensions") + + #undef ARMA_GCC_VERSION + #define ARMA_GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) + + #if (ARMA_GCC_VERSION < 40803) + #error "*** newer compiler required; need gcc 4.8.3 or newer ***" + #endif + + // #if (ARMA_GCC_VERSION < 60100) + // #pragma message ("WARNING: support for gcc versions older than 6.1 is deprecated") + // #endif + + #define ARMA_GOOD_COMPILER + + #undef arma_hot + #undef arma_cold + #undef arma_aligned + #undef arma_align_mem + #undef arma_warn_unused + #undef arma_deprecated + #undef arma_frown + #undef arma_malloc + #undef arma_inline + #undef arma_noinline + + #define arma_hot __attribute__((__hot__)) + #define arma_cold __attribute__((__cold__)) + #define arma_aligned __attribute__((__aligned__)) + #define arma_align_mem __attribute__((__aligned__(16))) + #define arma_warn_unused __attribute__((__warn_unused_result__)) + #define arma_deprecated __attribute__((__deprecated__)) + #define arma_frown(msg) __attribute__((__deprecated__(msg))) + #define arma_malloc __attribute__((__malloc__)) + #define arma_inline __attribute__((__always_inline__)) inline + #define arma_noinline __attribute__((__noinline__)) + + #undef ARMA_HAVE_ALIGNED_ATTRIBUTE + #define ARMA_HAVE_ALIGNED_ATTRIBUTE + + #undef ARMA_HAVE_GCC_ASSUME_ALIGNED + #define ARMA_HAVE_GCC_ASSUME_ALIGNED + + // gcc's vectoriser can handle elaborate loops + #undef ARMA_SIMPLE_LOOPS + + #if defined(__OPTIMIZE_SIZE__) + #define ARMA_SIMPLE_LOOPS + #endif + +#endif + + +// TODO: __INTEL_CLANG_COMPILER indicates the clang based intel compiler, distinct from the classic intel compiler +#if !defined(ARMA_ALLOW_FAKE_CLANG) + #if defined(__clang__) && (defined(__INTEL_COMPILER) || defined(__NVCC__) || defined(__CUDACC__) || defined(__PGI) || defined(__PATHSCALE__) || defined(__ARMCC_VERSION) || defined(__IBMCPP__)) + #undef ARMA_DETECTED_FAKE_CLANG + #define ARMA_DETECTED_FAKE_CLANG + + #pragma message ("WARNING: this compiler is pretending to be Clang but it may not be fully compatible;") + #pragma message ("WARNING: to allow this compiler to use Clang features such as data alignment attributes,") + #pragma message ("WARNING: #define ARMA_ALLOW_FAKE_CLANG before #include ") + #endif +#endif + + +#if defined(__clang__) && !defined(ARMA_DETECTED_FAKE_CLANG) + + // #pragma message ("using Clang extensions") + + #define ARMA_GOOD_COMPILER + + #if !defined(__has_attribute) + #define __has_attribute(x) 0 + #endif + + #if __has_attribute(__aligned__) + #undef arma_aligned + #undef arma_align_mem + + #define arma_aligned __attribute__((__aligned__)) + #define arma_align_mem __attribute__((__aligned__(16))) + + #undef ARMA_HAVE_ALIGNED_ATTRIBUTE + #define ARMA_HAVE_ALIGNED_ATTRIBUTE + #endif + + #if __has_attribute(__warn_unused_result__) + #undef arma_warn_unused + #define arma_warn_unused __attribute__((__warn_unused_result__)) + #endif + + #if __has_attribute(__deprecated__) + #undef arma_deprecated + #define arma_deprecated __attribute__((__deprecated__)) + #endif + + #if __has_attribute(__deprecated__) + #undef arma_frown + #define arma_frown(msg) __attribute__((__deprecated__(msg))) + #endif + + #if __has_attribute(__malloc__) + #undef arma_malloc + #define arma_malloc __attribute__((__malloc__)) + #endif + + #if __has_attribute(__always_inline__) + #undef arma_inline + #define arma_inline __attribute__((__always_inline__)) inline + #endif + + #if __has_attribute(__noinline__) + #undef arma_noinline + #define arma_noinline __attribute__((__noinline__)) + #endif + + #if __has_attribute(__hot__) + #undef arma_hot + #define arma_hot __attribute__((__hot__)) + #endif + + #if __has_attribute(__cold__) + #undef arma_cold + #define arma_cold __attribute__((__cold__)) + #elif __has_attribute(__minsize__) + #undef arma_cold + #define arma_cold __attribute__((__minsize__)) + #endif + + #if defined(__has_builtin) && __has_builtin(__builtin_assume_aligned) + #undef ARMA_HAVE_GCC_ASSUME_ALIGNED + #define ARMA_HAVE_GCC_ASSUME_ALIGNED + #endif + +#endif + + +#if defined(__INTEL_COMPILER) + + #if (__INTEL_COMPILER == 9999) + #error "*** newer compiler required ***" + #endif + + #if (__INTEL_COMPILER < 1500) + #error "*** newer compiler required ***" + #endif + + #undef ARMA_HAVE_GCC_ASSUME_ALIGNED + #undef ARMA_HAVE_ICC_ASSUME_ALIGNED + #define ARMA_HAVE_ICC_ASSUME_ALIGNED + +#endif + + +#if defined(_MSC_VER) + + #if (_MSC_VER < 1900) + #error "*** newer compiler required ***" + #endif + + #undef arma_deprecated + #define arma_deprecated __declspec(deprecated) + // #undef arma_inline + // #define arma_inline __forceinline inline + + #pragma warning(push) + + #pragma warning(disable: 4127) // conditional expression is constant + #pragma warning(disable: 4180) // qualifier has no meaning + #pragma warning(disable: 4244) // possible loss of data when converting types (see also 4305) + #pragma warning(disable: 4510) // default constructor could not be generated + #pragma warning(disable: 4511) // copy constructor can't be generated + #pragma warning(disable: 4512) // assignment operator can't be generated + #pragma warning(disable: 4513) // destructor can't be generated + #pragma warning(disable: 4514) // unreferenced inline function has been removed + #pragma warning(disable: 4519) // default template args are only allowed on a class template (C++11) + #pragma warning(disable: 4522) // multiple assignment operators specified + #pragma warning(disable: 4623) // default constructor can't be generated + #pragma warning(disable: 4624) // destructor can't be generated + #pragma warning(disable: 4625) // copy constructor can't be generated + #pragma warning(disable: 4626) // assignment operator can't be generated + #pragma warning(disable: 4702) // unreachable code + #pragma warning(disable: 4710) // function not inlined + #pragma warning(disable: 4711) // call was inlined + #pragma warning(disable: 4714) // __forceinline can't be inlined + #pragma warning(disable: 4800) // value forced to bool + + // NOTE: also possible to disable 4146 (unary minus operator applied to unsigned type, result still unsigned) + + #if defined(ARMA_HAVE_CXX17) + #pragma warning(disable: 26812) // unscoped enum + #pragma warning(disable: 26819) // unannotated fallthrough + #endif + + // #if (_MANAGED == 1) || (_M_CEE == 1) + // + // // don't do any alignment when compiling in "managed code" mode + // + // #undef arma_aligned + // #define arma_aligned + // + // #undef arma_align_mem + // #define arma_align_mem + // + // #elif (_MSC_VER >= 1700) + // + // #undef arma_align_mem + // #define arma_align_mem __declspec(align(16)) + // + // #define ARMA_HAVE_ALIGNED_ATTRIBUTE + // + // // disable warnings: "structure was padded due to __declspec(align(16))" + // #pragma warning(disable: 4324) + // + // #endif + +#endif + + +#if defined(__SUNPRO_CC) + + // http://www.oracle.com/technetwork/server-storage/solarisstudio/training/index-jsp-141991.html + // http://www.oracle.com/technetwork/server-storage/solarisstudio/documentation/cplusplus-faq-355066.html + + #if (__SUNPRO_CC < 0x5140) + #error "*** newer compiler required ***" + #endif + +#endif + + +#if defined(ARMA_HAVE_CXX14) + #undef arma_deprecated + #define arma_deprecated [[deprecated]] + + #undef arma_frown + #define arma_frown(msg) [[deprecated(msg)]] +#endif + + +#if defined(ARMA_HAVE_CXX17) + #undef arma_warn_unused + #define arma_warn_unused [[nodiscard]] +#endif + + +#if !defined(ARMA_DONT_USE_OPENMP) + #if (defined(_OPENMP) && (_OPENMP >= 201107)) + #undef ARMA_USE_OPENMP + #define ARMA_USE_OPENMP + #endif +#endif + + +#if ( defined(ARMA_USE_OPENMP) && (!defined(_OPENMP) || (defined(_OPENMP) && (_OPENMP < 201107))) ) + // OpenMP 3.0 required for parallelisation of loops with unsigned integers + // OpenMP 3.1 required for atomic read and atomic write + #undef ARMA_USE_OPENMP + #undef ARMA_PRINT_OPENMP_WARNING + #define ARMA_PRINT_OPENMP_WARNING +#endif + + +#if defined(ARMA_PRINT_OPENMP_WARNING) && !defined(ARMA_DONT_PRINT_OPENMP_WARNING) + #pragma message ("WARNING: use of OpenMP disabled; compiler support for OpenMP 3.1+ not detected") + + #if (defined(_OPENMP) && (_OPENMP < 201107)) + #pragma message ("NOTE: your compiler has an outdated version of OpenMP") + #pragma message ("NOTE: consider upgrading to a better compiler") + #endif +#endif + + +#if defined(ARMA_USE_OPENMP) + #if (defined(ARMA_GCC_VERSION) && (ARMA_GCC_VERSION < 50400)) + // due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=57580 + #undef ARMA_USE_OPENMP + #if !defined(ARMA_DONT_PRINT_OPENMP_WARNING) + #pragma message ("WARNING: use of OpenMP disabled due to compiler bug in gcc <= 5.3") + #endif + #endif +#endif + + +#if (defined(__FAST_MATH__) || (defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)) || defined(_M_FP_FAST)) + #undef ARMA_FAST_MATH + #define ARMA_FAST_MATH +#endif + + +#if defined(ARMA_FAST_MATH) && !defined(ARMA_DONT_PRINT_FAST_MATH_WARNING) + #pragma message ("WARNING: compiler is in fast math mode; some functions may be unreliable.") + #pragma message ("WARNING: to suppress this warning and related warnings,") + #pragma message ("WARNING: #define ARMA_DONT_PRINT_FAST_MATH_WARNING before #include ") +#endif + + +#if ( (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) && (!defined(__MINGW32__) && !defined(__MINGW64__)) ) + #undef ARMA_PRINT_EXCEPTIONS_INTERNAL + #define ARMA_PRINT_EXCEPTIONS_INTERNAL +#endif + + +#if (defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) && !defined(ARMA_ALIEN_MEM_FREE_FUNCTION)) || (!defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) && defined(ARMA_ALIEN_MEM_FREE_FUNCTION)) + #error "*** both ARMA_ALIEN_MEM_ALLOC_FUNCTION and ARMA_ALIEN_MEM_FREE_FUNCTION must be defined ***" +#endif + + + +// cleanup + +#undef ARMA_DETECTED_FAKE_GCC +#undef ARMA_DETECTED_FAKE_CLANG +#undef ARMA_GCC_VERSION +#undef ARMA_PRINT_OPENMP_WARNING + + + +// undefine conflicting macros + +#if defined(log2) + #undef log2 + #pragma message ("WARNING: undefined conflicting 'log2' macro") +#endif + +#if defined(check) + #undef check + #pragma message ("WARNING: undefined conflicting 'check' macro") +#endif + +#if defined(min) || defined(max) + #undef min + #undef max + #pragma message ("WARNING: undefined conflicting 'min' and/or 'max' macros") +#endif + +// https://sourceware.org/bugzilla/show_bug.cgi?id=19239 +#undef minor +#undef major + + +// optionally allow disabling of compile-time deprecation messages (not recommended) +// NOTE: option 'ARMA_IGNORE_DEPRECATED_MARKER' will be removed +// NOTE: disabling deprecation messages is counter-productive + +#if defined(ARMA_IGNORE_DEPRECATED_MARKER) && (!defined(ARMA_DONT_IGNORE_DEPRECATED_MARKER)) && (!defined(ARMA_EXTRA_DEBUG)) + #undef arma_deprecated + #define arma_deprecated + + #undef arma_frown + #define arma_frown(msg) +#endif diff --git a/src/armadillo/include/armadillo_bits/compiler_setup_post.hpp b/src/armadillo/include/armadillo_bits/compiler_setup_post.hpp new file mode 100644 index 0000000..6274b7e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/compiler_setup_post.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if defined(_MSC_VER) + + #pragma warning(pop) + +#endif diff --git a/src/armadillo/include/armadillo_bits/cond_rel_bones.hpp b/src/armadillo/include/armadillo_bits/cond_rel_bones.hpp new file mode 100644 index 0000000..a160d26 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/cond_rel_bones.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup cond_rel +//! @{ + + +// +// for preventing pedantic compiler warnings + +template +class cond_rel + { + public: + + template arma_inline static bool lt(const eT A, const eT B); + template arma_inline static bool gt(const eT A, const eT B); + + template arma_inline static bool leq(const eT A, const eT B); + template arma_inline static bool geq(const eT A, const eT B); + + template arma_inline static eT make_neg(const eT val); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/cond_rel_meat.hpp b/src/armadillo/include/armadillo_bits/cond_rel_meat.hpp new file mode 100644 index 0000000..a285774 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/cond_rel_meat.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup cond_rel +//! @{ + + + +template<> +template +arma_inline +bool +cond_rel::lt(const eT A, const eT B) + { + return (A < B); + } + + + +template<> +template +arma_inline +bool +cond_rel::lt(const eT, const eT) + { + return false; + } + + + +template<> +template +arma_inline +bool +cond_rel::gt(const eT A, const eT B) + { + return (A > B); + } + + + +template<> +template +arma_inline +bool +cond_rel::gt(const eT, const eT) + { + return false; + } + + + +template<> +template +arma_inline +bool +cond_rel::leq(const eT A, const eT B) + { + return (A <= B); + } + + + +template<> +template +arma_inline +bool +cond_rel::leq(const eT, const eT) + { + return false; + } + + + +template<> +template +arma_inline +bool +cond_rel::geq(const eT A, const eT B) + { + return (A >= B); + } + + + +template<> +template +arma_inline +bool +cond_rel::geq(const eT, const eT) + { + return false; + } + + + +template<> +template +arma_inline +eT +cond_rel::make_neg(const eT val) + { + return -val; + } + + + +template<> +template +arma_inline +eT +cond_rel::make_neg(const eT) + { + return eT(0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/config.hpp b/src/armadillo/include/armadillo_bits/config.hpp new file mode 100644 index 0000000..6d7874a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/config.hpp @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if !defined(ARMA_WARN_LEVEL) + #define ARMA_WARN_LEVEL 2 +#endif +//// The level of warning messages printed to ARMA_CERR_STREAM. +//// Must be an integer >= 0. The default value is 2. +//// 0 = no warnings; generally not recommended +//// 1 = only critical warnings about arguments and/or data which are likely to lead to incorrect results +//// 2 = as per level 1, and warnings about poorly conditioned systems (low rcond) detected by solve(), spsolve(), etc +//// 3 = as per level 2, and warnings about failed decompositions, failed saving/loading, etc + +// #define ARMA_USE_WRAPPER +//// Comment out the above line if you prefer to directly link with BLAS, LAPACK, etc +//// instead of the Armadillo runtime library. +//// You will need to link your programs directly with -lopenblas -llapack instead of -larmadillo + +#if !defined(ARMA_USE_LAPACK) +#define ARMA_USE_LAPACK +//// Comment out the above line if you don't have LAPACK or a high-speed replacement for LAPACK, +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. +//// LAPACK is required for matrix decompositions (eg. SVD) and matrix inverse. +#endif + +#if !defined(ARMA_USE_BLAS) +#define ARMA_USE_BLAS +//// Comment out the above line if you don't have BLAS or a high-speed replacement for BLAS, +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. +//// BLAS is used for matrix multiplication. +//// Without BLAS, matrix multiplication will still work, but might be slower. +#endif + +#if !defined(ARMA_USE_NEWARP) +#define ARMA_USE_NEWARP +//// Uncomment the above line to enable the built-in partial emulation of ARPACK. +//// This is used for eigen decompositions of real (non-complex) sparse matrices, eg. eigs_sym(), svds() +#endif + +#if !defined(ARMA_USE_ARPACK) +// #define ARMA_USE_ARPACK +//// Uncomment the above line if you have ARPACK or a high-speed replacement for ARPACK. +//// ARPACK is required for eigen decompositions of complex sparse matrices +#endif + +#if !defined(ARMA_USE_SUPERLU) +// #define ARMA_USE_SUPERLU +//// Uncomment the above line if you have SuperLU. +//// SuperLU is used for solving sparse linear systems via spsolve() +//// Caveat: only SuperLU version 5.2 can be used! +#endif + +#if !defined(ARMA_SUPERLU_INCLUDE_DIR) +// #define ARMA_SUPERLU_INCLUDE_DIR /usr/include/ +//// If you're using SuperLU and want to explicitly include the SuperLU headers, +//// uncomment the above define and specify the appropriate include directory. +//// Make sure the directory has a trailing / +#endif + +#if !defined(ARMA_USE_ATLAS) +// #define ARMA_USE_ATLAS +//// NOTE: support for ATLAS is deprecated and will be removed. +#endif + +#if !defined(ARMA_USE_HDF5) +// #define ARMA_USE_HDF5 +//// Uncomment the above line to allow the ability to save and load matrices stored in HDF5 format; +//// the hdf5.h header file must be available on your system, +//// and you will need to link with the hdf5 library (eg. -lhdf5) +#endif + +#if !defined(ARMA_USE_FFTW3) +// #define ARMA_USE_FFTW3 +//// Uncomment the above line to allow the use of the FFTW3 library by fft() and ifft() functions; +//// you will need to link with the FFTW3 library (eg. -lfftw3) +#endif + +#if defined(ARMA_USE_FFTW) + #error "use ARMA_USE_FFTW3 instead of ARMA_USE_FFTW" +#endif + +// #define ARMA_BLAS_CAPITALS +//// Uncomment the above line if your BLAS and LAPACK libraries have capitalised function names + +#define ARMA_BLAS_UNDERSCORE +//// Uncomment the above line if your BLAS and LAPACK libraries have function names with a trailing underscore. +//// Conversely, comment it out if the function names don't have a trailing underscore. + +// #define ARMA_BLAS_LONG +//// Uncomment the above line if your BLAS and LAPACK libraries use "long" instead of "int" + +// #define ARMA_BLAS_LONG_LONG +//// Uncomment the above line if your BLAS and LAPACK libraries use "long long" instead of "int" + +// #define ARMA_BLAS_NOEXCEPT +//// Uncomment the above line if you require BLAS functions to have the 'noexcept' specification + +// #define ARMA_LAPACK_NOEXCEPT +//// Uncomment the above line if you require LAPACK functions to have the 'noexcept' specification + +#define ARMA_USE_FORTRAN_HIDDEN_ARGS +//// Comment out the above line to call BLAS and LAPACK functions without using so-called "hidden" arguments. +//// Fortran functions (compiled without a BIND(C) declaration) that have char arguments +//// (like many BLAS and LAPACK functions) also have associated "hidden" arguments. +//// For each char argument, the corresponding "hidden" argument specifies the number of characters. +//// These "hidden" arguments are typically tacked onto the end of function definitions. + +// #define ARMA_USE_TBB_ALLOC +//// Uncomment the above line to use Intel TBB scalable_malloc() and scalable_free() instead of standard malloc() and free() + +// #define ARMA_USE_MKL_ALLOC +//// Uncomment the above line to use Intel MKL mkl_malloc() and mkl_free() instead of standard malloc() and free() + +// #define ARMA_USE_MKL_TYPES +//// Uncomment the above line to use Intel MKL types for complex numbers. +//// You will need to include appropriate MKL headers before the Armadillo header. +//// You may also need to enable or disable the following options: +//// ARMA_BLAS_LONG, ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS + +#if !defined(ARMA_USE_OPENMP) +// #define ARMA_USE_OPENMP +//// Uncomment the above line to forcefully enable use of OpenMP for parallelisation. +//// Note that ARMA_USE_OPENMP is automatically enabled when a compiler supporting OpenMP 3.1 is detected. +#endif + +#if !defined(ARMA_64BIT_WORD) +// #define ARMA_64BIT_WORD +//// Uncomment the above line if you require matrices/vectors capable of holding more than 4 billion elements. +//// Note that ARMA_64BIT_WORD is automatically enabled when std::size_t has 64 bits and ARMA_32BIT_WORD is not defined. +#endif + +#if !defined(ARMA_OPTIMISE_BAND) + #define ARMA_OPTIMISE_BAND + //// Comment out the above line to disable optimised handling + //// of band matrices by solve() and chol() +#endif + +#if !defined(ARMA_OPTIMISE_SYM) + #define ARMA_OPTIMISE_SYM + //// Comment out the above line to disable optimised handling + //// of symmetric/hermitian matrices by various functions: + //// solve(), inv(), pinv(), expmat(), logmat(), sqrtmat(), rcond(), rank() +#endif + +#if !defined(ARMA_OPTIMISE_INVEXPR) + #define ARMA_OPTIMISE_INVEXPR + //// Comment out the above line to disable optimised handling + //// of inv() and inv_sympd() within compound expressions +#endif + +#if !defined(ARMA_CHECK_NONFINITE) + #define ARMA_CHECK_NONFINITE + //// Comment out the above line to disable checking for nonfinite matrices +#endif + +#if !defined(ARMA_MAT_PREALLOC) + #define ARMA_MAT_PREALLOC 16 +#endif +//// This is the number of preallocated elements used by matrices and vectors; +//// it must be an integer that is at least 1. +//// If you mainly use lots of very small vectors (eg. <= 4 elements), +//// change the number to the size of your vectors. + +#if !defined(ARMA_OPENMP_THRESHOLD) + #define ARMA_OPENMP_THRESHOLD 320 +#endif +//// The minimum number of elements in a matrix to allow OpenMP based parallelisation; +//// it must be an integer that is at least 1. + +#if !defined(ARMA_OPENMP_THREADS) + #define ARMA_OPENMP_THREADS 8 +#endif +//// The maximum number of threads to use for OpenMP based parallelisation; +//// it must be an integer that is at least 1. + +// #define ARMA_NO_DEBUG +//// Uncomment the above line to disable all run-time checks. NOT RECOMMENDED. +//// It is strongly recommended that run-time checks are enabled during development, +//// as this greatly aids in finding mistakes in your code. + +// #define ARMA_EXTRA_DEBUG +//// Uncomment the above line to see the function traces of how Armadillo evaluates expressions. +//// This is mainly useful for debugging of the library. + + +#if defined(ARMA_EXTRA_DEBUG) + #undef ARMA_NO_DEBUG + #undef ARMA_WARN_LEVEL + #define ARMA_WARN_LEVEL 3 +#endif + + +#if defined(ARMA_DEFAULT_OSTREAM) + #pragma message ("WARNING: support for ARMA_DEFAULT_OSTREAM is deprecated and will be removed;") + #pragma message ("WARNING: use ARMA_COUT_STREAM and ARMA_CERR_STREAM instead") +#endif + + +#if !defined(ARMA_COUT_STREAM) + #if defined(ARMA_DEFAULT_OSTREAM) + // for compatibility with earlier versions of Armadillo + #define ARMA_COUT_STREAM ARMA_DEFAULT_OSTREAM + #else + #define ARMA_COUT_STREAM std::cout + #endif +#endif + +#if !defined(ARMA_CERR_STREAM) + #if defined(ARMA_DEFAULT_OSTREAM) + // for compatibility with earlier versions of Armadillo + #define ARMA_CERR_STREAM ARMA_DEFAULT_OSTREAM + #else + #define ARMA_CERR_STREAM std::cerr + #endif +#endif + + +#if !defined(ARMA_PRINT_EXCEPTIONS) + // #define ARMA_PRINT_EXCEPTIONS + #if defined(ARMA_PRINT_EXCEPTIONS_INTERNAL) + #undef ARMA_PRINT_EXCEPTIONS + #define ARMA_PRINT_EXCEPTIONS + #endif +#endif + +#if defined(ARMA_DONT_USE_LAPACK) + #undef ARMA_USE_LAPACK +#endif + +#if defined(ARMA_DONT_USE_BLAS) + #undef ARMA_USE_BLAS +#endif + +#if defined(ARMA_DONT_USE_NEWARP) || !defined(ARMA_USE_LAPACK) + #undef ARMA_USE_NEWARP +#endif + +#if defined(ARMA_DONT_USE_ARPACK) + #undef ARMA_USE_ARPACK +#endif + +#if defined(ARMA_DONT_USE_SUPERLU) + #undef ARMA_USE_SUPERLU + #undef ARMA_SUPERLU_INCLUDE_DIR +#endif + +#if defined(ARMA_DONT_USE_ATLAS) + #undef ARMA_USE_ATLAS +#endif + +#if defined(ARMA_DONT_USE_HDF5) + #undef ARMA_USE_HDF5 +#endif + +#if defined(ARMA_DONT_USE_FFTW3) + #undef ARMA_USE_FFTW3 +#endif + +#if defined(ARMA_DONT_USE_WRAPPER) + #undef ARMA_USE_WRAPPER +#endif + +#if defined(ARMA_DONT_USE_FORTRAN_HIDDEN_ARGS) + #undef ARMA_USE_FORTRAN_HIDDEN_ARGS +#endif + +#if !defined(ARMA_DONT_USE_STD_MUTEX) + // #define ARMA_DONT_USE_STD_MUTEX + //// Uncomment the above line to disable use of std::mutex +#endif + +// for compatibility with earlier versions of Armadillo +#if defined(ARMA_DONT_USE_CXX11_MUTEX) + #pragma message ("WARNING: support for ARMA_DONT_USE_CXX11_MUTEX is deprecated and will be removed;") + #pragma message ("WARNING: use ARMA_DONT_USE_STD_MUTEX instead") + #undef ARMA_DONT_USE_STD_MUTEX + #define ARMA_DONT_USE_STD_MUTEX +#endif + +#if defined(ARMA_DONT_USE_OPENMP) + #undef ARMA_USE_OPENMP +#endif + +#if defined(ARMA_32BIT_WORD) + #undef ARMA_64BIT_WORD +#endif + +#if defined(ARMA_DONT_OPTIMISE_BAND) || defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) + #undef ARMA_OPTIMISE_BAND +#endif + +#if defined(ARMA_DONT_OPTIMISE_SYM) || defined(ARMA_DONT_OPTIMISE_SYMPD) || defined(ARMA_DONT_OPTIMISE_SOLVE_SYMPD) + #undef ARMA_OPTIMISE_SYM +#endif + +#if defined(ARMA_DONT_OPTIMISE_INVEXPR) + #undef ARMA_OPTIMISE_INVEXPR +#endif + +#if defined(ARMA_DONT_CHECK_NONFINITE) + #undef ARMA_CHECK_NONFINITE +#endif + +#if defined(ARMA_DONT_PRINT_ERRORS) + #pragma message ("INFO: support for ARMA_DONT_PRINT_ERRORS option has been removed") + + #if defined(ARMA_PRINT_EXCEPTIONS) + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL and ARMA_DONT_PRINT_EXCEPTIONS options instead") + #else + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL option instead") + #endif + + #pragma message ("INFO: see the documentation for details") +#endif + +#if defined(ARMA_DONT_PRINT_EXCEPTIONS) + #undef ARMA_PRINT_EXCEPTIONS +#endif + +#if !defined(ARMA_DONT_ZERO_INIT) + // #define ARMA_DONT_ZERO_INIT + //// Uncomment the above line to disable initialising elements to zero during construction of dense matrices and cubes +#endif + +#if defined(ARMA_NO_CRIPPLED_LAPACK) + #undef ARMA_CRIPPLED_LAPACK +#endif + + +// if Armadillo was installed on this system via CMake and ARMA_USE_WRAPPER is not defined, +// ARMA_AUX_LIBS lists the libraries required by Armadillo on this system, and +// ARMA_AUX_INCDIRS lists the include directories required by Armadillo on this system. +// Do not use these unless you know what you are doing. +#define ARMA_AUX_LIBS +#define ARMA_AUX_INCDIRS diff --git a/src/armadillo/include/armadillo_bits/config.hpp.cmake b/src/armadillo/include/armadillo_bits/config.hpp.cmake new file mode 100644 index 0000000..4ac633b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/config.hpp.cmake @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if !defined(ARMA_WARN_LEVEL) + #define ARMA_WARN_LEVEL 2 +#endif +//// The level of warning messages printed to ARMA_CERR_STREAM. +//// Must be an integer >= 0. The default value is 2. +//// 0 = no warnings; generally not recommended +//// 1 = only critical warnings about arguments and/or data which are likely to lead to incorrect results +//// 2 = as per level 1, and warnings about poorly conditioned systems (low rcond) detected by solve(), spsolve(), etc +//// 3 = as per level 2, and warnings about failed decompositions, failed saving/loading, etc + +#cmakedefine ARMA_USE_WRAPPER +//// Comment out the above line if you prefer to directly link with BLAS, LAPACK, etc +//// instead of the Armadillo runtime library. +//// You will need to link your programs directly with -lopenblas -llapack instead of -larmadillo + +#if !defined(ARMA_USE_LAPACK) +#cmakedefine ARMA_USE_LAPACK +//// Comment out the above line if you don't have LAPACK or a high-speed replacement for LAPACK, +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. +//// LAPACK is required for matrix decompositions (eg. SVD) and matrix inverse. +#endif + +#if !defined(ARMA_USE_BLAS) +#cmakedefine ARMA_USE_BLAS +//// Comment out the above line if you don't have BLAS or a high-speed replacement for BLAS, +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. +//// BLAS is used for matrix multiplication. +//// Without BLAS, matrix multiplication will still work, but might be slower. +#endif + +#if !defined(ARMA_USE_NEWARP) +#define ARMA_USE_NEWARP +//// Uncomment the above line to enable the built-in partial emulation of ARPACK. +//// This is used for eigen decompositions of real (non-complex) sparse matrices, eg. eigs_sym(), svds() +#endif + +#if !defined(ARMA_USE_ARPACK) +#cmakedefine ARMA_USE_ARPACK +//// Uncomment the above line if you have ARPACK or a high-speed replacement for ARPACK. +//// ARPACK is required for eigen decompositions of complex sparse matrices +#endif + +#if !defined(ARMA_USE_SUPERLU) +#cmakedefine ARMA_USE_SUPERLU +//// Uncomment the above line if you have SuperLU. +//// SuperLU is used for solving sparse linear systems via spsolve() +//// Caveat: only SuperLU version 5.2 can be used! +#endif + +#if !defined(ARMA_SUPERLU_INCLUDE_DIR) +#define ARMA_SUPERLU_INCLUDE_DIR ${ARMA_SUPERLU_INCLUDE_DIR}/ +//// If you're using SuperLU and want to explicitly include the SuperLU headers, +//// uncomment the above define and specify the appropriate include directory. +//// Make sure the directory has a trailing / +#endif + +#if !defined(ARMA_USE_ATLAS) +#cmakedefine ARMA_USE_ATLAS +//// NOTE: support for ATLAS is deprecated and will be removed. +#endif + +#if !defined(ARMA_USE_HDF5) +// #define ARMA_USE_HDF5 +//// Uncomment the above line to allow the ability to save and load matrices stored in HDF5 format; +//// the hdf5.h header file must be available on your system, +//// and you will need to link with the hdf5 library (eg. -lhdf5) +#endif + +#if !defined(ARMA_USE_FFTW3) +// #define ARMA_USE_FFTW3 +//// Uncomment the above line to allow the use of the FFTW3 library by fft() and ifft() functions; +//// you will need to link with the FFTW3 library (eg. -lfftw3) +#endif + +#if defined(ARMA_USE_FFTW) + #error "use ARMA_USE_FFTW3 instead of ARMA_USE_FFTW" +#endif + +// #define ARMA_BLAS_CAPITALS +//// Uncomment the above line if your BLAS and LAPACK libraries have capitalised function names + +#define ARMA_BLAS_UNDERSCORE +//// Uncomment the above line if your BLAS and LAPACK libraries have function names with a trailing underscore. +//// Conversely, comment it out if the function names don't have a trailing underscore. + +// #define ARMA_BLAS_LONG +//// Uncomment the above line if your BLAS and LAPACK libraries use "long" instead of "int" + +// #define ARMA_BLAS_LONG_LONG +//// Uncomment the above line if your BLAS and LAPACK libraries use "long long" instead of "int" + +// #define ARMA_BLAS_NOEXCEPT +//// Uncomment the above line if you require BLAS functions to have the 'noexcept' specification + +// #define ARMA_LAPACK_NOEXCEPT +//// Uncomment the above line if you require LAPACK functions to have the 'noexcept' specification + +#define ARMA_USE_FORTRAN_HIDDEN_ARGS +//// Comment out the above line to call BLAS and LAPACK functions without using so-called "hidden" arguments. +//// Fortran functions (compiled without a BIND(C) declaration) that have char arguments +//// (like many BLAS and LAPACK functions) also have associated "hidden" arguments. +//// For each char argument, the corresponding "hidden" argument specifies the number of characters. +//// These "hidden" arguments are typically tacked onto the end of function definitions. + +// #define ARMA_USE_TBB_ALLOC +//// Uncomment the above line to use Intel TBB scalable_malloc() and scalable_free() instead of standard malloc() and free() + +// #define ARMA_USE_MKL_ALLOC +//// Uncomment the above line to use Intel MKL mkl_malloc() and mkl_free() instead of standard malloc() and free() + +// #define ARMA_USE_MKL_TYPES +//// Uncomment the above line to use Intel MKL types for complex numbers. +//// You will need to include appropriate MKL headers before the Armadillo header. +//// You may also need to enable or disable the following options: +//// ARMA_BLAS_LONG, ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS + +#if !defined(ARMA_USE_OPENMP) +// #define ARMA_USE_OPENMP +//// Uncomment the above line to forcefully enable use of OpenMP for parallelisation. +//// Note that ARMA_USE_OPENMP is automatically enabled when a compiler supporting OpenMP 3.1 is detected. +#endif + +#if !defined(ARMA_64BIT_WORD) +// #define ARMA_64BIT_WORD +//// Uncomment the above line if you require matrices/vectors capable of holding more than 4 billion elements. +//// Note that ARMA_64BIT_WORD is automatically enabled when std::size_t has 64 bits and ARMA_32BIT_WORD is not defined. +#endif + +#if !defined(ARMA_OPTIMISE_BAND) + #define ARMA_OPTIMISE_BAND + //// Comment out the above line to disable optimised handling + //// of band matrices by solve() and chol() +#endif + +#if !defined(ARMA_OPTIMISE_SYM) + #define ARMA_OPTIMISE_SYM + //// Comment out the above line to disable optimised handling + //// of symmetric/hermitian matrices by various functions: + //// solve(), inv(), pinv(), expmat(), logmat(), sqrtmat(), rcond(), rank() +#endif + +#if !defined(ARMA_OPTIMISE_INVEXPR) + #define ARMA_OPTIMISE_INVEXPR + //// Comment out the above line to disable optimised handling + //// of inv() and inv_sympd() within compound expressions +#endif + +#if !defined(ARMA_CHECK_NONFINITE) + #define ARMA_CHECK_NONFINITE + //// Comment out the above line to disable checking for nonfinite matrices +#endif + +#if !defined(ARMA_MAT_PREALLOC) + #define ARMA_MAT_PREALLOC 16 +#endif +//// This is the number of preallocated elements used by matrices and vectors; +//// it must be an integer that is at least 1. +//// If you mainly use lots of very small vectors (eg. <= 4 elements), +//// change the number to the size of your vectors. + +#if !defined(ARMA_OPENMP_THRESHOLD) + #define ARMA_OPENMP_THRESHOLD 320 +#endif +//// The minimum number of elements in a matrix to allow OpenMP based parallelisation; +//// it must be an integer that is at least 1. + +#if !defined(ARMA_OPENMP_THREADS) + #define ARMA_OPENMP_THREADS 8 +#endif +//// The maximum number of threads to use for OpenMP based parallelisation; +//// it must be an integer that is at least 1. + +// #define ARMA_NO_DEBUG +//// Uncomment the above line to disable all run-time checks. NOT RECOMMENDED. +//// It is strongly recommended that run-time checks are enabled during development, +//// as this greatly aids in finding mistakes in your code. + +// #define ARMA_EXTRA_DEBUG +//// Uncomment the above line to see the function traces of how Armadillo evaluates expressions. +//// This is mainly useful for debugging of the library. + + +#if defined(ARMA_EXTRA_DEBUG) + #undef ARMA_NO_DEBUG + #undef ARMA_WARN_LEVEL + #define ARMA_WARN_LEVEL 3 +#endif + + +#if defined(ARMA_DEFAULT_OSTREAM) + #pragma message ("WARNING: support for ARMA_DEFAULT_OSTREAM is deprecated and will be removed;") + #pragma message ("WARNING: use ARMA_COUT_STREAM and ARMA_CERR_STREAM instead") +#endif + + +#if !defined(ARMA_COUT_STREAM) + #if defined(ARMA_DEFAULT_OSTREAM) + // for compatibility with earlier versions of Armadillo + #define ARMA_COUT_STREAM ARMA_DEFAULT_OSTREAM + #else + #define ARMA_COUT_STREAM std::cout + #endif +#endif + +#if !defined(ARMA_CERR_STREAM) + #if defined(ARMA_DEFAULT_OSTREAM) + // for compatibility with earlier versions of Armadillo + #define ARMA_CERR_STREAM ARMA_DEFAULT_OSTREAM + #else + #define ARMA_CERR_STREAM std::cerr + #endif +#endif + + +#if !defined(ARMA_PRINT_EXCEPTIONS) + // #define ARMA_PRINT_EXCEPTIONS + #if defined(ARMA_PRINT_EXCEPTIONS_INTERNAL) + #undef ARMA_PRINT_EXCEPTIONS + #define ARMA_PRINT_EXCEPTIONS + #endif +#endif + +#if defined(ARMA_DONT_USE_LAPACK) + #undef ARMA_USE_LAPACK +#endif + +#if defined(ARMA_DONT_USE_BLAS) + #undef ARMA_USE_BLAS +#endif + +#if defined(ARMA_DONT_USE_NEWARP) || !defined(ARMA_USE_LAPACK) + #undef ARMA_USE_NEWARP +#endif + +#if defined(ARMA_DONT_USE_ARPACK) + #undef ARMA_USE_ARPACK +#endif + +#if defined(ARMA_DONT_USE_SUPERLU) + #undef ARMA_USE_SUPERLU + #undef ARMA_SUPERLU_INCLUDE_DIR +#endif + +#if defined(ARMA_DONT_USE_ATLAS) + #undef ARMA_USE_ATLAS +#endif + +#if defined(ARMA_DONT_USE_HDF5) + #undef ARMA_USE_HDF5 +#endif + +#if defined(ARMA_DONT_USE_FFTW3) + #undef ARMA_USE_FFTW3 +#endif + +#if defined(ARMA_DONT_USE_WRAPPER) + #undef ARMA_USE_WRAPPER +#endif + +#if defined(ARMA_DONT_USE_FORTRAN_HIDDEN_ARGS) + #undef ARMA_USE_FORTRAN_HIDDEN_ARGS +#endif + +#if !defined(ARMA_DONT_USE_STD_MUTEX) + // #define ARMA_DONT_USE_STD_MUTEX + //// Uncomment the above line to disable use of std::mutex +#endif + +// for compatibility with earlier versions of Armadillo +#if defined(ARMA_DONT_USE_CXX11_MUTEX) + #pragma message ("WARNING: support for ARMA_DONT_USE_CXX11_MUTEX is deprecated and will be removed;") + #pragma message ("WARNING: use ARMA_DONT_USE_STD_MUTEX instead") + #undef ARMA_DONT_USE_STD_MUTEX + #define ARMA_DONT_USE_STD_MUTEX +#endif + +#if defined(ARMA_DONT_USE_OPENMP) + #undef ARMA_USE_OPENMP +#endif + +#if defined(ARMA_32BIT_WORD) + #undef ARMA_64BIT_WORD +#endif + +#if defined(ARMA_DONT_OPTIMISE_BAND) || defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) + #undef ARMA_OPTIMISE_BAND +#endif + +#if defined(ARMA_DONT_OPTIMISE_SYM) || defined(ARMA_DONT_OPTIMISE_SYMPD) || defined(ARMA_DONT_OPTIMISE_SOLVE_SYMPD) + #undef ARMA_OPTIMISE_SYM +#endif + +#if defined(ARMA_DONT_OPTIMISE_INVEXPR) + #undef ARMA_OPTIMISE_INVEXPR +#endif + +#if defined(ARMA_DONT_CHECK_NONFINITE) + #undef ARMA_CHECK_NONFINITE +#endif + +#if defined(ARMA_DONT_PRINT_ERRORS) + #pragma message ("INFO: support for ARMA_DONT_PRINT_ERRORS option has been removed") + + #if defined(ARMA_PRINT_EXCEPTIONS) + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL and ARMA_DONT_PRINT_EXCEPTIONS options instead") + #else + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL option instead") + #endif + + #pragma message ("INFO: see the documentation for details") +#endif + +#if defined(ARMA_DONT_PRINT_EXCEPTIONS) + #undef ARMA_PRINT_EXCEPTIONS +#endif + +#if !defined(ARMA_DONT_ZERO_INIT) + // #define ARMA_DONT_ZERO_INIT + //// Uncomment the above line to disable initialising elements to zero during construction of dense matrices and cubes +#endif + +#if defined(ARMA_NO_CRIPPLED_LAPACK) + #undef ARMA_CRIPPLED_LAPACK +#endif + + +// if Armadillo was installed on this system via CMake and ARMA_USE_WRAPPER is not defined, +// ARMA_AUX_LIBS lists the libraries required by Armadillo on this system, and +// ARMA_AUX_INCDIRS lists the include directories required by Armadillo on this system. +// Do not use these unless you know what you are doing. +#define ARMA_AUX_LIBS ${ARMA_LIBS} +#define ARMA_AUX_INCDIRS ${CMAKE_REQUIRED_INCLUDES} diff --git a/src/armadillo/include/armadillo_bits/constants.hpp b/src/armadillo/include/armadillo_bits/constants.hpp new file mode 100644 index 0000000..9adf9ea --- /dev/null +++ b/src/armadillo/include/armadillo_bits/constants.hpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup constants +//! @{ + + +namespace priv + { + class Datum_helper + { + public: + + template + static + typename arma_real_only::result + nan(typename arma_real_only::result* junk = nullptr) + { + arma_ignore(junk); + + return (std::numeric_limits::has_quiet_NaN) ? eT(std::numeric_limits::quiet_NaN()) : eT(0); + } + + + template + static + typename arma_cx_only::result + nan(typename arma_cx_only::result* junk = nullptr) + { + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + return eT( Datum_helper::nan(), Datum_helper::nan() ); + } + + + template + static + typename arma_integral_only::result + nan(typename arma_integral_only::result* junk = nullptr) + { + arma_ignore(junk); + + return eT(0); + } + + + template + static + typename arma_real_only::result + inf(typename arma_real_only::result* junk = nullptr) + { + arma_ignore(junk); + + return (std::numeric_limits::has_infinity) ? eT(std::numeric_limits::infinity()) : eT(std::numeric_limits::max()); + } + + + template + static + typename arma_cx_only::result + inf(typename arma_cx_only::result* junk = nullptr) + { + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + return eT( Datum_helper::inf(), Datum_helper::inf() ); + } + + + template + static + typename arma_integral_only::result + inf(typename arma_integral_only::result* junk = nullptr) + { + arma_ignore(junk); + + return std::numeric_limits::max(); + } + }; + } + + + +//! various constants. +//! Physical constants taken from NIST 2018 CODATA values, and some from WolframAlpha (values provided as of 2009-06-23) +//! http://physics.nist.gov/cuu/Constants +//! http://www.wolframalpha.com +//! See also http://en.wikipedia.org/wiki/Physical_constant + + +template +class Datum + { + public: + + static const eT pi; //!< ratio of any circle's circumference to its diameter + static const eT tau; //!< ratio of any circle's circumference to its radius (replacement of 2*pi) + static const eT e; //!< base of the natural logarithm + static const eT euler; //!< Euler's constant, aka Euler-Mascheroni constant + static const eT gratio; //!< golden ratio + static const eT sqrt2; //!< square root of 2 + static const eT sqrt2pi; //!< square root of 2*pi + static const eT log_sqrt2pi; //!< log of square root of 2*pi + static const eT eps; //!< the difference between 1 and the least value greater than 1 that is representable + static const eT log_min; //!< log of the minimum representable value + static const eT log_max; //!< log of the maximum representable value + static const eT nan; //!< "not a number" + static const eT inf; //!< infinity + + // + + static const eT m_u; //!< atomic mass constant (in kg) + static const eT N_A; //!< Avogadro constant + static const eT k; //!< Boltzmann constant (in joules per kelvin) + static const eT k_evk; //!< Boltzmann constant (in eV/K) + static const eT a_0; //!< Bohr radius (in meters) + static const eT mu_B; //!< Bohr magneton + static const eT Z_0; //!< characteristic impedance of vacuum (in ohms) + static const eT G_0; //!< conductance quantum (in siemens) + static const eT k_e; //!< Coulomb's constant (in meters per farad) + static const eT eps_0; //!< electric constant (in farads per meter) + static const eT m_e; //!< electron mass (in kg) + static const eT eV; //!< electron volt (in joules) + static const eT ec; //!< elementary charge (in coulombs) + static const eT F; //!< Faraday constant (in coulombs) + static const eT alpha; //!< fine-structure constant + static const eT alpha_inv; //!< inverse fine-structure constant + static const eT K_J; //!< Josephson constant + static const eT mu_0; //!< magnetic constant (in henries per meter) + static const eT phi_0; //!< magnetic flux quantum (in webers) + static const eT R; //!< molar gas constant (in joules per mole kelvin) + static const eT G; //!< Newtonian constant of gravitation (in newton square meters per kilogram squared) + static const eT h; //!< Planck constant (in joule seconds) + static const eT h_bar; //!< Planck constant over 2 pi, aka reduced Planck constant (in joule seconds) + static const eT m_p; //!< proton mass (in kg) + static const eT R_inf; //!< Rydberg constant (in reciprocal meters) + static const eT c_0; //!< speed of light in vacuum (in meters per second) + static const eT sigma; //!< Stefan-Boltzmann constant + static const eT R_k; //!< von Klitzing constant (in ohms) + static const eT b; //!< Wien wavelength displacement law constant + }; + + +// the long lengths of the constants are for future support of "long double" +// and any smart compiler that does high-precision computation at compile-time + +template const eT Datum::pi = eT(3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679); +template const eT Datum::tau = eT(6.2831853071795864769252867665590057683943387987502116419498891846156328125724179972560696506842341359); +template const eT Datum::e = eT(2.7182818284590452353602874713526624977572470936999595749669676277240766303535475945713821785251664274); +template const eT Datum::euler = eT(0.5772156649015328606065120900824024310421593359399235988057672348848677267776646709369470632917467495); +template const eT Datum::gratio = eT(1.6180339887498948482045868343656381177203091798057628621354486227052604628189024497072072041893911374); +template const eT Datum::sqrt2 = eT(1.4142135623730950488016887242096980785696718753769480731766797379907324784621070388503875343276415727); +template const eT Datum::sqrt2pi = eT(2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749465958383780572661); +template const eT Datum::log_sqrt2pi = eT(0.9189385332046727417803297364056176398613974736377834128171515404827656959272603976947432986359541976); +template const eT Datum::eps = std::numeric_limits::epsilon(); +template const eT Datum::log_min = std::log(std::numeric_limits::min()); +template const eT Datum::log_max = std::log(std::numeric_limits::max()); +template const eT Datum::nan = priv::Datum_helper::nan(); +template const eT Datum::inf = priv::Datum_helper::inf(); + +template const eT Datum::m_u = eT(1.66053906660e-27); +template const eT Datum::N_A = eT(6.02214076e23); +template const eT Datum::k = eT(1.380649e-23); +template const eT Datum::k_evk = eT(8.617333262e-5); +template const eT Datum::a_0 = eT(5.29177210903e-11); +template const eT Datum::mu_B = eT(9.2740100783e-24); +template const eT Datum::Z_0 = eT(376.730313668); +template const eT Datum::G_0 = eT(7.748091729e-5); +template const eT Datum::k_e = eT(8.9875517923e9); +template const eT Datum::eps_0 = eT(8.8541878128e-12); +template const eT Datum::m_e = eT(9.1093837015e-31); +template const eT Datum::eV = eT(1.602176634e-19); +template const eT Datum::ec = eT(1.602176634e-19); +template const eT Datum::F = eT(96485.33212); +template const eT Datum::alpha = eT(7.2973525693e-3); +template const eT Datum::alpha_inv = eT(137.035999084); +template const eT Datum::K_J = eT(483597.8484e9); +template const eT Datum::mu_0 = eT(1.25663706212e-6); +template const eT Datum::phi_0 = eT(2.067833848e-15); +template const eT Datum::R = eT(8.314462618); +template const eT Datum::G = eT(6.67430e-11); +template const eT Datum::h = eT(6.62607015e-34); +template const eT Datum::h_bar = eT(1.054571817e-34); +template const eT Datum::m_p = eT(1.67262192369e-27); +template const eT Datum::R_inf = eT(10973731.568160); +template const eT Datum::c_0 = eT(299792458.0); +template const eT Datum::sigma = eT(5.670374419e-8); +template const eT Datum::R_k = eT(25812.80745); +template const eT Datum::b = eT(2.897771955e-3); + + + +typedef Datum fdatum; +typedef Datum datum; + + + + +namespace priv + { + + template + static + constexpr + typename arma_real_only::result + most_neg() + { + return (std::numeric_limits::has_infinity) ? -(std::numeric_limits::infinity()) : std::numeric_limits::lowest(); + } + + + template + static + constexpr + typename arma_integral_only::result + most_neg() + { + return std::numeric_limits::lowest(); + } + + + template + static + constexpr + typename arma_real_only::result + most_pos() + { + return (std::numeric_limits::has_infinity) ? std::numeric_limits::infinity() : std::numeric_limits::max(); + } + + + template + static + constexpr + typename arma_integral_only::result + most_pos() + { + return std::numeric_limits::max(); + } + + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/constants_old.hpp b/src/armadillo/include/armadillo_bits/constants_old.hpp new file mode 100644 index 0000000..a2bc046 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/constants_old.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup constants_old +//! @{ + + +// DO NOT USE IN NEW CODE !!! +// the Math and Phy classes are kept for compatibility with old code; +// for new code, use the Datum class instead +// eg. instead of math::pi(), use datum::pi + +template +class Math + { + public: + + arma_frown("use datum::pi instead") static eT pi() { return eT(Datum::pi); } + arma_frown("use datum::e instead") static eT e() { return eT(Datum::e); } + arma_frown("use datum::euler instead") static eT euler() { return eT(Datum::euler); } + arma_frown("use datum::gratio instead") static eT gratio() { return eT(Datum::gratio); } + arma_frown("use datum::sqrt2 instead") static eT sqrt2() { return eT(Datum::sqrt2); } + arma_frown("use datum::eps instead") static eT eps() { return eT(Datum::eps); } + arma_frown("use datum::log_min instead") static eT log_min() { return eT(Datum::log_min); } + arma_frown("use datum::log_max instead") static eT log_max() { return eT(Datum::log_max); } + arma_frown("use datum::nan instead") static eT nan() { return eT(Datum::nan); } + arma_frown("use datum::inf instead") static eT inf() { return eT(Datum::inf); } + }; + + + +template +class Phy + { + public: + + arma_deprecated static eT m_u() { return eT(Datum::m_u); } + arma_deprecated static eT N_A() { return eT(Datum::N_A); } + arma_deprecated static eT k() { return eT(Datum::k); } + arma_deprecated static eT k_evk() { return eT(Datum::k_evk); } + arma_deprecated static eT a_0() { return eT(Datum::a_0); } + arma_deprecated static eT mu_B() { return eT(Datum::mu_B); } + arma_deprecated static eT Z_0() { return eT(Datum::Z_0); } + arma_deprecated static eT G_0() { return eT(Datum::G_0); } + arma_deprecated static eT k_e() { return eT(Datum::k_e); } + arma_deprecated static eT eps_0() { return eT(Datum::eps_0); } + arma_deprecated static eT m_e() { return eT(Datum::m_e); } + arma_deprecated static eT eV() { return eT(Datum::eV); } + arma_deprecated static eT e() { return eT(Datum::ec); } + arma_deprecated static eT F() { return eT(Datum::F); } + arma_deprecated static eT alpha() { return eT(Datum::alpha); } + arma_deprecated static eT alpha_inv() { return eT(Datum::alpha_inv); } + arma_deprecated static eT K_J() { return eT(Datum::K_J); } + arma_deprecated static eT mu_0() { return eT(Datum::mu_0); } + arma_deprecated static eT phi_0() { return eT(Datum::phi_0); } + arma_deprecated static eT R() { return eT(Datum::R); } + arma_deprecated static eT G() { return eT(Datum::G); } + arma_deprecated static eT h() { return eT(Datum::h); } + arma_deprecated static eT h_bar() { return eT(Datum::h_bar); } + arma_deprecated static eT m_p() { return eT(Datum::m_p); } + arma_deprecated static eT R_inf() { return eT(Datum::R_inf); } + arma_deprecated static eT c_0() { return eT(Datum::c_0); } + arma_deprecated static eT sigma() { return eT(Datum::sigma); } + arma_deprecated static eT R_k() { return eT(Datum::R_k); } + arma_deprecated static eT b() { return eT(Datum::b); } + }; + + + +typedef Math fmath; +typedef Math math; + +typedef Phy fphy; +typedef Phy phy; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/csv_name.hpp b/src/armadillo/include/armadillo_bits/csv_name.hpp new file mode 100644 index 0000000..c6a1df5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/csv_name.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diskio +//! @{ + + +namespace csv_opts + { + typedef unsigned int flag_type; + + struct opts + { + const flag_type flags; + + inline constexpr explicit opts(const flag_type in_flags); + + inline const opts operator+(const opts& rhs) const; + }; + + inline + constexpr + opts::opts(const flag_type in_flags) + : flags(in_flags) + {} + + inline + const opts + opts::operator+(const opts& rhs) const + { + const opts result( flags | rhs.flags ); + + return result; + } + + // The values below (eg. 1u << 0) are for internal Armadillo use only. + // The values can change without notice. + + static constexpr flag_type flag_none = flag_type(0 ); + static constexpr flag_type flag_trans = flag_type(1u << 0); + static constexpr flag_type flag_no_header = flag_type(1u << 1); + static constexpr flag_type flag_with_header = flag_type(1u << 2); + static constexpr flag_type flag_semicolon = flag_type(1u << 3); + static constexpr flag_type flag_strict = flag_type(1u << 4); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_trans : public opts { inline constexpr opts_trans() : opts(flag_trans ) {} }; + struct opts_no_header : public opts { inline constexpr opts_no_header() : opts(flag_no_header ) {} }; + struct opts_with_header : public opts { inline constexpr opts_with_header() : opts(flag_with_header) {} }; + struct opts_semicolon : public opts { inline constexpr opts_semicolon() : opts(flag_semicolon ) {} }; + struct opts_strict : public opts { inline constexpr opts_strict() : opts(flag_strict ) {} }; + + static constexpr opts_none none; + static constexpr opts_trans trans; + static constexpr opts_no_header no_header; + static constexpr opts_with_header with_header; + static constexpr opts_semicolon semicolon; + static constexpr opts_strict strict; + } + + +struct csv_name + { + typedef field header_type; + + const std::string filename; + const csv_opts::opts opts; + + header_type header_junk; + const header_type& header_ro; + header_type& header_rw; + + inline + csv_name(const std::string& in_filename) + : filename (in_filename ) + , opts (csv_opts::no_header) + , header_ro(header_junk ) + , header_rw(header_junk ) + {} + + inline + csv_name(const std::string& in_filename, const csv_opts::opts& in_opts) + : filename (in_filename ) + , opts (csv_opts::no_header + in_opts) + , header_ro(header_junk ) + , header_rw(header_junk ) + {} + + inline + csv_name(const std::string& in_filename, field& in_header) + : filename (in_filename ) + , opts (csv_opts::with_header) + , header_ro(in_header ) + , header_rw(in_header ) + {} + + inline + csv_name(const std::string& in_filename, const field& in_header) + : filename (in_filename ) + , opts (csv_opts::with_header) + , header_ro(in_header ) + , header_rw(header_junk ) + {} + + inline + csv_name(const std::string& in_filename, field& in_header, const csv_opts::opts& in_opts) + : filename (in_filename ) + , opts (csv_opts::with_header + in_opts) + , header_ro(in_header ) + , header_rw(in_header ) + {} + + inline + csv_name(const std::string& in_filename, const field& in_header, const csv_opts::opts& in_opts) + : filename (in_filename ) + , opts (csv_opts::with_header + in_opts) + , header_ro(in_header ) + , header_rw(header_junk ) + {} + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/debug.hpp b/src/armadillo/include/armadillo_bits/debug.hpp new file mode 100644 index 0000000..7a0b95c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/debug.hpp @@ -0,0 +1,1467 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup debug +//! @{ + + + +inline +std::ostream& +get_cout_stream() + { + return (ARMA_COUT_STREAM); + } + + + +inline +std::ostream& +get_cerr_stream() + { + return (ARMA_CERR_STREAM); + } + + + +arma_deprecated +inline +std::ostream& +get_stream_err1() + { + return get_cerr_stream(); + } + + + +arma_deprecated +inline +std::ostream& +get_stream_err2() + { + return get_cerr_stream(); + } + + + +arma_frown("this function does nothing; instead use ARMA_COUT_STREAM or ARMA_WARN_LEVEL; see documentation") +inline +void +set_cout_stream(const std::ostream&) + { + } + + + +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") +inline +void +set_cerr_stream(const std::ostream&) + { + } + + + +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") +inline +void +set_stream_err1(const std::ostream&) + { + } + + + +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") +inline +void +set_stream_err2(const std::ostream&) + { + } + + + +template +arma_frown("this function does nothing; instead use ARMA_COUT_STREAM or ARMA_WARN_LEVEL; see documentation") +inline +std::ostream& +arma_cout_stream(std::ostream*) + { + return (ARMA_COUT_STREAM); + } + + + +template +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") +inline +std::ostream& +arma_cerr_stream(std::ostream*) + { + return (ARMA_CERR_STREAM); + } + + + +//! print a message to get_cerr_stream() and throw logic_error exception +template +arma_cold +arma_noinline +static +void +arma_stop_logic_error(const T1& x) + { + #if defined(ARMA_PRINT_EXCEPTIONS) + { + get_cerr_stream() << "\nerror: " << x << std::endl; + } + #endif + + throw std::logic_error( std::string(x) ); + } + + + +arma_cold +arma_noinline +static +void +arma_stop_logic_error(const char* x, const char* y) + { + arma_stop_logic_error( std::string(x) + std::string(y) ); + } + + + +//! print a message to get_cerr_stream() and throw out_of_range exception +template +arma_cold +arma_noinline +static +void +arma_stop_bounds_error(const T1& x) + { + #if defined(ARMA_PRINT_EXCEPTIONS) + { + get_cerr_stream() << "\nerror: " << x << std::endl; + } + #endif + + throw std::out_of_range( std::string(x) ); + } + + + +//! print a message to get_cerr_stream() and throw bad_alloc exception +template +arma_cold +arma_noinline +static +void +arma_stop_bad_alloc(const T1& x) + { + #if defined(ARMA_PRINT_EXCEPTIONS) + { + get_cerr_stream() << "\nerror: " << x << std::endl; + } + #else + { + arma_ignore(x); + } + #endif + + throw std::bad_alloc(); + } + + + +//! print a message to get_cerr_stream() and throw runtime_error exception +template +arma_cold +arma_noinline +static +void +arma_stop_runtime_error(const T1& x) + { + #if defined(ARMA_PRINT_EXCEPTIONS) + { + get_cerr_stream() << "\nerror: " << x << std::endl; + } + #endif + + throw std::runtime_error( std::string(x) ); + } + + + +// +// arma_print + + +arma_cold +inline +void +arma_print() + { + get_cerr_stream() << std::endl; + } + + +template +arma_cold +arma_noinline +static +void +arma_print(const T1& x) + { + get_cerr_stream() << x << std::endl; + } + + + +template +arma_cold +arma_noinline +static +void +arma_print(const T1& x, const T2& y) + { + get_cerr_stream() << x << y << std::endl; + } + + + +template +arma_cold +arma_noinline +static +void +arma_print(const T1& x, const T2& y, const T3& z) + { + get_cerr_stream() << x << y << z << std::endl; + } + + + + + + +// +// arma_sigprint + +//! print a message to the log stream with a preceding @ character. +//! by default the log stream is cout. +//! used for printing the signature of a function +//! (see the arma_extra_debug_sigprint macro) +inline +void +arma_sigprint(const char* x) + { + get_cerr_stream() << "@ " << x; + } + + + +// +// arma_bktprint + + +inline +void +arma_bktprint() + { + get_cerr_stream() << std::endl; + } + + +template +inline +void +arma_bktprint(const T1& x) + { + get_cerr_stream() << " [" << x << ']' << std::endl; + } + + + +template +inline +void +arma_bktprint(const T1& x, const T2& y) + { + get_cerr_stream() << " [" << x << y << ']' << std::endl; + } + + + + + + +// +// arma_thisprint + +inline +void +arma_thisprint(const void* this_ptr) + { + get_cerr_stream() << " [this = " << this_ptr << ']' << std::endl; + } + + + +// +// arma_warn + + +//! print a message to the warn stream +template +arma_cold +arma_noinline +static +void +arma_warn(const T1& arg1) + { + get_cerr_stream() << "\nwarning: " << arg1 << std::endl; + } + + +template +arma_cold +arma_noinline +static +void +arma_warn(const T1& arg1, const T2& arg2) + { + get_cerr_stream() << "\nwarning: " << arg1 << arg2 << std::endl; + } + + +template +arma_cold +arma_noinline +static +void +arma_warn(const T1& arg1, const T2& arg2, const T3& arg3) + { + get_cerr_stream() << "\nwarning: " << arg1 << arg2 << arg3 << std::endl; + } + + +template +arma_cold +arma_noinline +static +void +arma_warn(const T1& arg1, const T2& arg2, const T3& arg3, const T4& arg4) + { + get_cerr_stream() << "\nwarning: " << arg1 << arg2 << arg3 << arg4 << std::endl; + } + + + +// +// arma_warn_level + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1); } + } + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1, const T2& arg2) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1,arg2); } + } + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1, const T2& arg2, const T3& arg3) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1,arg2,arg3); } + } + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1, const T2& arg2, const T3& arg3, const T4& arg4) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1,arg2,arg3,arg4); } + } + + + +// +// arma_check + +//! if state is true, abort program +template +arma_hot +inline +void +arma_check(const bool state, const T1& x) + { + if(state) { arma_stop_logic_error(arma_str::str_wrapper(x)); } + } + + +template +arma_hot +inline +void +arma_check(const bool state, const char* x, const Functor& fn) + { + if(state) { fn(); arma_stop_logic_error(x); } + } + + +arma_hot +inline +void +arma_check(const bool state, const char* x, const char* y) + { + if(state) { arma_stop_logic_error(x,y); } + } + + +template +arma_hot +inline +void +arma_check(const bool state, const char* x, const char* y, const Functor& fn) + { + if(state) { fn(); arma_stop_logic_error(x,y); } + } + + +template +arma_hot +inline +void +arma_check_bounds(const bool state, const T1& x) + { + if(state) { arma_stop_bounds_error(arma_str::str_wrapper(x)); } + } + + +template +arma_hot +inline +void +arma_check_bad_alloc(const bool state, const T1& x) + { + if(state) { arma_stop_bad_alloc(x); } + } + + + +// +// arma_set_error + + +arma_hot +arma_inline +void +arma_set_error(bool& err_state, char*& err_msg, const bool expression, const char* message) + { + if(expression) + { + err_state = true; + err_msg = const_cast(message); + } + } + + + + +// +// functions for generating strings indicating size errors + +arma_cold +arma_noinline +static +std::string +arma_incompat_size_string(const uword A_n_rows, const uword A_n_cols, const uword B_n_rows, const uword B_n_cols, const char* x) + { + std::ostringstream tmp; + + tmp << x << ": incompatible matrix dimensions: " << A_n_rows << 'x' << A_n_cols << " and " << B_n_rows << 'x' << B_n_cols; + + return tmp.str(); + } + + + +arma_cold +arma_noinline +static +std::string +arma_incompat_size_string(const uword A_n_rows, const uword A_n_cols, const uword A_n_slices, const uword B_n_rows, const uword B_n_cols, const uword B_n_slices, const char* x) + { + std::ostringstream tmp; + + tmp << x << ": incompatible cube dimensions: " << A_n_rows << 'x' << A_n_cols << 'x' << A_n_slices << " and " << B_n_rows << 'x' << B_n_cols << 'x' << B_n_slices; + + return tmp.str(); + } + + + +template +arma_cold +arma_noinline +static +std::string +arma_incompat_size_string(const subview_cube& Q, const Mat& A, const char* x) + { + std::ostringstream tmp; + + tmp << x + << ": interpreting matrix as cube with dimensions: " + << A.n_rows << 'x' << A.n_cols << 'x' << 1 + << " or " + << A.n_rows << 'x' << 1 << 'x' << A.n_cols + << " or " + << 1 << 'x' << A.n_rows << 'x' << A.n_cols + << " is incompatible with cube dimensions: " + << Q.n_rows << 'x' << Q.n_cols << 'x' << Q.n_slices; + + return tmp.str(); + } + + + +// +// functions for checking whether two dense matrices have the same dimensions + + + +arma_hot +arma_inline +void +arma_assert_same_size(const uword A_n_rows, const uword A_n_cols, const uword B_n_rows, const uword B_n_cols, const char* x) + { + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +//! stop if given matrices have different sizes +template +arma_hot +inline +void +arma_assert_same_size(const Mat& A, const Mat& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +//! stop if given proxies have different sizes +template +arma_hot +inline +void +arma_assert_same_size(const Proxy& A, const Proxy& B, const char* x) + { + const uword A_n_rows = A.get_n_rows(); + const uword A_n_cols = A.get_n_cols(); + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const subview& A, const subview& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const Mat& A, const subview& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const subview& A, const Mat& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const Mat& A, const Proxy& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const Proxy& A, const Mat& B, const char* x) + { + const uword A_n_rows = A.get_n_rows(); + const uword A_n_cols = A.get_n_cols(); + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const Proxy& A, const subview& B, const char* x) + { + const uword A_n_rows = A.get_n_rows(); + const uword A_n_cols = A.get_n_cols(); + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const subview& A, const Proxy& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +// +// functions for checking whether two sparse matrices have the same dimensions + + + +template +arma_hot +inline +void +arma_assert_same_size(const SpMat& A, const SpMat& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +// +// functions for checking whether two cubes have the same dimensions + + + +arma_hot +inline +void +arma_assert_same_size(const uword A_n_rows, const uword A_n_cols, const uword A_n_slices, const uword B_n_rows, const uword B_n_cols, const uword B_n_slices, const char* x) + { + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) || (A_n_slices != B_n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, A_n_slices, B_n_rows, B_n_cols, B_n_slices, x) ); + } + } + + + +//! stop if given cubes have different sizes +template +arma_hot +inline +void +arma_assert_same_size(const Cube& A, const Cube& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (A.n_slices != B.n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, A.n_slices, B.n_rows, B.n_cols, B.n_slices, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const Cube& A, const subview_cube& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (A.n_slices != B.n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, A.n_slices, B.n_rows, B.n_cols, B.n_slices, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const subview_cube& A, const Cube& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (A.n_slices != B.n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, A.n_slices, B.n_rows, B.n_cols, B.n_slices, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const subview_cube& A, const subview_cube& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (A.n_slices != B.n_slices)) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, A.n_slices, B.n_rows, B.n_cols, B.n_slices, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const subview_cube& A, const ProxyCube& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_n_slices = A.n_slices; + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + const uword B_n_slices = B.get_n_slices(); + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) || (A_n_slices != B_n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, A_n_slices, B_n_rows, B_n_cols, B_n_slices, x) ); + } + } + + + +//! stop if given cube proxies have different sizes +template +arma_hot +inline +void +arma_assert_same_size(const ProxyCube& A, const ProxyCube& B, const char* x) + { + const uword A_n_rows = A.get_n_rows(); + const uword A_n_cols = A.get_n_cols(); + const uword A_n_slices = A.get_n_slices(); + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + const uword B_n_slices = B.get_n_slices(); + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) || (A_n_slices != B_n_slices)) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, A_n_slices, B_n_rows, B_n_cols, B_n_slices, x) ); + } + } + + + +// +// functions for checking whether a cube or subcube can be interpreted as a matrix (ie. single slice) + + + +template +arma_hot +inline +void +arma_assert_same_size(const Cube& A, const Mat& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (A.n_slices != 1) ) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, A.n_slices, B.n_rows, B.n_cols, 1, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const Mat& A, const Cube& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (1 != B.n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, 1, B.n_rows, B.n_cols, B.n_slices, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const subview_cube& A, const Mat& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (A.n_slices != 1) ) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, A.n_slices, B.n_rows, B.n_cols, 1, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_same_size(const Mat& A, const subview_cube& B, const char* x) + { + if( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) || (1 != B.n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, 1, B.n_rows, B.n_cols, B.n_slices, x) ); + } + } + + + +template +inline +void +arma_assert_cube_as_mat(const Mat& M, const T1& Q, const char* x, const bool check_compat_size) + { + const uword Q_n_rows = Q.n_rows; + const uword Q_n_cols = Q.n_cols; + const uword Q_n_slices = Q.n_slices; + + const uword M_vec_state = M.vec_state; + + if(M_vec_state == 0) + { + if( ( (Q_n_rows == 1) || (Q_n_cols == 1) || (Q_n_slices == 1) ) == false ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a matrix; one of the dimensions must be 1"; + + arma_stop_logic_error( tmp.str() ); + } + } + else + { + if(Q_n_slices == 1) + { + if( (M_vec_state == 1) && (Q_n_cols != 1) ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a column vector"; + + arma_stop_logic_error( tmp.str() ); + } + + if( (M_vec_state == 2) && (Q_n_rows != 1) ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a row vector"; + + arma_stop_logic_error( tmp.str() ); + } + } + else + { + if( (Q_n_cols != 1) && (Q_n_rows != 1) ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a vector"; + + arma_stop_logic_error( tmp.str() ); + } + } + } + + + if(check_compat_size) + { + const uword M_n_rows = M.n_rows; + const uword M_n_cols = M.n_cols; + + if(M_vec_state == 0) + { + if( + ( + ( (Q_n_rows == M_n_rows) && (Q_n_cols == M_n_cols) ) + || + ( (Q_n_rows == M_n_rows) && (Q_n_slices == M_n_cols) ) + || + ( (Q_n_cols == M_n_rows) && (Q_n_slices == M_n_cols) ) + ) + == false + ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a matrix with dimensions " + << M_n_rows << 'x' << M_n_cols; + + arma_stop_logic_error( tmp.str() ); + } + } + else + { + if(Q_n_slices == 1) + { + if( (M_vec_state == 1) && (Q_n_rows != M_n_rows) ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a column vector with dimensions " + << M_n_rows << 'x' << M_n_cols; + + arma_stop_logic_error( tmp.str() ); + } + + if( (M_vec_state == 2) && (Q_n_cols != M_n_cols) ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a row vector with dimensions " + << M_n_rows << 'x' << M_n_cols; + + arma_stop_logic_error( tmp.str() ); + } + } + else + { + if( ( (M_n_cols == Q_n_slices) || (M_n_rows == Q_n_slices) ) == false ) + { + std::ostringstream tmp; + + tmp << x + << ": can't interpret cube with dimensions " + << Q_n_rows << 'x' << Q_n_cols << 'x' << Q_n_slices + << " as a vector with dimensions " + << M_n_rows << 'x' << M_n_cols; + + arma_stop_logic_error( tmp.str() ); + } + } + } + } + } + + + +// +// functions for checking whether two matrices have dimensions that are compatible with the matrix multiply operation + + + +arma_hot +inline +void +arma_assert_mul_size(const uword A_n_rows, const uword A_n_cols, const uword B_n_rows, const uword B_n_cols, const char* x) + { + if(A_n_cols != B_n_rows) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, B_n_rows, B_n_cols, x) ); + } + } + + + +//! stop if given matrices are incompatible for multiplication +template +arma_hot +inline +void +arma_assert_mul_size(const Mat& A, const Mat& B, const char* x) + { + const uword A_n_cols = A.n_cols; + const uword B_n_rows = B.n_rows; + + if(A_n_cols != B_n_rows) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A_n_cols, B_n_rows, B.n_cols, x) ); + } + } + + + +//! stop if given matrices are incompatible for multiplication +template +arma_hot +inline +void +arma_assert_mul_size(const Mat& A, const Mat& B, const bool do_trans_A, const bool do_trans_B, const char* x) + { + const uword final_A_n_cols = (do_trans_A == false) ? A.n_cols : A.n_rows; + const uword final_B_n_rows = (do_trans_B == false) ? B.n_rows : B.n_cols; + + if(final_A_n_cols != final_B_n_rows) + { + const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; + const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; + + arma_stop_logic_error( arma_incompat_size_string(final_A_n_rows, final_A_n_cols, final_B_n_rows, final_B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_trans_mul_size(const uword A_n_rows, const uword A_n_cols, const uword B_n_rows, const uword B_n_cols, const char* x) + { + const uword final_A_n_cols = (do_trans_A == false) ? A_n_cols : A_n_rows; + const uword final_B_n_rows = (do_trans_B == false) ? B_n_rows : B_n_cols; + + if(final_A_n_cols != final_B_n_rows) + { + const uword final_A_n_rows = (do_trans_A == false) ? A_n_rows : A_n_cols; + const uword final_B_n_cols = (do_trans_B == false) ? B_n_cols : B_n_rows; + + arma_stop_logic_error( arma_incompat_size_string(final_A_n_rows, final_A_n_cols, final_B_n_rows, final_B_n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_mul_size(const Mat& A, const subview& B, const char* x) + { + if(A.n_cols != B.n_rows) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, B.n_rows, B.n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_mul_size(const subview& A, const Mat& B, const char* x) + { + if(A.n_cols != B.n_rows) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, B.n_rows, B.n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_mul_size(const subview& A, const subview& B, const char* x) + { + if(A.n_cols != B.n_rows) + { + arma_stop_logic_error( arma_incompat_size_string(A.n_rows, A.n_cols, B.n_rows, B.n_cols, x) ); + } + } + + + +template +arma_hot +inline +void +arma_assert_blas_size(const T1& A) + { + if(sizeof(uword) >= sizeof(blas_int)) + { + bool overflow; + + overflow = (A.n_rows > ARMA_MAX_BLAS_INT); + overflow = (A.n_cols > ARMA_MAX_BLAS_INT) || overflow; + + if(overflow) + { + arma_stop_runtime_error("integer overflow: matrix dimensions are too large for integer type used by BLAS and LAPACK"); + } + } + } + + + +template +arma_hot +inline +void +arma_assert_blas_size(const T1& A, const T2& B) + { + if(sizeof(uword) >= sizeof(blas_int)) + { + bool overflow; + + overflow = (A.n_rows > ARMA_MAX_BLAS_INT); + overflow = (A.n_cols > ARMA_MAX_BLAS_INT) || overflow; + overflow = (B.n_rows > ARMA_MAX_BLAS_INT) || overflow; + overflow = (B.n_cols > ARMA_MAX_BLAS_INT) || overflow; + + if(overflow) + { + arma_stop_runtime_error("integer overflow: matrix dimensions are too large for integer type used by BLAS and LAPACK"); + } + } + } + + + +// TODO: remove support for ATLAS in next major version +template +arma_hot +inline +void +arma_assert_atlas_size(const T1& A) + { + if(sizeof(uword) >= sizeof(int)) + { + bool overflow; + + overflow = (A.n_rows > INT_MAX); + overflow = (A.n_cols > INT_MAX) || overflow; + + if(overflow) + { + arma_stop_runtime_error("integer overflow: matrix dimensions are too large for integer type used by ATLAS"); + } + } + } + + + +// TODO: remove support for ATLAS in next major version +template +arma_hot +inline +void +arma_assert_atlas_size(const T1& A, const T2& B) + { + if(sizeof(uword) >= sizeof(int)) + { + bool overflow; + + overflow = (A.n_rows > INT_MAX); + overflow = (A.n_cols > INT_MAX) || overflow; + overflow = (B.n_rows > INT_MAX) || overflow; + overflow = (B.n_cols > INT_MAX) || overflow; + + if(overflow) + { + arma_stop_runtime_error("integer overflow: matrix dimensions are too large for integer type used by ATLAS"); + } + } + } + + + +// +// macros + + +// #define ARMA_STRING1(x) #x +// #define ARMA_STRING2(x) ARMA_STRING1(x) +// #define ARMA_FILELINE __FILE__ ": " ARMA_STRING2(__LINE__) + + +#if defined(ARMA_NO_DEBUG) + + #define arma_debug_print true ? (void)0 : arma_print + #define arma_debug_warn true ? (void)0 : arma_warn + #define arma_debug_warn_level true ? (void)0 : arma_warn_level + #define arma_debug_check true ? (void)0 : arma_check + #define arma_debug_check_bounds true ? (void)0 : arma_check_bounds + #define arma_debug_set_error true ? (void)0 : arma_set_error + #define arma_debug_assert_same_size true ? (void)0 : arma_assert_same_size + #define arma_debug_assert_mul_size true ? (void)0 : arma_assert_mul_size + #define arma_debug_assert_trans_mul_size true ? (void)0 : arma_assert_trans_mul_size + #define arma_debug_assert_cube_as_mat true ? (void)0 : arma_assert_cube_as_mat + #define arma_debug_assert_blas_size true ? (void)0 : arma_assert_blas_size + #define arma_debug_assert_atlas_size true ? (void)0 : arma_assert_atlas_size + +#else + + #define arma_debug_print arma_print + #define arma_debug_warn arma_warn + #define arma_debug_warn_level arma_warn_level + #define arma_debug_check arma_check + #define arma_debug_check_bounds arma_check_bounds + #define arma_debug_set_error arma_set_error + #define arma_debug_assert_same_size arma_assert_same_size + #define arma_debug_assert_mul_size arma_assert_mul_size + #define arma_debug_assert_trans_mul_size arma_assert_trans_mul_size + #define arma_debug_assert_cube_as_mat arma_assert_cube_as_mat + #define arma_debug_assert_blas_size arma_assert_blas_size + #define arma_debug_assert_atlas_size arma_assert_atlas_size + +#endif + + + +#if defined(ARMA_EXTRA_DEBUG) + + #define arma_extra_debug_sigprint arma_sigprint(ARMA_FNSIG); arma_bktprint + #define arma_extra_debug_sigprint_this arma_sigprint(ARMA_FNSIG); arma_thisprint + #define arma_extra_debug_print arma_print + +#else + + #define arma_extra_debug_sigprint true ? (void)0 : arma_bktprint + #define arma_extra_debug_sigprint_this true ? (void)0 : arma_thisprint + #define arma_extra_debug_print true ? (void)0 : arma_print + +#endif + + + + +#if defined(ARMA_EXTRA_DEBUG) + + namespace junk + { + class arma_first_extra_debug_message + { + public: + + inline + arma_first_extra_debug_message() + { + union + { + unsigned short a; + unsigned char b[sizeof(unsigned short)]; + } endian_test; + + endian_test.a = 1; + + const bool little_endian = (endian_test.b[0] == 1); + const char* nickname = ARMA_VERSION_NAME; + + std::ostream& out = get_cerr_stream(); + + out << "@ ---" << '\n'; + out << "@ Armadillo " + << arma_version::major << '.' << arma_version::minor << '.' << arma_version::patch + << " (" << nickname << ")\n"; + + out << "@ arma_config::wrapper = " << arma_config::wrapper << '\n'; + out << "@ arma_config::cxx14 = " << arma_config::cxx14 << '\n'; + out << "@ arma_config::cxx17 = " << arma_config::cxx17 << '\n'; + out << "@ arma_config::cxx20 = " << arma_config::cxx20 << '\n'; + out << "@ arma_config::std_mutex = " << arma_config::std_mutex << '\n'; + out << "@ arma_config::posix = " << arma_config::posix << '\n'; + out << "@ arma_config::openmp = " << arma_config::openmp << '\n'; + out << "@ arma_config::lapack = " << arma_config::lapack << '\n'; + out << "@ arma_config::blas = " << arma_config::blas << '\n'; + out << "@ arma_config::newarp = " << arma_config::newarp << '\n'; + out << "@ arma_config::arpack = " << arma_config::arpack << '\n'; + out << "@ arma_config::superlu = " << arma_config::superlu << '\n'; + out << "@ arma_config::atlas = " << arma_config::atlas << '\n'; + out << "@ arma_config::hdf5 = " << arma_config::hdf5 << '\n'; + out << "@ arma_config::good_comp = " << arma_config::good_comp << '\n'; + out << "@ arma_config::extra_code = " << arma_config::extra_code << '\n'; + out << "@ arma_config::hidden_args = " << arma_config::hidden_args << '\n'; + out << "@ arma_config::mat_prealloc = " << arma_config::mat_prealloc << '\n'; + out << "@ arma_config::mp_threshold = " << arma_config::mp_threshold << '\n'; + out << "@ arma_config::mp_threads = " << arma_config::mp_threads << '\n'; + out << "@ arma_config::optimise_band = " << arma_config::optimise_band << '\n'; + out << "@ arma_config::optimise_sym = " << arma_config::optimise_sym << '\n'; + out << "@ arma_config::optimise_invexpr = " << arma_config::optimise_invexpr << '\n'; + out << "@ arma_config::check_nonfinite = " << arma_config::check_nonfinite << '\n'; + out << "@ arma_config::zero_init = " << arma_config::zero_init << '\n'; + out << "@ arma_config::fast_math = " << arma_config::fast_math << '\n'; + out << "@ sizeof(void*) = " << sizeof(void*) << '\n'; + out << "@ sizeof(int) = " << sizeof(int) << '\n'; + out << "@ sizeof(long) = " << sizeof(long) << '\n'; + out << "@ sizeof(uword) = " << sizeof(uword) << '\n'; + out << "@ sizeof(blas_int) = " << sizeof(blas_int) << '\n'; + out << "@ little_endian = " << little_endian << '\n'; + out << "@ ---" << std::endl; + } + + }; + + static arma_first_extra_debug_message arma_first_extra_debug_message_run; + } + +#endif + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/def_arpack.hpp b/src/armadillo/include/armadillo_bits/def_arpack.hpp new file mode 100644 index 0000000..5bbbb7f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/def_arpack.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_ARPACK) + +// I'm not sure this is necessary. +#if !defined(ARMA_BLAS_CAPITALS) + + #define arma_snaupd snaupd + #define arma_dnaupd dnaupd + #define arma_cnaupd cnaupd + #define arma_znaupd znaupd + + #define arma_sneupd sneupd + #define arma_dneupd dneupd + #define arma_cneupd cneupd + #define arma_zneupd zneupd + + #define arma_ssaupd ssaupd + #define arma_dsaupd dsaupd + + #define arma_sseupd sseupd + #define arma_dseupd dseupd + +#else + + #define arma_snaupd SNAUPD + #define arma_dnaupd DNAUPD + #define arma_cnaupd CNAUPD + #define arma_znaupd ZNAUPD + + #define arma_sneupd SNEUPD + #define arma_dneupd DNEUPD + #define arma_cneupd CNEUPD + #define arma_zneupd ZNEUPD + + #define arma_ssaupd SSAUPD + #define arma_dsaupd DSAUPD + + #define arma_sseupd SSEUPD + #define arma_dseupd DSEUPD + +#endif + +extern "C" +{ +#if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + + // eigendecomposition of non-symmetric positive semi-definite matrices + void arma_fortran(arma_snaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_dnaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_cnaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, float* rwork, blas_int* info, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_znaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, double* rwork, blas_int* info, blas_len bmat_len, blas_len which_len); + + // recovery of eigenvectors after naupd(); uses blas_int for LOGICAL types + void arma_fortran(arma_sneupd)(blas_int* rvec, char* howmny, blas_int* select, float* dr, float* di, float* z, blas_int* ldz, float* sigmar, float* sigmai, float* workev, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info, blas_len howmny_len, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_dneupd)(blas_int* rvec, char* howmny, blas_int* select, double* dr, double* di, double* z, blas_int* ldz, double* sigmar, double* sigmai, double* workev, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info, blas_len howmny_len, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_cneupd)(blas_int* rvec, char* howmny, blas_int* select, void* d, void* z, blas_int* ldz, void* sigma, void* workev, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, float* rwork, blas_int* info, blas_len howmny_len, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_zneupd)(blas_int* rvec, char* howmny, blas_int* select, void* d, void* z, blas_int* ldz, void* sigma, void* workev, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, double* rwork, blas_int* info, blas_len howmny_len, blas_len bmat_len, blas_len which_len); + + // eigendecomposition of symmetric positive semi-definite matrices + void arma_fortran(arma_ssaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_dsaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info, blas_len bmat_len, blas_len which_len); + + // recovery of eigenvectors after saupd(); uses blas_int for LOGICAL types + void arma_fortran(arma_sseupd)(blas_int* rvec, char* howmny, blas_int* select, float* d, float* z, blas_int* ldz, float* sigma, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info, blas_len howmny_len, blas_len bmat_len, blas_len which_len); + void arma_fortran(arma_dseupd)(blas_int* rvec, char* howmny, blas_int* select, double* d, double* z, blas_int* ldz, double* sigma, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info, blas_len howmny_len, blas_len bmat_len, blas_len which_len); + +#else + + // eigendecomposition of non-symmetric positive semi-definite matrices + void arma_fortran(arma_snaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info); + void arma_fortran(arma_dnaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info); + void arma_fortran(arma_cnaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, float* rwork, blas_int* info); + void arma_fortran(arma_znaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, double* rwork, blas_int* info); + + // recovery of eigenvectors after naupd(); uses blas_int for LOGICAL types + void arma_fortran(arma_sneupd)(blas_int* rvec, char* howmny, blas_int* select, float* dr, float* di, float* z, blas_int* ldz, float* sigmar, float* sigmai, float* workev, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info); + void arma_fortran(arma_dneupd)(blas_int* rvec, char* howmny, blas_int* select, double* dr, double* di, double* z, blas_int* ldz, double* sigmar, double* sigmai, double* workev, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info); + void arma_fortran(arma_cneupd)(blas_int* rvec, char* howmny, blas_int* select, void* d, void* z, blas_int* ldz, void* sigma, void* workev, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, float* rwork, blas_int* info); + void arma_fortran(arma_zneupd)(blas_int* rvec, char* howmny, blas_int* select, void* d, void* z, blas_int* ldz, void* sigma, void* workev, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, void* resid, blas_int* ncv, void* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, void* workd, void* workl, blas_int* lworkl, double* rwork, blas_int* info); + + // eigendecomposition of symmetric positive semi-definite matrices + void arma_fortran(arma_ssaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info); + void arma_fortran(arma_dsaupd)(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info); + + // recovery of eigenvectors after saupd(); uses blas_int for LOGICAL types + void arma_fortran(arma_sseupd)(blas_int* rvec, char* howmny, blas_int* select, float* d, float* z, blas_int* ldz, float* sigma, char* bmat, blas_int* n, char* which, blas_int* nev, float* tol, float* resid, blas_int* ncv, float* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, float* workd, float* workl, blas_int* lworkl, blas_int* info); + void arma_fortran(arma_dseupd)(blas_int* rvec, char* howmny, blas_int* select, double* d, double* z, blas_int* ldz, double* sigma, char* bmat, blas_int* n, char* which, blas_int* nev, double* tol, double* resid, blas_int* ncv, double* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, double* workd, double* workl, blas_int* lworkl, blas_int* info); + +#endif +} + +#endif diff --git a/src/armadillo/include/armadillo_bits/def_atlas.hpp b/src/armadillo/include/armadillo_bits/def_atlas.hpp new file mode 100644 index 0000000..e410d9b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/def_atlas.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +// TODO: remove support for ATLAS in next major version + +#if defined(ARMA_USE_ATLAS) + + +typedef enum + { + atlas_CblasRowMajor = 101, + atlas_CblasColMajor = 102 + } + atlas_CBLAS_LAYOUT; + +typedef enum + { + atlas_CblasNoTrans = 111, + atlas_CblasTrans = 112, + atlas_CblasConjTrans = 113 + } + atlas_CBLAS_TRANS; + +typedef enum + { + atlas_CblasUpper = 121, + atlas_CblasLower = 122 + } + atlas_CBLAS_UPLO; + + +extern "C" + { + float arma_wrapper(cblas_sasum)(const int N, const float *X, const int incX); + double arma_wrapper(cblas_dasum)(const int N, const double *X, const int incX); + + float arma_wrapper(cblas_snrm2)(const int N, const float *X, const int incX); + double arma_wrapper(cblas_dnrm2)(const int N, const double *X, const int incX); + + float arma_wrapper(cblas_sdot)(const int N, const float *X, const int incX, const float *Y, const int incY); + double arma_wrapper(cblas_ddot)(const int N, const double *X, const int incX, const double *Y, const int incY); + + void arma_wrapper(cblas_cdotu_sub)(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu); + void arma_wrapper(cblas_zdotu_sub)(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu); + + void arma_wrapper(cblas_sgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const float alpha, const float *A, const int lda, const float *X, const int incX, const float beta, float *Y, const int incY); + void arma_wrapper(cblas_dgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const double alpha, const double *A, const int lda, const double *X, const int incX, const double beta, double *Y, const int incY); + void arma_wrapper(cblas_cgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const void *alpha, const void *A, const int lda, const void *X, const int incX, const void *beta, void *Y, const int incY); + void arma_wrapper(cblas_zgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const void *alpha, const void *A, const int lda, const void *X, const int incX, const void *beta, void *Y, const int incY); + + void arma_wrapper(cblas_sgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const float alpha, const float *A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc); + void arma_wrapper(cblas_dgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const double alpha, const double *A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc); + void arma_wrapper(cblas_cgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const void *alpha, const void *A, const int lda, const void *B, const int ldb, const void *beta, void *C, const int ldc); + void arma_wrapper(cblas_zgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const void *alpha, const void *A, const int lda, const void *B, const int ldb, const void *beta, void *C, const int ldc); + + void arma_wrapper(cblas_ssyrk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const float alpha, const float *A, const int lda, const float beta, float *C, const int ldc); + void arma_wrapper(cblas_dsyrk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const double alpha, const double *A, const int lda, const double beta, double *C, const int ldc); + + void arma_wrapper(cblas_cherk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const float alpha, const void *A, const int lda, const float beta, void *C, const int ldc); + void arma_wrapper(cblas_zherk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const double alpha, const void *A, const int lda, const double beta, void *C, const int ldc); + } + + +#endif diff --git a/src/armadillo/include/armadillo_bits/def_blas.hpp b/src/armadillo/include/armadillo_bits/def_blas.hpp new file mode 100644 index 0000000..e27ca6c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/def_blas.hpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if defined(ARMA_USE_BLAS) + +#if defined(dgemm) || defined(DGEMM) + #pragma message ("WARNING: detected possible interference with definitions of BLAS functions;") + #pragma message ("WARNING: include the armadillo header before any other header as a workaround") +#endif + + +#if defined(ARMA_BLAS_NOEXCEPT) + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT noexcept +#else + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT +#endif + + +#if !defined(ARMA_BLAS_CAPITALS) + + #define arma_sasum sasum + #define arma_dasum dasum + + #define arma_snrm2 snrm2 + #define arma_dnrm2 dnrm2 + + #define arma_sdot sdot + #define arma_ddot ddot + + #define arma_sgemv sgemv + #define arma_dgemv dgemv + #define arma_cgemv cgemv + #define arma_zgemv zgemv + + #define arma_sgemm sgemm + #define arma_dgemm dgemm + #define arma_cgemm cgemm + #define arma_zgemm zgemm + + #define arma_ssyrk ssyrk + #define arma_dsyrk dsyrk + + #define arma_cherk cherk + #define arma_zherk zherk + +#else + + #define arma_sasum SASUM + #define arma_dasum DASUM + + #define arma_snrm2 SNRM2 + #define arma_dnrm2 DNRM2 + + #define arma_sdot SDOT + #define arma_ddot DDOT + + #define arma_sgemv SGEMV + #define arma_dgemv DGEMV + #define arma_cgemv CGEMV + #define arma_zgemv ZGEMV + + #define arma_sgemm SGEMM + #define arma_dgemm DGEMM + #define arma_cgemm CGEMM + #define arma_zgemm ZGEMM + + #define arma_ssyrk SSYRK + #define arma_dsyrk DSYRK + + #define arma_cherk CHERK + #define arma_zherk ZHERK + +#endif + + +// NOTE: "For arguments of CHARACTER type, the character length is passed as a hidden argument at the end of the argument list." +// NOTE: https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html + + +extern "C" +{ +#if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + + float arma_fortran(arma_sasum)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dasum)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; + + float arma_fortran(arma_snrm2)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dnrm2)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; + + float arma_fortran(arma_sdot)(const blas_int* n, const float* x, const blas_int* incx, const float* y, const blas_int* incy) ARMA_NOEXCEPT; + double arma_fortran(arma_ddot)(const blas_int* n, const double* x, const blas_int* incx, const double* y, const blas_int* incy) ARMA_NOEXCEPT; + + void arma_fortran(arma_sgemv)(const char* transA, const blas_int* m, const blas_int* n, const float* alpha, const float* A, const blas_int* ldA, const float* x, const blas_int* incx, const float* beta, float* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemv)(const char* transA, const blas_int* m, const blas_int* n, const double* alpha, const double* A, const blas_int* ldA, const double* x, const blas_int* incx, const double* beta, double* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* x, const blas_int* incx, const blas_cxf* beta, blas_cxf* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* x, const blas_int* incx, const blas_cxd* beta, blas_cxd* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; + + void arma_fortran(arma_sgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* B, const blas_int* ldB, const float* beta, float* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* B, const blas_int* ldB, const double* beta, double* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* B, const blas_int* ldB, const blas_cxf* beta, blas_cxf* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* B, const blas_int* ldB, const blas_cxd* beta, blas_cxd* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; + + void arma_fortran(arma_ssyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* beta, float* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* beta, double* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; + + void arma_fortran(arma_cherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const blas_cxf* A, const blas_int* ldA, const float* beta, blas_cxf* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const blas_cxd* A, const blas_int* ldA, const double* beta, blas_cxd* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; + +#else + + // prototypes without hidden arguments + + float arma_fortran(arma_sasum)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dasum)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; + + float arma_fortran(arma_snrm2)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dnrm2)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; + + float arma_fortran(arma_sdot)(const blas_int* n, const float* x, const blas_int* incx, const float* y, const blas_int* incy) ARMA_NOEXCEPT; + double arma_fortran(arma_ddot)(const blas_int* n, const double* x, const blas_int* incx, const double* y, const blas_int* incy) ARMA_NOEXCEPT; + + void arma_fortran(arma_sgemv)(const char* transA, const blas_int* m, const blas_int* n, const float* alpha, const float* A, const blas_int* ldA, const float* x, const blas_int* incx, const float* beta, float* y, const blas_int* incy) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemv)(const char* transA, const blas_int* m, const blas_int* n, const double* alpha, const double* A, const blas_int* ldA, const double* x, const blas_int* incx, const double* beta, double* y, const blas_int* incy) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* x, const blas_int* incx, const blas_cxf* beta, blas_cxf* y, const blas_int* incy) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* x, const blas_int* incx, const blas_cxd* beta, blas_cxd* y, const blas_int* incy) ARMA_NOEXCEPT; + + void arma_fortran(arma_sgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* B, const blas_int* ldB, const float* beta, float* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* B, const blas_int* ldB, const double* beta, double* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* B, const blas_int* ldB, const blas_cxf* beta, blas_cxf* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* B, const blas_int* ldB, const blas_cxd* beta, blas_cxd* C, const blas_int* ldC) ARMA_NOEXCEPT; + + void arma_fortran(arma_ssyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* beta, float* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* beta, double* C, const blas_int* ldC) ARMA_NOEXCEPT; + + void arma_fortran(arma_cherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const blas_cxf* A, const blas_int* ldA, const float* beta, blas_cxf* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_zherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const blas_cxd* A, const blas_int* ldA, const double* beta, blas_cxd* C, const blas_int* ldC) ARMA_NOEXCEPT; + +#endif +} + +#undef ARMA_NOEXCEPT + +#endif diff --git a/src/armadillo/include/armadillo_bits/def_fftw3.hpp b/src/armadillo/include/armadillo_bits/def_fftw3.hpp new file mode 100644 index 0000000..454d752 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/def_fftw3.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_FFTW3) + + +extern "C" + { + // function prefix for single precision: fftwf_ + // function prefix for double precision: fftw_ + + + // single precision (float) + + void_ptr fftwf_plan_dft_1d(int N, void* input, void* output, int fftw3_sign, unsigned int fftw3_flags); + + void fftwf_execute(void_ptr plan); + void fftwf_destroy_plan(void_ptr plan); + + void fftwf_cleanup(); + + + // double precision (double) + + void_ptr fftw_plan_dft_1d(int N, void* input, void* output, int fftw3_sign, unsigned int fftw3_flags); + + void fftw_execute(void_ptr plan); + void fftw_destroy_plan(void_ptr plan); + + void fftw_cleanup(); + } + + +#endif diff --git a/src/armadillo/include/armadillo_bits/def_lapack.hpp b/src/armadillo/include/armadillo_bits/def_lapack.hpp new file mode 100644 index 0000000..00854ab --- /dev/null +++ b/src/armadillo/include/armadillo_bits/def_lapack.hpp @@ -0,0 +1,1178 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if defined(ARMA_USE_LAPACK) + +#if defined(dgetrf) || defined(DGETRF) + #pragma message ("WARNING: detected possible interference with definitions of LAPACK functions;") + #pragma message ("WARNING: include the armadillo header before any other header as a workaround") +#endif + + +#if defined(ARMA_LAPACK_NOEXCEPT) + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT noexcept +#else + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT +#endif + + +#if !defined(ARMA_BLAS_CAPITALS) + #define arma_sgetrf sgetrf + #define arma_dgetrf dgetrf + #define arma_cgetrf cgetrf + #define arma_zgetrf zgetrf + + #define arma_sgetrs sgetrs + #define arma_dgetrs dgetrs + #define arma_cgetrs cgetrs + #define arma_zgetrs zgetrs + + #define arma_sgetri sgetri + #define arma_dgetri dgetri + #define arma_cgetri cgetri + #define arma_zgetri zgetri + + #define arma_strtri strtri + #define arma_dtrtri dtrtri + #define arma_ctrtri ctrtri + #define arma_ztrtri ztrtri + + #define arma_sgeev sgeev + #define arma_dgeev dgeev + #define arma_cgeev cgeev + #define arma_zgeev zgeev + + #define arma_sgeevx sgeevx + #define arma_dgeevx dgeevx + #define arma_cgeevx cgeevx + #define arma_zgeevx zgeevx + + #define arma_ssyev ssyev + #define arma_dsyev dsyev + + #define arma_cheev cheev + #define arma_zheev zheev + + #define arma_ssyevd ssyevd + #define arma_dsyevd dsyevd + + #define arma_cheevd cheevd + #define arma_zheevd zheevd + + #define arma_sggev sggev + #define arma_dggev dggev + + #define arma_cggev cggev + #define arma_zggev zggev + + #define arma_spotrf spotrf + #define arma_dpotrf dpotrf + #define arma_cpotrf cpotrf + #define arma_zpotrf zpotrf + + #define arma_spotrs spotrs + #define arma_dpotrs dpotrs + #define arma_cpotrs cpotrs + #define arma_zpotrs zpotrs + + #define arma_spbtrf spbtrf + #define arma_dpbtrf dpbtrf + #define arma_cpbtrf cpbtrf + #define arma_zpbtrf zpbtrf + + #define arma_spotri spotri + #define arma_dpotri dpotri + #define arma_cpotri cpotri + #define arma_zpotri zpotri + + #define arma_sgeqrf sgeqrf + #define arma_dgeqrf dgeqrf + #define arma_cgeqrf cgeqrf + #define arma_zgeqrf zgeqrf + + #define arma_sgeqp3 sgeqp3 + #define arma_dgeqp3 dgeqp3 + #define arma_cgeqp3 cgeqp3 + #define arma_zgeqp3 zgeqp3 + + #define arma_sorgqr sorgqr + #define arma_dorgqr dorgqr + + #define arma_cungqr cungqr + #define arma_zungqr zungqr + + #define arma_sgesvd sgesvd + #define arma_dgesvd dgesvd + + #define arma_cgesvd cgesvd + #define arma_zgesvd zgesvd + + #define arma_sgesdd sgesdd + #define arma_dgesdd dgesdd + #define arma_cgesdd cgesdd + #define arma_zgesdd zgesdd + + #define arma_sgesv sgesv + #define arma_dgesv dgesv + #define arma_cgesv cgesv + #define arma_zgesv zgesv + + #define arma_sgesvx sgesvx + #define arma_dgesvx dgesvx + #define arma_cgesvx cgesvx + #define arma_zgesvx zgesvx + + #define arma_sposv sposv + #define arma_dposv dposv + #define arma_cposv cposv + #define arma_zposv zposv + + #define arma_sposvx sposvx + #define arma_dposvx dposvx + #define arma_cposvx cposvx + #define arma_zposvx zposvx + + #define arma_sgels sgels + #define arma_dgels dgels + #define arma_cgels cgels + #define arma_zgels zgels + + #define arma_sgelsd sgelsd + #define arma_dgelsd dgelsd + #define arma_cgelsd cgelsd + #define arma_zgelsd zgelsd + + #define arma_strtrs strtrs + #define arma_dtrtrs dtrtrs + #define arma_ctrtrs ctrtrs + #define arma_ztrtrs ztrtrs + + #define arma_sgbtrf sgbtrf + #define arma_dgbtrf dgbtrf + #define arma_cgbtrf cgbtrf + #define arma_zgbtrf zgbtrf + + #define arma_sgbtrs sgbtrs + #define arma_dgbtrs dgbtrs + #define arma_cgbtrs cgbtrs + #define arma_zgbtrs zgbtrs + + #define arma_sgbsv sgbsv + #define arma_dgbsv dgbsv + #define arma_cgbsv cgbsv + #define arma_zgbsv zgbsv + + #define arma_sgbsvx sgbsvx + #define arma_dgbsvx dgbsvx + #define arma_cgbsvx cgbsvx + #define arma_zgbsvx zgbsvx + + #define arma_sgtsv sgtsv + #define arma_dgtsv dgtsv + #define arma_cgtsv cgtsv + #define arma_zgtsv zgtsv + + #define arma_sgtsvx sgtsvx + #define arma_dgtsvx dgtsvx + #define arma_cgtsvx cgtsvx + #define arma_zgtsvx zgtsvx + + #define arma_sgees sgees + #define arma_dgees dgees + #define arma_cgees cgees + #define arma_zgees zgees + + #define arma_strsyl strsyl + #define arma_dtrsyl dtrsyl + #define arma_ctrsyl ctrsyl + #define arma_ztrsyl ztrsyl + + #define arma_sgges sgges + #define arma_dgges dgges + #define arma_cgges cgges + #define arma_zgges zgges + + #define arma_slange slange + #define arma_dlange dlange + #define arma_clange clange + #define arma_zlange zlange + + #define arma_slansy slansy + #define arma_dlansy dlansy + #define arma_clansy clansy + #define arma_zlansy zlansy + + #define arma_clanhe clanhe + #define arma_zlanhe zlanhe + + #define arma_slangb slangb + #define arma_dlangb dlangb + #define arma_clangb clangb + #define arma_zlangb zlangb + + #define arma_sgecon sgecon + #define arma_dgecon dgecon + #define arma_cgecon cgecon + #define arma_zgecon zgecon + + #define arma_spocon spocon + #define arma_dpocon dpocon + #define arma_cpocon cpocon + #define arma_zpocon zpocon + + #define arma_strcon strcon + #define arma_dtrcon dtrcon + #define arma_ctrcon ctrcon + #define arma_ztrcon ztrcon + + #define arma_sgbcon sgbcon + #define arma_dgbcon dgbcon + #define arma_cgbcon cgbcon + #define arma_zgbcon zgbcon + + #define arma_ilaenv ilaenv + + #define arma_slahqr slahqr + #define arma_dlahqr dlahqr + + #define arma_sstedc sstedc + #define arma_dstedc dstedc + + #define arma_strevc strevc + #define arma_dtrevc dtrevc + + #define arma_sgehrd sgehrd + #define arma_dgehrd dgehrd + #define arma_cgehrd cgehrd + #define arma_zgehrd zgehrd + + #define arma_spstrf spstrf + #define arma_dpstrf dpstrf + #define arma_cpstrf cpstrf + #define arma_zpstrf zpstrf + +#else + + #define arma_sgetrf SGETRF + #define arma_dgetrf DGETRF + #define arma_cgetrf CGETRF + #define arma_zgetrf ZGETRF + + #define arma_sgetrs SGETRS + #define arma_dgetrs DGETRS + #define arma_cgetrs CGETRS + #define arma_zgetrs ZGETRS + + #define arma_sgetri SGETRI + #define arma_dgetri DGETRI + #define arma_cgetri CGETRI + #define arma_zgetri ZGETRI + + #define arma_strtri STRTRI + #define arma_dtrtri DTRTRI + #define arma_ctrtri CTRTRI + #define arma_ztrtri ZTRTRI + + #define arma_sgeev SGEEV + #define arma_dgeev DGEEV + #define arma_cgeev CGEEV + #define arma_zgeev ZGEEV + + #define arma_sgeevx SGEEVX + #define arma_dgeevx DGEEVX + #define arma_cgeevx CGEEVX + #define arma_zgeevx ZGEEVX + + #define arma_ssyev SSYEV + #define arma_dsyev DSYEV + + #define arma_cheev CHEEV + #define arma_zheev ZHEEV + + #define arma_ssyevd SSYEVD + #define arma_dsyevd DSYEVD + + #define arma_cheevd CHEEVD + #define arma_zheevd ZHEEVD + + #define arma_sggev SGGEV + #define arma_dggev DGGEV + + #define arma_cggev CGGEV + #define arma_zggev ZGGEV + + #define arma_spotrf SPOTRF + #define arma_dpotrf DPOTRF + #define arma_cpotrf CPOTRF + #define arma_zpotrf ZPOTRF + + #define arma_spotrs SPOTRS + #define arma_dpotrs DPOTRS + #define arma_cpotrs CPOTRS + #define arma_zpotrs ZPOTRS + + #define arma_spbtrf SPBTRF + #define arma_dpbtrf DPBTRF + #define arma_cpbtrf CPBTRF + #define arma_zpbtrf ZPBTRF + + #define arma_spotri SPOTRI + #define arma_dpotri DPOTRI + #define arma_cpotri CPOTRI + #define arma_zpotri ZPOTRI + + #define arma_sgeqrf SGEQRF + #define arma_dgeqrf DGEQRF + #define arma_cgeqrf CGEQRF + #define arma_zgeqrf ZGEQRF + + #define arma_sgeqp3 SGEQP3 + #define arma_dgeqp3 DGEQP3 + #define arma_cgeqp3 CGEQP3 + #define arma_zgeqp3 ZGEQP3 + + #define arma_sorgqr SORGQR + #define arma_dorgqr DORGQR + + #define arma_cungqr CUNGQR + #define arma_zungqr ZUNGQR + + #define arma_sgesvd SGESVD + #define arma_dgesvd DGESVD + + #define arma_cgesvd CGESVD + #define arma_zgesvd ZGESVD + + #define arma_sgesdd SGESDD + #define arma_dgesdd DGESDD + #define arma_cgesdd CGESDD + #define arma_zgesdd ZGESDD + + #define arma_sgesv SGESV + #define arma_dgesv DGESV + #define arma_cgesv CGESV + #define arma_zgesv ZGESV + + #define arma_sgesvx SGESVX + #define arma_dgesvx DGESVX + #define arma_cgesvx CGESVX + #define arma_zgesvx ZGESVX + + #define arma_sposv SPOSV + #define arma_dposv DPOSV + #define arma_cposv CPOSV + #define arma_zposv ZPOSV + + #define arma_sposvx SPOSVX + #define arma_dposvx DPOSVX + #define arma_cposvx CPOSVX + #define arma_zposvx ZPOSVX + + #define arma_sgels SGELS + #define arma_dgels DGELS + #define arma_cgels CGELS + #define arma_zgels ZGELS + + #define arma_sgelsd SGELSD + #define arma_dgelsd DGELSD + #define arma_cgelsd CGELSD + #define arma_zgelsd ZGELSD + + #define arma_strtrs STRTRS + #define arma_dtrtrs DTRTRS + #define arma_ctrtrs CTRTRS + #define arma_ztrtrs ZTRTRS + + #define arma_sgbtrf SGBTRF + #define arma_dgbtrf DGBTRF + #define arma_cgbtrf CGBTRF + #define arma_zgbtrf ZGBTRF + + #define arma_sgbtrs SGBTRS + #define arma_dgbtrs DGBTRS + #define arma_cgbtrs CGBTRS + #define arma_zgbtrs ZGBTRS + + #define arma_sgbsv SGBSV + #define arma_dgbsv DGBSV + #define arma_cgbsv CGBSV + #define arma_zgbsv ZGBSV + + #define arma_sgbsvx SGBSVX + #define arma_dgbsvx DGBSVX + #define arma_cgbsvx CGBSVX + #define arma_zgbsvx ZGBSVX + + #define arma_sgtsv SGTSV + #define arma_dgtsv DGTSV + #define arma_cgtsv CGTSV + #define arma_zgtsv ZGTSV + + #define arma_sgtsvx SGTSVX + #define arma_dgtsvx DGTSVX + #define arma_cgtsvx CGTSVX + #define arma_zgtsvx ZGTSVX + + #define arma_sgees SGEES + #define arma_dgees DGEES + #define arma_cgees CGEES + #define arma_zgees ZGEES + + #define arma_strsyl STRSYL + #define arma_dtrsyl DTRSYL + #define arma_ctrsyl CTRSYL + #define arma_ztrsyl ZTRSYL + + #define arma_sgges SGGES + #define arma_dgges DGGES + #define arma_cgges CGGES + #define arma_zgges ZGGES + + #define arma_slange SLANGE + #define arma_dlange DLANGE + #define arma_clange CLANGE + #define arma_zlange ZLANGE + + #define arma_slansy SLANSY + #define arma_dlansy DLANSY + #define arma_clansy CLANSY + #define arma_zlansy ZLANSY + + #define arma_clanhe CLANHE + #define arma_zlanhe ZLANHE + + #define arma_slangb SLANGB + #define arma_dlangb DLANGB + #define arma_clangb CLANGB + #define arma_zlangb ZLANGB + + #define arma_sgecon SGECON + #define arma_dgecon DGECON + #define arma_cgecon CGECON + #define arma_zgecon ZGECON + + #define arma_spocon SPOCON + #define arma_dpocon DPOCON + #define arma_cpocon CPOCON + #define arma_zpocon ZPOCON + + #define arma_strcon STRCON + #define arma_dtrcon DTRCON + #define arma_ctrcon CTRCON + #define arma_ztrcon ZTRCON + + #define arma_sgbcon SGBCON + #define arma_dgbcon DGBCON + #define arma_cgbcon CGBCON + #define arma_zgbcon ZGBCON + + #define arma_ilaenv ILAENV + + #define arma_slahqr SLAHQR + #define arma_dlahqr DLAHQR + + #define arma_sstedc SSTEDC + #define arma_dstedc DSTEDC + + #define arma_strevc STREVC + #define arma_dtrevc DTREVC + + #define arma_sgehrd SGEHRD + #define arma_dgehrd DGEHRD + #define arma_cgehrd CGEHRD + #define arma_zgehrd ZGEHRD + + #define arma_spstrf SPSTRF + #define arma_dpstrf DPSTRF + #define arma_cpstrf CPSTRF + #define arma_zpstrf ZPSTRF + +#endif + + +typedef blas_int (*fn_select_s2) (const float*, const float* ); +typedef blas_int (*fn_select_s3) (const float*, const float*, const float*); + +typedef blas_int (*fn_select_d2) (const double*, const double* ); +typedef blas_int (*fn_select_d3) (const double*, const double*, const double*); + +typedef blas_int (*fn_select_c1) (const blas_cxf* ); +typedef blas_int (*fn_select_c2) (const blas_cxf*, const blas_cxf*); + +typedef blas_int (*fn_select_z1) (const blas_cxd* ); +typedef blas_int (*fn_select_z2) (const blas_cxd*, const blas_cxd*); + + +// NOTE: "For arguments of CHARACTER type, the character length is passed as a hidden argument at the end of the argument list." +// NOTE: https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html + + +extern "C" +{ +#if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + + // LU decomposition + void arma_fortran(arma_sgetrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations using pre-computed LU decomposition + void arma_fortran(arma_sgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; + + // matrix inversion (using pre-computed LU decomposition) + void arma_fortran(arma_sgetri)(const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetri)(const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetri)(const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetri)(const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // matrix inversion (triangular matrices) + void arma_fortran(arma_strtri)(const char* uplo, const char* diag, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtri)(const char* uplo, const char* diag, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (real) + void arma_fortran(arma_sgeev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (complex) + void arma_fortran(arma_cgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (real; advanced form) + void arma_fortran(arma_sgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (complex; advanced form) + void arma_fortran(arma_cgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; + + // eigen decomposition of symmetric real matrices + void arma_fortran(arma_ssyev)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyev)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + + // eigen decomposition of hermitian matrices (complex) + void arma_fortran(arma_cheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + + // eigen decomposition of symmetric real matrices by divide and conquer + void arma_fortran(arma_ssyevd)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyevd)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + + // eigen decomposition of hermitian matrices (complex) by divide and conquer + void arma_fortran(arma_cheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + + // eigen decomposition of general real matrix pair + void arma_fortran(arma_sggev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* alphar, float* alphai, float* beta, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dggev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* alphar, double* alphai, double* beta, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + + // eigen decomposition of general complex matrix pair + void arma_fortran(arma_cggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + + // Cholesky decomposition + void arma_fortran(arma_spotrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // solve system of linear equations using pre-computed Cholesky decomposition + void arma_fortran(arma_spotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // Cholesky decomposition (band matrices) + void arma_fortran(arma_spbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, float* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, double* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxf* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxd* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // matrix inversion (using pre-computed Cholesky decomposition) + void arma_fortran(arma_spotri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // QR decomposition + void arma_fortran(arma_sgeqrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgeqrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (real matrices) + void arma_fortran(arma_sgeqp3)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* jpvt, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqp3)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* jpvt, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (complex matrices) + void arma_fortran(arma_cgeqp3)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* jpvt, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqp3)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* jpvt, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // Q matrix calculation from QR decomposition (real matrices) + void arma_fortran(arma_sorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, float* a, const blas_int* lda, const float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, double* a, const blas_int* lda, const double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // Q matrix calculation from QR decomposition (complex matrices) + void arma_fortran(arma_cungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxf* a, const blas_int* lda, const blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxd* a, const blas_int* lda, const blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // SVD (real matrices) + void arma_fortran(arma_sgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; + + // SVD (complex matrices) + void arma_fortran(arma_cgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; + + // SVD (real matrices) by divide and conquer + void arma_fortran(arma_sgesdd)(const char* jobz, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesdd)(const char* jobz, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; + + // SVD (complex matrices) by divide and conquer + void arma_fortran(arma_cgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; + + // solve system of linear equations (general square matrix) + void arma_fortran(arma_sgesv)(const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesv)(const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgesv)(const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesv)(const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general square matrix, advanced form, real matrices) + void arma_fortran(arma_sgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + + // solve system of linear equations (general square matrix, advanced form, complex matrices) + void arma_fortran(arma_cgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + + // solve system of linear equations (symmetric positive definite matrix) + void arma_fortran(arma_sposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // solve system of linear equations (symmetric positive definite matrix, advanced form, real matrices) + void arma_fortran(arma_sposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, char* equed, float* s, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, char* equed, double* s, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; + + // solve system of linear equations (hermitian positive definite matrix, advanced form, complex matrices) + void arma_fortran(arma_cposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, char* equed, float* s, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, char* equed, double* s, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; + + // solve over/under-determined system of linear equations + void arma_fortran(arma_sgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + + // approximately solve system of linear equations using svd (real) + void arma_fortran(arma_sgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // approximately solve system of linear equations using svd (complex) + void arma_fortran(arma_cgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (triangular matrix) + void arma_fortran(arma_strtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; + + // LU factorisation (general band matrix) + void arma_fortran(arma_sgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, float* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, double* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations using pre-computed LU decomposition (general band matrix) + void arma_fortran(arma_sgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const float* ab, const blas_int* ldab, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const double* ab, const blas_int* ldab, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + + // solve system of linear equations (general band matrix) + void arma_fortran(arma_sgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general band matrix, advanced form, real matrices) + void arma_fortran(arma_sgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, float* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, double* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + + // solve system of linear equations (general band matrix, advanced form, complex matrices) + void arma_fortran(arma_cgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_cxf* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_cxd* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + + // solve system of linear equations (tridiagonal band matrix) + void arma_fortran(arma_sgtsv)(const blas_int* n, const blas_int* nrhs, float* dl, float* d, float* du, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsv)(const blas_int* n, const blas_int* nrhs, double* dl, double* d, double* du, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxf* dl, blas_cxf* d, blas_cxf* du, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxd* dl, blas_cxd* d, blas_cxd* du, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (tridiagonal band matrix, advanced form, real matrices) + void arma_fortran(arma_sgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const float* dl, const float* d, const float* du, float* dlf, float* df, float* duf, float* du2, blas_int* ipiv, const float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const double* dl, const double* d, const double* du, double* dlf, double* df, double* duf, double* du2, blas_int* ipiv, const double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; + + // solve system of linear equations (tridiagonal band matrix, advanced form, complex matrices) + void arma_fortran(arma_cgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* dl, const blas_cxf* d, const blas_cxf* du, blas_cxf* dlf, blas_cxf* df, blas_cxf* duf, blas_cxf* du2, blas_int* ipiv, const blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* dl, const blas_cxd* d, const blas_cxd* du, blas_cxd* dlf, blas_cxd* df, blas_cxd* duf, blas_cxd* du2, blas_int* ipiv, const blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; + + // Schur decomposition (real matrices) + void arma_fortran(arma_sgees)(const char* jobvs, const char* sort, fn_select_s2 select, const blas_int* n, float* a, const blas_int* lda, blas_int* sdim, float* wr, float* wi, float* vs, const blas_int* ldvs, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgees)(const char* jobvs, const char* sort, fn_select_d2 select, const blas_int* n, double* a, const blas_int* lda, blas_int* sdim, double* wr, double* wi, double* vs, const blas_int* ldvs, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; + + // Schur decomposition (complex matrices) + void arma_fortran(arma_cgees)(const char* jobvs, const char* sort, fn_select_c1 select, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* sdim, blas_cxf* w, blas_cxf* vs, const blas_int* ldvs, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgees)(const char* jobvs, const char* sort, fn_select_z1 select, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* sdim, blas_cxd* w, blas_cxd* vs, const blas_int* ldvs, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; + + // solve a Sylvester equation ax + xb = c, with a and b assumed to be in Schur form + void arma_fortran(arma_strsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, const float* b, const blas_int* ldb, float* c, const blas_int* ldc, float* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, const double* b, const blas_int* ldb, double* c, const blas_int* ldc, double* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, const blas_cxf* b, const blas_int* ldb, blas_cxf* c, const blas_int* ldc, float* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, const blas_cxd* b, const blas_int* ldb, blas_cxd* c, const blas_int* ldc, double* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; + + // QZ decomposition (real matrices) + void arma_fortran(arma_sgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_s3 selctg, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* sdim, float* alphar, float* alphai, float* beta, float* vsl, const blas_int* ldvsl, float* vsr, const blas_int* ldvsr, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_d3 selctg, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* sdim, double* alphar, double* alphai, double* beta, double* vsl, const blas_int* ldvsl, double* vsr, const blas_int* ldvsr, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; + + // QZ decomposition (complex matrices) + void arma_fortran(arma_cgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_c2 selctg, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* sdim, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vsl, const blas_int* ldvsl, blas_cxf* vsr, const blas_int* ldvsr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_z2 selctg, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* sdim, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vsl, const blas_int* ldvsl, blas_cxd* vsr, const blas_int* ldvsr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; + + // 1-norm (general matrix) + float arma_fortran(arma_slange)(const char* norm, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_dlange)(const char* norm, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, double* work, blas_len norm_len) ARMA_NOEXCEPT; + float arma_fortran(arma_clange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len) ARMA_NOEXCEPT; + + // 1-norm (real symmetric matrix) + float arma_fortran(arma_slansy)(const char* norm, const char* uplo, const blas_int* n, const float* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + double arma_fortran(arma_dlansy)(const char* norm, const char* uplo, const blas_int* n, const double* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + float arma_fortran(arma_clansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + + // 1-norm (complex hermitian matrix) + float arma_fortran(arma_clanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + + // 1-norm (band matrix) + float arma_fortran(arma_slangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_dlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, double* work, blas_len norm_len) ARMA_NOEXCEPT; + float arma_fortran(arma_clangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, double* work, blas_len norm_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, generic matrix) + void arma_fortran(arma_sgecon)(const char* norm, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgecon)(const char* norm, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, generic matrix) + void arma_fortran(arma_cgecon)(const char* norm, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgecon)(const char* norm, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, symmetric positive definite matrix) + void arma_fortran(arma_spocon)(const char* uplo, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpocon)(const char* uplo, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, hermitian positive definite matrix) + void arma_fortran(arma_cpocon)(const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpocon)(const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, triangular matrix) + void arma_fortran(arma_strcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const float* a, const blas_int* lda, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const double* a, const blas_int* lda, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, triangular matrix) + void arma_fortran(arma_ctrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, band matrix) + void arma_fortran(arma_sgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, band matrix) + void arma_fortran(arma_cgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + + // obtain parameters according to the local configuration of lapack + blas_int arma_fortran(arma_ilaenv)(const blas_int* ispec, const char* name, const char* opts, const blas_int* n1, const blas_int* n2, const blas_int* n3, const blas_int* n4, blas_len name_len, blas_len opts_len) ARMA_NOEXCEPT; + + // calculate eigenvalues of an upper Hessenberg matrix + void arma_fortran(arma_slahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* h, const blas_int* ldh, float* wr, float* wi, const blas_int* iloz, const blas_int* ihiz, float* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dlahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* h, const blas_int* ldh, double* wr, double* wi, const blas_int* iloz, const blas_int* ihiz, double* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; + + // calculate eigenvalues of a symmetric tridiagonal matrix + void arma_fortran(arma_sstedc)(const char* compz, const blas_int* n, float* d, float* e, float* z, const blas_int* ldz, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len compz_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dstedc)(const char* compz, const blas_int* n, double* d, double* e, double* z, const blas_int* ldz, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len compz_len) ARMA_NOEXCEPT; + + // calculate eigenvectors of a Schur form matrix + void arma_fortran(arma_strevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const float* t, const blas_int* ldt, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, float* work, blas_int* info, blas_len side_len, blas_len howmny_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const double* t, const blas_int* ldt, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, double* work, blas_int* info, blas_len side_len, blas_len howmny_len) ARMA_NOEXCEPT; + + // hessenberg decomposition + void arma_fortran(arma_sgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* a, const blas_int* lda, float* tao, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* a, const blas_int* lda, double* tao, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxf* a, const blas_int* lda, blas_cxf* tao, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxd* a, const blas_int* lda, blas_cxd* tao, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // pivoted cholesky + void arma_fortran(arma_spstrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpstrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpstrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpstrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + +#else + + // prototypes without hidden arguments + + // LU decomposition + void arma_fortran(arma_sgetrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations using pre-computed LU decomposition + void arma_fortran(arma_sgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // matrix inversion (using pre-computed LU decomposition) + void arma_fortran(arma_sgetri)(const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetri)(const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetri)(const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetri)(const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // matrix inversion (triangular matrices) + void arma_fortran(arma_strtri)(const char* uplo, const char* diag, const blas_int* n, float* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtri)(const char* uplo, const char* diag, const blas_int* n, double* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (real) + void arma_fortran(arma_sgeev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (complex) + void arma_fortran(arma_cgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (real; advanced form) + void arma_fortran(arma_sgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of general matrix (complex; advanced form) + void arma_fortran(arma_cgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of symmetric real matrices + void arma_fortran(arma_ssyev)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyev)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of hermitian matrices (complex) + void arma_fortran(arma_cheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of symmetric real matrices by divide and conquer + void arma_fortran(arma_ssyevd)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyevd)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of hermitian matrices (complex) by divide and conquer + void arma_fortran(arma_cheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of general real matrix pair + void arma_fortran(arma_sggev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* alphar, float* alphai, float* beta, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dggev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* alphar, double* alphai, double* beta, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // eigen decomposition of general complex matrix pair + void arma_fortran(arma_cggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // Cholesky decomposition + void arma_fortran(arma_spotrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations with pre-computed Cholesky decomposition + void arma_fortran(arma_spotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // Cholesky decomposition (band matrices) + void arma_fortran(arma_spbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, float* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, double* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxf* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxd* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; + + // matrix inversion (using pre-computed Cholesky decomposition) + void arma_fortran(arma_spotri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition + void arma_fortran(arma_sgeqrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgeqrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (real matrices) + void arma_fortran(arma_sgeqp3)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* jpvt, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqp3)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* jpvt, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (complex matrices) + void arma_fortran(arma_cgeqp3)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* jpvt, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqp3)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* jpvt, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // Q matrix calculation from QR decomposition (real matrices) + void arma_fortran(arma_sorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, float* a, const blas_int* lda, const float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, double* a, const blas_int* lda, const double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // Q matrix calculation from QR decomposition (complex matrices) + void arma_fortran(arma_cungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxf* a, const blas_int* lda, const blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxd* a, const blas_int* lda, const blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // SVD (real matrices) + void arma_fortran(arma_sgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // SVD (complex matrices) + void arma_fortran(arma_cgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // SVD (real matrices) by divide and conquer + void arma_fortran(arma_sgesdd)(const char* jobz, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesdd)(const char* jobz, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // SVD (complex matrices) by divide and conquer + void arma_fortran(arma_cgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general square matrix) + void arma_fortran(arma_sgesv)(const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesv)(const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgesv)(const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesv)(const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general square matrix, advanced form, real matrices) + void arma_fortran(arma_sgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general square matrix, advanced form, complex matrices) + void arma_fortran(arma_cgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (symmetric positive definite matrix) + void arma_fortran(arma_sposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (symmetric positive definite matrix, advanced form, real matrices) + void arma_fortran(arma_sposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, char* equed, float* s, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, char* equed, double* s, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (hermitian positive definite matrix, advanced form, complex matrices) + void arma_fortran(arma_cposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, char* equed, float* s, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, char* equed, double* s, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // solve over/under-determined system of linear equations + void arma_fortran(arma_sgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // approximately solve system of linear equations using svd (real) + void arma_fortran(arma_sgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + + // approximately solve system of linear equations using svd (complex) + void arma_fortran(arma_cgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (triangular matrix) + void arma_fortran(arma_strtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // LU factorisation (general band matrix) + void arma_fortran(arma_sgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, float* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, double* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations using pre-computed LU decomposition (general band matrix) + void arma_fortran(arma_sgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const float* ab, const blas_int* ldab, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const double* ab, const blas_int* ldab, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general band matrix) + void arma_fortran(arma_sgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general band matrix, advanced form, real matrices) + void arma_fortran(arma_sgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, float* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, double* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (general band matrix, advanced form, complex matrices) + void arma_fortran(arma_cgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_cxf* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_cxd* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (tridiagonal band matrix) + void arma_fortran(arma_sgtsv)(const blas_int* n, const blas_int* nrhs, float* dl, float* d, float* du, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsv)(const blas_int* n, const blas_int* nrhs, double* dl, double* d, double* du, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxf* dl, blas_cxf* d, blas_cxf* du, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxd* dl, blas_cxd* d, blas_cxd* du, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (tridiagonal band matrix, advanced form, real matrices) + void arma_fortran(arma_sgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const float* dl, const float* d, const float* du, float* dlf, float* df, float* duf, float* du2, blas_int* ipiv, const float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const double* dl, const double* d, const double* du, double* dlf, double* df, double* duf, double* du2, blas_int* ipiv, const double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system of linear equations (tridiagonal band matrix, advanced form, complex matrices) + void arma_fortran(arma_cgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* dl, const blas_cxf* d, const blas_cxf* du, blas_cxf* dlf, blas_cxf* df, blas_cxf* duf, blas_cxf* du2, blas_int* ipiv, const blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* dl, const blas_cxd* d, const blas_cxd* du, blas_cxd* dlf, blas_cxd* df, blas_cxd* duf, blas_cxd* du2, blas_int* ipiv, const blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // Schur decomposition (real matrices) + void arma_fortran(arma_sgees)(const char* jobvs, const char* sort, fn_select_s2 select, const blas_int* n, float* a, const blas_int* lda, blas_int* sdim, float* wr, float* wi, float* vs, const blas_int* ldvs, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgees)(const char* jobvs, const char* sort, fn_select_d2 select, const blas_int* n, double* a, const blas_int* lda, blas_int* sdim, double* wr, double* wi, double* vs, const blas_int* ldvs, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + + // Schur decomposition (complex matrices) + void arma_fortran(arma_cgees)(const char* jobvs, const char* sort, fn_select_c1 select, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* sdim, blas_cxf* w, blas_cxf* vs, const blas_int* ldvs, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgees)(const char* jobvs, const char* sort, fn_select_z1 select, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* sdim, blas_cxd* w, blas_cxd* vs, const blas_int* ldvs, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + + // solve a Sylvester equation ax + xb = c, with a and b assumed to be in Schur form + void arma_fortran(arma_strsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, const float* b, const blas_int* ldb, float* c, const blas_int* ldc, float* scale, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, const double* b, const blas_int* ldb, double* c, const blas_int* ldc, double* scale, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, const blas_cxf* b, const blas_int* ldb, blas_cxf* c, const blas_int* ldc, float* scale, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, const blas_cxd* b, const blas_int* ldb, blas_cxd* c, const blas_int* ldc, double* scale, blas_int* info) ARMA_NOEXCEPT; + + // QZ decomposition (real matrices) + void arma_fortran(arma_sgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_s3 selctg, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* sdim, float* alphar, float* alphai, float* beta, float* vsl, const blas_int* ldvsl, float* vsr, const blas_int* ldvsr, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_d3 selctg, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* sdim, double* alphar, double* alphai, double* beta, double* vsl, const blas_int* ldvsl, double* vsr, const blas_int* ldvsr, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + + // QZ decomposition (complex matrices) + void arma_fortran(arma_cgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_c2 selctg, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* sdim, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vsl, const blas_int* ldvsl, blas_cxf* vsr, const blas_int* ldvsr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_z2 selctg, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* sdim, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vsl, const blas_int* ldvsl, blas_cxd* vsr, const blas_int* ldvsr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + + // 1-norm (general matrix) + float arma_fortran(arma_slange)(const char* norm, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_dlange)(const char* norm, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; + float arma_fortran(arma_clange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; + + // 1-norm (real symmetric matrix) + float arma_fortran(arma_slansy)(const char* norm, const char* uplo, const blas_int* n, const float* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_dlansy)(const char* norm, const char* uplo, const blas_int* n, const double* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; + float arma_fortran(arma_clansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; + + // 1-norm (complex hermitian matrix) + float arma_fortran(arma_clanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; + + // 1-norm (band matrix) + float arma_fortran(arma_slangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_dlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, double* work) ARMA_NOEXCEPT; + float arma_fortran(arma_clangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, double* work) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, generic matrix) + void arma_fortran(arma_sgecon)(const char* norm, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgecon)(const char* norm, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, generic matrix) + void arma_fortran(arma_cgecon)(const char* norm, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgecon)(const char* norm, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, symmetric positive definite matrix) + void arma_fortran(arma_spocon)(const char* uplo, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpocon)(const char* uplo, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, hermitian positive definite matrix) + void arma_fortran(arma_cpocon)(const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpocon)(const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, triangular matrix) + void arma_fortran(arma_strcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const float* a, const blas_int* lda, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const double* a, const blas_int* lda, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, triangular matrix) + void arma_fortran(arma_ctrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // reciprocal of condition number (real, band matrix) + void arma_fortran(arma_sgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // reciprocal of condition number (complex, band matrix) + void arma_fortran(arma_cgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; + + // obtain parameters according to the local configuration of lapack + // NOTE: DO NOT USE THIS FORM; kept only for compatibility + // NOTE: this function takes 'name' and 'opts' argumments, which are strings with length != 1; their length needs to be given via "hidden" parameters, which this form lacks + blas_int arma_fortran(arma_ilaenv)(const blas_int* ispec, const char* name, const char* opts, const blas_int* n1, const blas_int* n2, const blas_int* n3, const blas_int* n4) ARMA_NOEXCEPT; + + // calculate eigenvalues of an upper Hessenberg matrix + void arma_fortran(arma_slahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* h, const blas_int* ldh, float* wr, float* wi, const blas_int* iloz, const blas_int* ihiz, float* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dlahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* h, const blas_int* ldh, double* wr, double* wi, const blas_int* iloz, const blas_int* ihiz, double* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; + + // calculate eigenvalues of a symmetric tridiagonal matrix + void arma_fortran(arma_sstedc)(const char* compz, const blas_int* n, float* d, float* e, float* z, const blas_int* ldz, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dstedc)(const char* compz, const blas_int* n, double* d, double* e, double* z, const blas_int* ldz, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + + // calculate eigenvectors of a Schur form matrix + void arma_fortran(arma_strevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const float* t, const blas_int* ldt, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, float* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const double* t, const blas_int* ldt, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, double* work, blas_int* info) ARMA_NOEXCEPT; + + // hessenberg decomposition + void arma_fortran(arma_sgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* a, const blas_int* lda, float* tao, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* a, const blas_int* lda, double* tao, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxf* a, const blas_int* lda, blas_cxf* tao, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxd* a, const blas_int* lda, blas_cxd* tao, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // pivoted cholesky + void arma_fortran(arma_spstrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpstrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpstrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpstrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info) ARMA_NOEXCEPT; + +#endif +} + +#undef ARMA_NOEXCEPT + +#endif diff --git a/src/armadillo/include/armadillo_bits/def_superlu.hpp b/src/armadillo/include/armadillo_bits/def_superlu.hpp new file mode 100644 index 0000000..81f6ac3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/def_superlu.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + +#if defined(ARMA_USE_SUPERLU) + +extern "C" + { + extern void arma_wrapper(sgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(dgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(cgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(zgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + + extern void arma_wrapper(sgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, float*, float*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, float*, float*, float*, float*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(dgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, double*, double*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, double*, double*, double*, double*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(cgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, float*, float*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, float*, float*, float*, float*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(zgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, double*, double*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, double*, double*, double*, double*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*); + + extern void arma_wrapper(sgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(dgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(cgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(zgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + + extern void arma_wrapper(sgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(dgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(cgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(zgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + + extern float arma_wrapper(slangs)(char* norm, superlu::SuperMatrix* A); + extern double arma_wrapper(dlangs)(char* norm, superlu::SuperMatrix* A); + extern float arma_wrapper(clangs)(char* norm, superlu::SuperMatrix* A); + extern double arma_wrapper(zlangs)(char* norm, superlu::SuperMatrix* A); + + extern void arma_wrapper(sgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, float anorm, float* rcond, superlu::SuperLUStat_t* stat, int* info); + extern void arma_wrapper(dgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, double anorm, double* rcond, superlu::SuperLUStat_t* stat, int* info); + extern void arma_wrapper(cgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, float anorm, float* rcond, superlu::SuperLUStat_t* stat, int* info); + extern void arma_wrapper(zgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, double anorm, double* rcond, superlu::SuperLUStat_t* stat, int* info); + + extern void arma_wrapper(StatInit)(superlu::SuperLUStat_t*); + extern void arma_wrapper(StatFree)(superlu::SuperLUStat_t*); + extern void arma_wrapper(set_default_options)(superlu::superlu_options_t*); + + extern void arma_wrapper(get_perm_c)(int, superlu::SuperMatrix*, int*); + extern int arma_wrapper(sp_ienv)(int); + extern void arma_wrapper(sp_preorder)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*); + + extern void arma_wrapper(Destroy_SuperNode_Matrix)(superlu::SuperMatrix*); + extern void arma_wrapper(Destroy_CompCol_Matrix)(superlu::SuperMatrix*); + extern void arma_wrapper(Destroy_CompCol_Permuted)(superlu::SuperMatrix*); + extern void arma_wrapper(Destroy_SuperMatrix_Store)(superlu::SuperMatrix*); + + // We also need superlu_malloc() and superlu_free(). + // When using the original SuperLU code directly, you (the user) may + // define USER_MALLOC and USER_FREE, but the joke is on you because + // if you are linking against SuperLU and not compiling from scratch, + // it won't actually make a difference anyway! If you've compiled + // SuperLU against a custom USER_MALLOC and USER_FREE, you're probably up + // shit creek about a thousand different ways before you even get to this + // code, so, don't do that! + + extern void* arma_wrapper(superlu_malloc)(size_t); + extern void arma_wrapper(superlu_free)(void*); + } + +#endif diff --git a/src/armadillo/include/armadillo_bits/diagmat_proxy.hpp b/src/armadillo/include/armadillo_bits/diagmat_proxy.hpp new file mode 100644 index 0000000..262dd95 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/diagmat_proxy.hpp @@ -0,0 +1,375 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diagmat_proxy +//! @{ + + + +template +class diagmat_proxy_default + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline + diagmat_proxy_default(const T1& X) + : P ( X ) + , P_is_vec( (resolves_to_vector::yes) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1) ) + , P_is_col( T1::is_col || (P.get_n_cols() == 1) ) + , n_rows ( P_is_vec ? P.get_n_elem() : P.get_n_rows() ) + , n_cols ( P_is_vec ? P.get_n_elem() : P.get_n_cols() ) + { + arma_extra_debug_sigprint(); + } + + + arma_inline + elem_type + operator[](const uword i) const + { + if(Proxy::use_at == false) + { + return P_is_vec ? P[i] : P.at(i,i); + } + else + { + if(P_is_vec) + { + return (P_is_col) ? P.at(i,0) : P.at(0,i); + } + else + { + return P.at(i,i); + } + } + } + + + arma_inline + elem_type + at(const uword row, const uword col) const + { + if(row == col) + { + if(Proxy::use_at == false) + { + return (P_is_vec) ? P[row] : P.at(row,row); + } + else + { + if(P_is_vec) + { + return (P_is_col) ? P.at(row,0) : P.at(0,row); + } + else + { + return P.at(row,row); + } + } + } + else + { + return elem_type(0); + } + } + + + inline bool is_alias(const Mat& X) const { return P.is_alias(X); } + + const Proxy P; + const bool P_is_vec; + const bool P_is_col; + const uword n_rows; + const uword n_cols; + }; + + + +template +class diagmat_proxy_fixed + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline + diagmat_proxy_fixed(const T1& X) + : P(X) + { + arma_extra_debug_sigprint(); + } + + + arma_inline + elem_type + operator[](const uword i) const + { + return (P_is_vec) ? P[i] : P.at(i,i); + } + + + arma_inline + elem_type + at(const uword row, const uword col) const + { + if(row == col) + { + return (P_is_vec) ? P[row] : P.at(row,row); + } + else + { + return elem_type(0); + } + } + + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&P)); } + + const T1& P; + + //// this may require T1::n_elem etc to be declared as static constexpr inline variables (C++17) + //// see also the notes in Mat::fixed + // static constexpr bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1); + // static constexpr uword n_rows = P_is_vec ? T1::n_elem : T1::n_rows; + // static constexpr uword n_cols = P_is_vec ? T1::n_elem : T1::n_cols; + + static const bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1); + static const uword n_rows = P_is_vec ? T1::n_elem : T1::n_rows; + static const uword n_cols = P_is_vec ? T1::n_elem : T1::n_cols; + }; + + + +template +struct diagmat_proxy_redirect {}; + +template +struct diagmat_proxy_redirect { typedef diagmat_proxy_default result; }; + +template +struct diagmat_proxy_redirect { typedef diagmat_proxy_fixed result; }; + + +template +class diagmat_proxy : public diagmat_proxy_redirect::value>::result + { + public: + inline diagmat_proxy(const T1& X) + : diagmat_proxy_redirect::value>::result(X) + { + } + }; + + + +template +class diagmat_proxy< Mat > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + inline + diagmat_proxy(const Mat& X) + : P ( X ) + , P_is_vec( (X.n_rows == 1) || (X.n_cols == 1) ) + , n_rows ( P_is_vec ? X.n_elem : X.n_rows ) + , n_cols ( P_is_vec ? X.n_elem : X.n_cols ) + { + arma_extra_debug_sigprint(); + } + + arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } + arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } + + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&P)); } + + const Mat& P; + const bool P_is_vec; + const uword n_rows; + const uword n_cols; + }; + + + +template +class diagmat_proxy< Row > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + + inline + diagmat_proxy(const Row& X) + : P(X) + , n_rows(X.n_elem) + , n_cols(X.n_elem) + { + arma_extra_debug_sigprint(); + } + + arma_inline elem_type operator[] (const uword i) const { return P[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } + + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&P)); } + + static constexpr bool P_is_vec = true; + + const Row& P; + const uword n_rows; + const uword n_cols; + }; + + + +template +class diagmat_proxy< Col > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + + inline + diagmat_proxy(const Col& X) + : P(X) + , n_rows(X.n_elem) + , n_cols(X.n_elem) + { + arma_extra_debug_sigprint(); + } + + arma_inline elem_type operator[] (const uword i) const { return P[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } + + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&P)); } + + static constexpr bool P_is_vec = true; + + const Col& P; + const uword n_rows; + const uword n_cols; + }; + + + +template +class diagmat_proxy< subview_row > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + + inline + diagmat_proxy(const subview_row& X) + : P(X) + , n_rows(X.n_elem) + , n_cols(X.n_elem) + { + arma_extra_debug_sigprint(); + } + + arma_inline elem_type operator[] (const uword i) const { return P[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } + + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&(P.m))); } + + static constexpr bool P_is_vec = true; + + const subview_row& P; + const uword n_rows; + const uword n_cols; + }; + + + +template +class diagmat_proxy< subview_col > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + + inline + diagmat_proxy(const subview_col& X) + : P(X) + , n_rows(X.n_elem) + , n_cols(X.n_elem) + { + arma_extra_debug_sigprint(); + } + + arma_inline elem_type operator[] (const uword i) const { return P[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } + + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&(P.m))); } + + static constexpr bool P_is_vec = true; + + const subview_col& P; + const uword n_rows; + const uword n_cols; + }; + + + +template +class diagmat_proxy< Glue > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline + diagmat_proxy(const Glue& X) + { + op_diagmat::apply_times(P, X.A, X.B); + + n_rows = P.n_rows; + n_cols = P.n_cols; + + arma_extra_debug_sigprint(); + } + + arma_inline elem_type operator[] (const uword i) const { return P.at(i,i); } + arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P.at(row,row) : elem_type(0); } + + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool P_is_vec = false; + + Mat P; + uword n_rows; + uword n_cols; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/diagview_bones.hpp b/src/armadillo/include/armadillo_bits/diagview_bones.hpp new file mode 100644 index 0000000..5aa4bce --- /dev/null +++ b/src/armadillo/include/armadillo_bits/diagview_bones.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diagview +//! @{ + + +//! Class for storing data required to extract and set the diagonals of a matrix +template +class diagview : public Base< eT, diagview > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + arma_aligned const Mat& m; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + const uword row_offset; + const uword col_offset; + + const uword n_rows; // equal to n_elem + const uword n_elem; + + static constexpr uword n_cols = 1; + + + protected: + + arma_inline diagview(const Mat& in_m, const uword in_row_offset, const uword in_col_offset, const uword len); + + + public: + + inline ~diagview(); + inline diagview() = delete; + + inline diagview(const diagview& in); + inline diagview( diagview&& in); + + inline void operator=(const diagview& x); + + inline void operator+=(const eT val); + inline void operator-=(const eT val); + inline void operator*=(const eT val); + inline void operator/=(const eT val); + + template inline void operator= (const Base& x); + template inline void operator+=(const Base& x); + template inline void operator-=(const Base& x); + template inline void operator%=(const Base& x); + template inline void operator/=(const Base& x); + + + arma_inline eT at_alt (const uword ii) const; + + arma_inline eT& operator[](const uword ii); + arma_inline eT operator[](const uword ii) const; + + arma_inline eT& at(const uword ii); + arma_inline eT at(const uword ii) const; + + arma_inline eT& operator()(const uword ii); + arma_inline eT operator()(const uword ii) const; + + arma_inline eT& at(const uword in_n_row, const uword); + arma_inline eT at(const uword in_n_row, const uword) const; + + arma_inline eT& operator()(const uword in_n_row, const uword in_n_col); + arma_inline eT operator()(const uword in_n_row, const uword in_n_col) const; + + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + inline void randu(); + inline void randn(); + + inline static void extract(Mat& out, const diagview& in); + + inline static void plus_inplace(Mat& out, const diagview& in); + inline static void minus_inplace(Mat& out, const diagview& in); + inline static void schur_inplace(Mat& out, const diagview& in); + inline static void div_inplace(Mat& out, const diagview& in); + + + friend class Mat; + friend class subview; + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/diagview_meat.hpp b/src/armadillo/include/armadillo_bits/diagview_meat.hpp new file mode 100644 index 0000000..e35f857 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/diagview_meat.hpp @@ -0,0 +1,1025 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diagview +//! @{ + + +template +inline +diagview::~diagview() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +arma_inline +diagview::diagview(const Mat& in_m, const uword in_row_offset, const uword in_col_offset, const uword in_len) + : m (in_m ) + , row_offset(in_row_offset) + , col_offset(in_col_offset) + , n_rows (in_len ) + , n_elem (in_len ) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +diagview::diagview(const diagview& in) + : m (in.m ) + , row_offset(in.row_offset) + , col_offset(in.col_offset) + , n_rows (in.n_rows ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + } + + + +template +inline +diagview::diagview(diagview&& in) + : m (in.m ) + , row_offset(in.row_offset) + , col_offset(in.col_offset) + , n_rows (in.n_rows ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + + // for paranoia + + access::rw(in.row_offset) = 0; + access::rw(in.col_offset) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_elem ) = 0; + } + + + +//! set a diagonal of our matrix using a diagonal from a foreign matrix +template +inline +void +diagview::operator= (const diagview& x) + { + arma_extra_debug_sigprint(); + + diagview& d = *this; + + arma_debug_check( (d.n_elem != x.n_elem), "diagview: diagonals have incompatible lengths" ); + + Mat& d_m = const_cast< Mat& >(d.m); + const Mat& x_m = x.m; + + if(&d_m != &x_m) + { + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const uword x_row_offset = x.row_offset; + const uword x_col_offset = x.col_offset; + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = x_m.at(ii + x_row_offset, ii + x_col_offset); + const eT tmp_j = x_m.at(jj + x_row_offset, jj + x_col_offset); + + d_m.at(ii + d_row_offset, ii + d_col_offset) = tmp_i; + d_m.at(jj + d_row_offset, jj + d_col_offset) = tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at(ii + d_row_offset, ii + d_col_offset) = x_m.at(ii + x_row_offset, ii + x_col_offset); + } + } + else + { + const Mat tmp = x; + + (*this).operator=(tmp); + } + } + + + +template +inline +void +diagview::operator+=(const eT val) + { + arma_extra_debug_sigprint(); + + Mat& t_m = const_cast< Mat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword ii=0; ii < t_n_elem; ++ii) + { + t_m.at( ii + t_row_offset, ii + t_col_offset) += val; + } + } + + + +template +inline +void +diagview::operator-=(const eT val) + { + arma_extra_debug_sigprint(); + + Mat& t_m = const_cast< Mat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword ii=0; ii < t_n_elem; ++ii) + { + t_m.at( ii + t_row_offset, ii + t_col_offset) -= val; + } + } + + + +template +inline +void +diagview::operator*=(const eT val) + { + arma_extra_debug_sigprint(); + + Mat& t_m = const_cast< Mat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword ii=0; ii < t_n_elem; ++ii) + { + t_m.at( ii + t_row_offset, ii + t_col_offset) *= val; + } + } + + + +template +inline +void +diagview::operator/=(const eT val) + { + arma_extra_debug_sigprint(); + + Mat& t_m = const_cast< Mat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword ii=0; ii < t_n_elem; ++ii) + { + t_m.at( ii + t_row_offset, ii + t_col_offset) /= val; + } + } + + + +//! set a diagonal of our matrix using data from a foreign object +template +template +inline +void +diagview::operator= (const Base& o) + { + arma_extra_debug_sigprint(); + + diagview& d = *this; + + Mat& d_m = const_cast< Mat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "diagview: given object has incompatible size" + ); + + const bool is_alias = P.is_alias(d_m); + + if(is_alias) { arma_extra_debug_print("aliasing detected"); } + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + { + const unwrap_check::stored_type> tmp(P.Q, is_alias); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = x_mem[ii]; + const eT tmp_j = x_mem[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) = tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) = tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) = x_mem[ii]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = Pea[ii]; + const eT tmp_j = Pea[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) = tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) = tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) = Pea[ii]; + } + } + } + + + +template +template +inline +void +diagview::operator+=(const Base& o) + { + arma_extra_debug_sigprint(); + + diagview& d = *this; + + Mat& d_m = const_cast< Mat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "diagview: given object has incompatible size" + ); + + const bool is_alias = P.is_alias(d_m); + + if(is_alias) { arma_extra_debug_print("aliasing detected"); } + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + { + const unwrap_check::stored_type> tmp(P.Q, is_alias); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = x_mem[ii]; + const eT tmp_j = x_mem[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) += tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) += tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) += x_mem[ii]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = Pea[ii]; + const eT tmp_j = Pea[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) += tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) += tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) += Pea[ii]; + } + } + } + + + +template +template +inline +void +diagview::operator-=(const Base& o) + { + arma_extra_debug_sigprint(); + + diagview& d = *this; + + Mat& d_m = const_cast< Mat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "diagview: given object has incompatible size" + ); + + const bool is_alias = P.is_alias(d_m); + + if(is_alias) { arma_extra_debug_print("aliasing detected"); } + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + { + const unwrap_check::stored_type> tmp(P.Q, is_alias); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = x_mem[ii]; + const eT tmp_j = x_mem[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) -= tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) -= tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) -= x_mem[ii]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = Pea[ii]; + const eT tmp_j = Pea[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) -= tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) -= tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) -= Pea[ii]; + } + } + } + + + +template +template +inline +void +diagview::operator%=(const Base& o) + { + arma_extra_debug_sigprint(); + + diagview& d = *this; + + Mat& d_m = const_cast< Mat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "diagview: given object has incompatible size" + ); + + const bool is_alias = P.is_alias(d_m); + + if(is_alias) { arma_extra_debug_print("aliasing detected"); } + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + { + const unwrap_check::stored_type> tmp(P.Q, is_alias); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = x_mem[ii]; + const eT tmp_j = x_mem[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) *= tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) *= tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) *= x_mem[ii]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = Pea[ii]; + const eT tmp_j = Pea[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) *= tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) *= tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) *= Pea[ii]; + } + } + } + + + +template +template +inline +void +diagview::operator/=(const Base& o) + { + arma_extra_debug_sigprint(); + + diagview& d = *this; + + Mat& d_m = const_cast< Mat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "diagview: given object has incompatible size" + ); + + const bool is_alias = P.is_alias(d_m); + + if(is_alias) { arma_extra_debug_print("aliasing detected"); } + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + { + const unwrap_check::stored_type> tmp(P.Q, is_alias); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = x_mem[ii]; + const eT tmp_j = x_mem[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) /= tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) /= tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) /= x_mem[ii]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + uword ii,jj; + for(ii=0, jj=1; jj < d_n_elem; ii+=2, jj+=2) + { + const eT tmp_i = Pea[ii]; + const eT tmp_j = Pea[jj]; + + d_m.at( ii + d_row_offset, ii + d_col_offset) /= tmp_i; + d_m.at( jj + d_row_offset, jj + d_col_offset) /= tmp_j; + } + + if(ii < d_n_elem) + { + d_m.at( ii + d_row_offset, ii + d_col_offset) /= Pea[ii]; + } + } + } + + + +//! extract a diagonal and store it as a column vector +template +inline +void +diagview::extract(Mat& out, const diagview& in) + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the matrix has already been set to the correct size and there is no aliasing; + // size setting and alias checking is done by either the Mat contructor or operator=() + + const Mat& in_m = in.m; + + const uword in_n_elem = in.n_elem; + const uword in_row_offset = in.row_offset; + const uword in_col_offset = in.col_offset; + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < in_n_elem; i+=2, j+=2) + { + const eT tmp_i = in_m.at( i + in_row_offset, i + in_col_offset ); + const eT tmp_j = in_m.at( j + in_row_offset, j + in_col_offset ); + + out_mem[i] = tmp_i; + out_mem[j] = tmp_j; + } + + if(i < in_n_elem) + { + out_mem[i] = in_m.at( i + in_row_offset, i + in_col_offset ); + } + } + + + +//! X += Y.diag() +template +inline +void +diagview::plus_inplace(Mat& out, const diagview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, in.n_rows, in.n_cols, "addition"); + + const Mat& in_m = in.m; + + const uword in_n_elem = in.n_elem; + const uword in_row_offset = in.row_offset; + const uword in_col_offset = in.col_offset; + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < in_n_elem; i+=2, j+=2) + { + const eT tmp_i = in_m.at( i + in_row_offset, i + in_col_offset ); + const eT tmp_j = in_m.at( j + in_row_offset, j + in_col_offset ); + + out_mem[i] += tmp_i; + out_mem[j] += tmp_j; + } + + if(i < in_n_elem) + { + out_mem[i] += in_m.at( i + in_row_offset, i + in_col_offset ); + } + } + + + +//! X -= Y.diag() +template +inline +void +diagview::minus_inplace(Mat& out, const diagview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, in.n_rows, in.n_cols, "subtraction"); + + const Mat& in_m = in.m; + + const uword in_n_elem = in.n_elem; + const uword in_row_offset = in.row_offset; + const uword in_col_offset = in.col_offset; + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < in_n_elem; i+=2, j+=2) + { + const eT tmp_i = in_m.at( i + in_row_offset, i + in_col_offset ); + const eT tmp_j = in_m.at( j + in_row_offset, j + in_col_offset ); + + out_mem[i] -= tmp_i; + out_mem[j] -= tmp_j; + } + + if(i < in_n_elem) + { + out_mem[i] -= in_m.at( i + in_row_offset, i + in_col_offset ); + } + } + + + +//! X %= Y.diag() +template +inline +void +diagview::schur_inplace(Mat& out, const diagview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, in.n_rows, in.n_cols, "element-wise multiplication"); + + const Mat& in_m = in.m; + + const uword in_n_elem = in.n_elem; + const uword in_row_offset = in.row_offset; + const uword in_col_offset = in.col_offset; + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < in_n_elem; i+=2, j+=2) + { + const eT tmp_i = in_m.at( i + in_row_offset, i + in_col_offset ); + const eT tmp_j = in_m.at( j + in_row_offset, j + in_col_offset ); + + out_mem[i] *= tmp_i; + out_mem[j] *= tmp_j; + } + + if(i < in_n_elem) + { + out_mem[i] *= in_m.at( i + in_row_offset, i + in_col_offset ); + } + } + + + +//! X /= Y.diag() +template +inline +void +diagview::div_inplace(Mat& out, const diagview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, in.n_rows, in.n_cols, "element-wise division"); + + const Mat& in_m = in.m; + + const uword in_n_elem = in.n_elem; + const uword in_row_offset = in.row_offset; + const uword in_col_offset = in.col_offset; + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < in_n_elem; i+=2, j+=2) + { + const eT tmp_i = in_m.at( i + in_row_offset, i + in_col_offset ); + const eT tmp_j = in_m.at( j + in_row_offset, j + in_col_offset ); + + out_mem[i] /= tmp_i; + out_mem[j] /= tmp_j; + } + + if(i < in_n_elem) + { + out_mem[i] /= in_m.at( i + in_row_offset, i + in_col_offset ); + } + } + + + +template +arma_inline +eT +diagview::at_alt(const uword ii) const + { + return m.at(ii+row_offset, ii+col_offset); + } + + + +template +arma_inline +eT& +diagview::operator[](const uword ii) + { + return (const_cast< Mat& >(m)).at(ii+row_offset, ii+col_offset); + } + + + +template +arma_inline +eT +diagview::operator[](const uword ii) const + { + return m.at(ii+row_offset, ii+col_offset); + } + + + +template +arma_inline +eT& +diagview::at(const uword ii) + { + return (const_cast< Mat& >(m)).at(ii+row_offset, ii+col_offset); + } + + + +template +arma_inline +eT +diagview::at(const uword ii) const + { + return m.at(ii+row_offset, ii+col_offset); + } + + + +template +arma_inline +eT& +diagview::operator()(const uword ii) + { + arma_debug_check_bounds( (ii >= n_elem), "diagview::operator(): out of bounds" ); + + return (const_cast< Mat& >(m)).at(ii+row_offset, ii+col_offset); + } + + + +template +arma_inline +eT +diagview::operator()(const uword ii) const + { + arma_debug_check_bounds( (ii >= n_elem), "diagview::operator(): out of bounds" ); + + return m.at(ii+row_offset, ii+col_offset); + } + + + +template +arma_inline +eT& +diagview::at(const uword row, const uword) + { + return (const_cast< Mat& >(m)).at(row+row_offset, row+col_offset); + } + + + +template +arma_inline +eT +diagview::at(const uword row, const uword) const + { + return m.at(row+row_offset, row+col_offset); + } + + + +template +arma_inline +eT& +diagview::operator()(const uword row, const uword col) + { + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "diagview::operator(): out of bounds" ); + + return (const_cast< Mat& >(m)).at(row+row_offset, row+col_offset); + } + + + +template +arma_inline +eT +diagview::operator()(const uword row, const uword col) const + { + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "diagview::operator(): out of bounds" ); + + return m.at(row+row_offset, row+col_offset); + } + + + +template +inline +void +diagview::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + Mat& x = const_cast< Mat& >(m); + + const uword local_n_elem = n_elem; + + if(arma_isnan(old_val)) + { + for(uword ii=0; ii < local_n_elem; ++ii) + { + eT& val = x.at(ii+row_offset, ii+col_offset); + + val = (arma_isnan(val)) ? new_val : val; + } + } + else + { + for(uword ii=0; ii < local_n_elem; ++ii) + { + eT& val = x.at(ii+row_offset, ii+col_offset); + + val = (val == old_val) ? new_val : val; + } + } + } + + + +template +inline +void +diagview::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +diagview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +diagview::fill(const eT val) + { + arma_extra_debug_sigprint(); + + Mat& x = const_cast< Mat& >(m); + + const uword local_n_elem = n_elem; + + for(uword ii=0; ii < local_n_elem; ++ii) + { + x.at(ii+row_offset, ii+col_offset) = val; + } + } + + + +template +inline +void +diagview::zeros() + { + arma_extra_debug_sigprint(); + + (*this).fill(eT(0)); + } + + + +template +inline +void +diagview::ones() + { + arma_extra_debug_sigprint(); + + (*this).fill(eT(1)); + } + + + +template +inline +void +diagview::randu() + { + arma_extra_debug_sigprint(); + + Mat& x = const_cast< Mat& >(m); + + const uword local_n_elem = n_elem; + + for(uword ii=0; ii < local_n_elem; ++ii) + { + x.at(ii+row_offset, ii+col_offset) = eT(arma_rng::randu()); + } + } + + + +template +inline +void +diagview::randn() + { + arma_extra_debug_sigprint(); + + Mat& x = const_cast< Mat& >(m); + + const uword local_n_elem = n_elem; + + for(uword ii=0; ii < local_n_elem; ++ii) + { + x.at(ii+row_offset, ii+col_offset) = eT(arma_rng::randn()); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/diskio_bones.hpp b/src/armadillo/include/armadillo_bits/diskio_bones.hpp new file mode 100644 index 0000000..03e1ac5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/diskio_bones.hpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diskio +//! @{ + + +//! class for saving and loading matrices and fields - INTERNAL USE ONLY! +class diskio + { + public: + + arma_deprecated inline static file_type guess_file_type(std::istream& f); + + + private: + + template friend class Mat; + template friend class Cube; + template friend class SpMat; + template friend class field; + + friend class Mat_aux; + friend class Cube_aux; + friend class SpMat_aux; + friend class field_aux; + + template arma_cold inline static std::string gen_txt_header(const Mat&); + template arma_cold inline static std::string gen_bin_header(const Mat&); + + template arma_cold inline static std::string gen_bin_header(const SpMat&); + + template arma_cold inline static std::string gen_txt_header(const Cube&); + template arma_cold inline static std::string gen_bin_header(const Cube&); + + arma_cold inline static file_type guess_file_type_internal(std::istream& f); + + arma_cold inline static std::string gen_tmp_name(const std::string& x); + + arma_cold inline static bool safe_rename(const std::string& old_name, const std::string& new_name); + + arma_cold inline static bool is_readable(const std::string& name); + + arma_cold inline static void sanitise_token(std::string& token); + + template inline static bool convert_token(eT& val, const std::string& token); + template inline static bool convert_token(std::complex& val, const std::string& token); + + template inline static bool convert_token_strict(eT& val, const std::string& token); + + template inline static std::streamsize prepare_stream(std::ostream& f); + + + // + // matrix saving + + template inline static bool save_raw_ascii (const Mat& x, const std::string& final_name); + template inline static bool save_raw_binary (const Mat& x, const std::string& final_name); + template inline static bool save_arma_ascii (const Mat& x, const std::string& final_name); + template inline static bool save_csv_ascii (const Mat& x, const std::string& final_name, const field& header, const bool with_header, const char separator); + template inline static bool save_coord_ascii(const Mat& x, const std::string& final_name); + template inline static bool save_arma_binary(const Mat& x, const std::string& final_name); + template inline static bool save_pgm_binary (const Mat& x, const std::string& final_name); + template inline static bool save_pgm_binary (const Mat< std::complex >& x, const std::string& final_name); + template inline static bool save_hdf5_binary(const Mat& x, const hdf5_name& spec, std::string& err_msg); + + template inline static bool save_raw_ascii (const Mat& x, std::ostream& f); + template inline static bool save_raw_binary (const Mat& x, std::ostream& f); + template inline static bool save_arma_ascii (const Mat& x, std::ostream& f); + template inline static bool save_csv_ascii (const Mat& x, std::ostream& f, const char separator); + template inline static bool save_csv_ascii (const Mat< std::complex >& x, std::ostream& f, const char separator); + template inline static bool save_coord_ascii(const Mat& x, std::ostream& f); + template inline static bool save_coord_ascii(const Mat< std::complex >& x, std::ostream& f); + template inline static bool save_arma_binary(const Mat& x, std::ostream& f); + template inline static bool save_pgm_binary (const Mat& x, std::ostream& f); + template inline static bool save_pgm_binary (const Mat< std::complex >& x, std::ostream& f); + + + // + // matrix loading + + template inline static bool load_raw_ascii (Mat& x, const std::string& name, std::string& err_msg); + template inline static bool load_raw_binary (Mat& x, const std::string& name, std::string& err_msg); + template inline static bool load_arma_ascii (Mat& x, const std::string& name, std::string& err_msg); + template inline static bool load_csv_ascii (Mat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator, const bool strict); + template inline static bool load_coord_ascii(Mat& x, const std::string& name, std::string& err_msg); + template inline static bool load_arma_binary(Mat& x, const std::string& name, std::string& err_msg); + template inline static bool load_pgm_binary (Mat& x, const std::string& name, std::string& err_msg); + template inline static bool load_pgm_binary (Mat< std::complex >& x, const std::string& name, std::string& err_msg); + template inline static bool load_hdf5_binary(Mat& x, const hdf5_name& spec, std::string& err_msg); + template inline static bool load_auto_detect(Mat& x, const std::string& name, std::string& err_msg); + + template inline static bool load_raw_ascii (Mat& x, std::istream& f, std::string& err_msg); + template inline static bool load_raw_binary (Mat& x, std::istream& f, std::string& err_msg); + template inline static bool load_arma_ascii (Mat& x, std::istream& f, std::string& err_msg); + template inline static bool load_csv_ascii (Mat& x, std::istream& f, std::string& err_msg, const char separator, const bool strict); + template inline static bool load_csv_ascii (Mat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator, const bool strict); + template inline static bool load_coord_ascii(Mat& x, std::istream& f, std::string& err_msg); + template inline static bool load_coord_ascii(Mat< std::complex >& x, std::istream& f, std::string& err_msg); + template inline static bool load_arma_binary(Mat& x, std::istream& f, std::string& err_msg); + template inline static bool load_pgm_binary (Mat& x, std::istream& is, std::string& err_msg); + template inline static bool load_pgm_binary (Mat< std::complex >& x, std::istream& is, std::string& err_msg); + template inline static bool load_auto_detect(Mat& x, std::istream& f, std::string& err_msg); + + inline static void pnm_skip_comments(std::istream& f); + + + // + // sparse matrix saving + + template inline static bool save_csv_ascii (const SpMat& x, const std::string& final_name, const field& header, const bool with_header, const char separator); + template inline static bool save_coord_ascii(const SpMat& x, const std::string& final_name); + template inline static bool save_arma_binary(const SpMat& x, const std::string& final_name); + + template inline static bool save_csv_ascii (const SpMat& x, std::ostream& f, const char separator); + template inline static bool save_csv_ascii (const SpMat< std::complex >& x, std::ostream& f, const char separator); + template inline static bool save_coord_ascii(const SpMat& x, std::ostream& f); + template inline static bool save_coord_ascii(const SpMat< std::complex >& x, std::ostream& f); + template inline static bool save_arma_binary(const SpMat& x, std::ostream& f); + + + // + // sparse matrix loading + + template inline static bool load_csv_ascii (SpMat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator); + template inline static bool load_coord_ascii(SpMat& x, const std::string& name, std::string& err_msg); + template inline static bool load_arma_binary(SpMat& x, const std::string& name, std::string& err_msg); + + template inline static bool load_csv_ascii (SpMat& x, std::istream& f, std::string& err_msg, const char separator); + template inline static bool load_csv_ascii (SpMat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator); + template inline static bool load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg); + template inline static bool load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::string& err_msg); + template inline static bool load_arma_binary(SpMat& x, std::istream& f, std::string& err_msg); + + + + // + // cube saving + + template inline static bool save_raw_ascii (const Cube& x, const std::string& name); + template inline static bool save_raw_binary (const Cube& x, const std::string& name); + template inline static bool save_arma_ascii (const Cube& x, const std::string& name); + template inline static bool save_arma_binary(const Cube& x, const std::string& name); + template inline static bool save_hdf5_binary(const Cube& x, const hdf5_name& spec, std::string& err_msg); + + template inline static bool save_raw_ascii (const Cube& x, std::ostream& f); + template inline static bool save_raw_binary (const Cube& x, std::ostream& f); + template inline static bool save_arma_ascii (const Cube& x, std::ostream& f); + template inline static bool save_arma_binary(const Cube& x, std::ostream& f); + + + // + // cube loading + + template inline static bool load_raw_ascii (Cube& x, const std::string& name, std::string& err_msg); + template inline static bool load_raw_binary (Cube& x, const std::string& name, std::string& err_msg); + template inline static bool load_arma_ascii (Cube& x, const std::string& name, std::string& err_msg); + template inline static bool load_arma_binary(Cube& x, const std::string& name, std::string& err_msg); + template inline static bool load_hdf5_binary(Cube& x, const hdf5_name& spec, std::string& err_msg); + template inline static bool load_auto_detect(Cube& x, const std::string& name, std::string& err_msg); + + template inline static bool load_raw_ascii (Cube& x, std::istream& f, std::string& err_msg); + template inline static bool load_raw_binary (Cube& x, std::istream& f, std::string& err_msg); + template inline static bool load_arma_ascii (Cube& x, std::istream& f, std::string& err_msg); + template inline static bool load_arma_binary(Cube& x, std::istream& f, std::string& err_msg); + template inline static bool load_auto_detect(Cube& x, std::istream& f, std::string& err_msg); + + + // + // field saving and loading + + template inline static bool save_arma_binary(const field& x, const std::string& name); + template inline static bool save_arma_binary(const field& x, std::ostream& f); + + template inline static bool load_arma_binary( field& x, const std::string& name, std::string& err_msg); + template inline static bool load_arma_binary( field& x, std::istream& f, std::string& err_msg); + + template inline static bool load_auto_detect( field& x, const std::string& name, std::string& err_msg); + template inline static bool load_auto_detect( field& x, std::istream& f, std::string& err_msg); + + inline static bool save_std_string(const field& x, const std::string& name); + inline static bool save_std_string(const field& x, std::ostream& f); + + inline static bool load_std_string( field& x, const std::string& name, std::string& err_msg); + inline static bool load_std_string( field& x, std::istream& f, std::string& err_msg); + + + + // + // handling of PPM images by cubes + + template inline static bool save_ppm_binary(const Cube& x, const std::string& final_name); + template inline static bool save_ppm_binary(const Cube& x, std::ostream& f); + + template inline static bool load_ppm_binary( Cube& x, const std::string& final_name, std::string& err_msg); + template inline static bool load_ppm_binary( Cube& x, std::istream& f, std::string& err_msg); + + + // + // handling of PPM images by fields + + template inline static bool save_ppm_binary(const field& x, const std::string& final_name); + template inline static bool save_ppm_binary(const field& x, std::ostream& f); + + template inline static bool load_ppm_binary( field& x, const std::string& final_name, std::string& err_msg); + template inline static bool load_ppm_binary( field& x, std::istream& f, std::string& err_msg); + + + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/diskio_meat.hpp b/src/armadillo/include/armadillo_bits/diskio_meat.hpp new file mode 100644 index 0000000..4f716cc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/diskio_meat.hpp @@ -0,0 +1,5356 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diskio +//! @{ + + +//! Generate the first line of the header used for saving matrices in text format. +//! Format: "ARMA_MAT_TXT_ABXYZ". +//! A is one of: I (for integral types) or F (for floating point types). +//! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. +template +inline +std::string +diskio::gen_txt_header(const Mat&) + { + arma_type_check(( is_supported_elem_type::value == false )); + + const char* ARMA_MAT_TXT_IU001 = "ARMA_MAT_TXT_IU001"; + const char* ARMA_MAT_TXT_IS001 = "ARMA_MAT_TXT_IS001"; + const char* ARMA_MAT_TXT_IU002 = "ARMA_MAT_TXT_IU002"; + const char* ARMA_MAT_TXT_IS002 = "ARMA_MAT_TXT_IS002"; + const char* ARMA_MAT_TXT_IU004 = "ARMA_MAT_TXT_IU004"; + const char* ARMA_MAT_TXT_IS004 = "ARMA_MAT_TXT_IS004"; + const char* ARMA_MAT_TXT_IU008 = "ARMA_MAT_TXT_IU008"; + const char* ARMA_MAT_TXT_IS008 = "ARMA_MAT_TXT_IS008"; + const char* ARMA_MAT_TXT_FN004 = "ARMA_MAT_TXT_FN004"; + const char* ARMA_MAT_TXT_FN008 = "ARMA_MAT_TXT_FN008"; + const char* ARMA_MAT_TXT_FC008 = "ARMA_MAT_TXT_FC008"; + const char* ARMA_MAT_TXT_FC016 = "ARMA_MAT_TXT_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_MAT_TXT_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_MAT_TXT_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_MAT_TXT_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_MAT_TXT_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_MAT_TXT_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_MAT_TXT_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_MAT_TXT_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_MAT_TXT_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_MAT_TXT_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_MAT_TXT_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_MAT_TXT_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_MAT_TXT_IS008); } + else if( is_float::value) { header = const_cast(ARMA_MAT_TXT_FN004); } + else if( is_double::value) { header = const_cast(ARMA_MAT_TXT_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_MAT_TXT_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_MAT_TXT_FC016); } + + return std::string(header); + } + + + +//! Generate the first line of the header used for saving matrices in binary format. +//! Format: "ARMA_MAT_BIN_ABXYZ". +//! A is one of: I (for integral types) or F (for floating point types). +//! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. +template +inline +std::string +diskio::gen_bin_header(const Mat&) + { + arma_type_check(( is_supported_elem_type::value == false )); + + const char* ARMA_MAT_BIN_IU001 = "ARMA_MAT_BIN_IU001"; + const char* ARMA_MAT_BIN_IS001 = "ARMA_MAT_BIN_IS001"; + const char* ARMA_MAT_BIN_IU002 = "ARMA_MAT_BIN_IU002"; + const char* ARMA_MAT_BIN_IS002 = "ARMA_MAT_BIN_IS002"; + const char* ARMA_MAT_BIN_IU004 = "ARMA_MAT_BIN_IU004"; + const char* ARMA_MAT_BIN_IS004 = "ARMA_MAT_BIN_IS004"; + const char* ARMA_MAT_BIN_IU008 = "ARMA_MAT_BIN_IU008"; + const char* ARMA_MAT_BIN_IS008 = "ARMA_MAT_BIN_IS008"; + const char* ARMA_MAT_BIN_FN004 = "ARMA_MAT_BIN_FN004"; + const char* ARMA_MAT_BIN_FN008 = "ARMA_MAT_BIN_FN008"; + const char* ARMA_MAT_BIN_FC008 = "ARMA_MAT_BIN_FC008"; + const char* ARMA_MAT_BIN_FC016 = "ARMA_MAT_BIN_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_MAT_BIN_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_MAT_BIN_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_MAT_BIN_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_MAT_BIN_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_MAT_BIN_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_MAT_BIN_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_MAT_BIN_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_MAT_BIN_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_MAT_BIN_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_MAT_BIN_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_MAT_BIN_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_MAT_BIN_IS008); } + else if( is_float::value) { header = const_cast(ARMA_MAT_BIN_FN004); } + else if( is_double::value) { header = const_cast(ARMA_MAT_BIN_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_MAT_BIN_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_MAT_BIN_FC016); } + + return std::string(header); + } + + + +//! Generate the first line of the header used for saving matrices in binary format. +//! Format: "ARMA_SPM_BIN_ABXYZ". +//! A is one of: I (for integral types) or F (for floating point types). +//! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. +template +inline +std::string +diskio::gen_bin_header(const SpMat&) + { + arma_type_check(( is_supported_elem_type::value == false )); + + const char* ARMA_SPM_BIN_IU001 = "ARMA_SPM_BIN_IU001"; + const char* ARMA_SPM_BIN_IS001 = "ARMA_SPM_BIN_IS001"; + const char* ARMA_SPM_BIN_IU002 = "ARMA_SPM_BIN_IU002"; + const char* ARMA_SPM_BIN_IS002 = "ARMA_SPM_BIN_IS002"; + const char* ARMA_SPM_BIN_IU004 = "ARMA_SPM_BIN_IU004"; + const char* ARMA_SPM_BIN_IS004 = "ARMA_SPM_BIN_IS004"; + const char* ARMA_SPM_BIN_IU008 = "ARMA_SPM_BIN_IU008"; + const char* ARMA_SPM_BIN_IS008 = "ARMA_SPM_BIN_IS008"; + const char* ARMA_SPM_BIN_FN004 = "ARMA_SPM_BIN_FN004"; + const char* ARMA_SPM_BIN_FN008 = "ARMA_SPM_BIN_FN008"; + const char* ARMA_SPM_BIN_FC008 = "ARMA_SPM_BIN_FC008"; + const char* ARMA_SPM_BIN_FC016 = "ARMA_SPM_BIN_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_SPM_BIN_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_SPM_BIN_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_SPM_BIN_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_SPM_BIN_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_SPM_BIN_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_SPM_BIN_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_SPM_BIN_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_SPM_BIN_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_SPM_BIN_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_SPM_BIN_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_SPM_BIN_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_SPM_BIN_IS008); } + else if( is_float::value) { header = const_cast(ARMA_SPM_BIN_FN004); } + else if( is_double::value) { header = const_cast(ARMA_SPM_BIN_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_SPM_BIN_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_SPM_BIN_FC016); } + + return std::string(header); + } + + +//! Generate the first line of the header used for saving cubes in text format. +//! Format: "ARMA_CUB_TXT_ABXYZ". +//! A is one of: I (for integral types) or F (for floating point types). +//! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. +template +inline +std::string +diskio::gen_txt_header(const Cube&) + { + arma_type_check(( is_supported_elem_type::value == false )); + + const char* ARMA_CUB_TXT_IU001 = "ARMA_CUB_TXT_IU001"; + const char* ARMA_CUB_TXT_IS001 = "ARMA_CUB_TXT_IS001"; + const char* ARMA_CUB_TXT_IU002 = "ARMA_CUB_TXT_IU002"; + const char* ARMA_CUB_TXT_IS002 = "ARMA_CUB_TXT_IS002"; + const char* ARMA_CUB_TXT_IU004 = "ARMA_CUB_TXT_IU004"; + const char* ARMA_CUB_TXT_IS004 = "ARMA_CUB_TXT_IS004"; + const char* ARMA_CUB_TXT_IU008 = "ARMA_CUB_TXT_IU008"; + const char* ARMA_CUB_TXT_IS008 = "ARMA_CUB_TXT_IS008"; + const char* ARMA_CUB_TXT_FN004 = "ARMA_CUB_TXT_FN004"; + const char* ARMA_CUB_TXT_FN008 = "ARMA_CUB_TXT_FN008"; + const char* ARMA_CUB_TXT_FC008 = "ARMA_CUB_TXT_FC008"; + const char* ARMA_CUB_TXT_FC016 = "ARMA_CUB_TXT_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_CUB_TXT_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_CUB_TXT_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_CUB_TXT_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_CUB_TXT_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_CUB_TXT_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_CUB_TXT_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_CUB_TXT_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_CUB_TXT_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_CUB_TXT_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_CUB_TXT_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_CUB_TXT_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_CUB_TXT_IS008); } + else if( is_float::value) { header = const_cast(ARMA_CUB_TXT_FN004); } + else if( is_double::value) { header = const_cast(ARMA_CUB_TXT_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_CUB_TXT_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_CUB_TXT_FC016); } + + return std::string(header); + } + + + +//! Generate the first line of the header used for saving cubes in binary format. +//! Format: "ARMA_CUB_BIN_ABXYZ". +//! A is one of: I (for integral types) or F (for floating point types). +//! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. +template +inline +std::string +diskio::gen_bin_header(const Cube&) + { + arma_type_check(( is_supported_elem_type::value == false )); + + const char* ARMA_CUB_BIN_IU001 = "ARMA_CUB_BIN_IU001"; + const char* ARMA_CUB_BIN_IS001 = "ARMA_CUB_BIN_IS001"; + const char* ARMA_CUB_BIN_IU002 = "ARMA_CUB_BIN_IU002"; + const char* ARMA_CUB_BIN_IS002 = "ARMA_CUB_BIN_IS002"; + const char* ARMA_CUB_BIN_IU004 = "ARMA_CUB_BIN_IU004"; + const char* ARMA_CUB_BIN_IS004 = "ARMA_CUB_BIN_IS004"; + const char* ARMA_CUB_BIN_IU008 = "ARMA_CUB_BIN_IU008"; + const char* ARMA_CUB_BIN_IS008 = "ARMA_CUB_BIN_IS008"; + const char* ARMA_CUB_BIN_FN004 = "ARMA_CUB_BIN_FN004"; + const char* ARMA_CUB_BIN_FN008 = "ARMA_CUB_BIN_FN008"; + const char* ARMA_CUB_BIN_FC008 = "ARMA_CUB_BIN_FC008"; + const char* ARMA_CUB_BIN_FC016 = "ARMA_CUB_BIN_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_CUB_BIN_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_CUB_BIN_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_CUB_BIN_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_CUB_BIN_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_CUB_BIN_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_CUB_BIN_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_CUB_BIN_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_CUB_BIN_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_CUB_BIN_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_CUB_BIN_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_CUB_BIN_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_CUB_BIN_IS008); } + else if( is_float::value) { header = const_cast(ARMA_CUB_BIN_FN004); } + else if( is_double::value) { header = const_cast(ARMA_CUB_BIN_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_CUB_BIN_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_CUB_BIN_FC016); } + + return std::string(header); + } + + + +inline +file_type +diskio::guess_file_type(std::istream& f) + { + arma_extra_debug_sigprint(); + + return diskio::guess_file_type_internal(f); + } + + + +inline +file_type +diskio::guess_file_type_internal(std::istream& f) + { + arma_extra_debug_sigprint(); + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + f.clear(); + f.seekg(0, ios::end); + + f.clear(); + const std::fstream::pos_type pos2 = f.tellg(); + + const uword N_max = ( (pos1 >= 0) && (pos2 >= 0) && (pos2 > pos1) ) ? uword(pos2 - pos1) : uword(0); + + f.clear(); + f.seekg(pos1); + + if(N_max == 0) { return file_type_unknown; } + + const uword N_use = (std::min)(N_max, uword(4096)); + + podarray data(N_use); + data.zeros(); + + unsigned char* data_mem = data.memptr(); + + f.clear(); + f.read( reinterpret_cast(data_mem), std::streamsize(N_use) ); + + const bool load_okay = f.good(); + + f.clear(); + f.seekg(pos1); + + if(load_okay == false) { return file_type_unknown; } + + bool has_binary = false; + bool has_bracket = false; + bool has_comma = false; + bool has_semicolon = false; + + for(uword i=0; i= 123) ) { has_binary = true; break; } // the range checking can be made more elaborate + + if( (val == '(') || (val == ')') ) { has_bracket = true; } + + if( (val == ';') ) { has_semicolon = true; } + + if( (val == ',') ) { has_comma = true; } + } + + if(has_binary) { return raw_binary; } + + // ssv_ascii has to be before csv_ascii; + // if the data has semicolons, it suggests a CSV file with semicolon as the separating character; + // the semicolon may be used to allow the comma character to represent the decimal seperator (eg. 1,2345 vs 1.2345) + + if(has_semicolon && (has_bracket == false)) { return ssv_ascii; } + + if(has_comma && (has_bracket == false)) { return csv_ascii; } + + return raw_ascii; + } + + + +//! Append a quasi-random string to the given filename. +//! Avoiding use of rand() to preserve its state. +inline +std::string +diskio::gen_tmp_name(const std::string& x) + { + union { uword val; void* ptr; } u; + + u.val = uword(0); + u.ptr = const_cast(&x); + + const u16 a = u16( (u.val >> 8) & 0xFFFF ); + const u16 b = u16( (std::clock()) & 0xFFFF ); + + std::ostringstream ss; + + ss << x << ".tmp_"; + + ss.setf(std::ios_base::hex, std::ios_base::basefield); + + ss.width(4); + ss.fill('0'); + ss << a; + + ss.width(4); + ss.fill('0'); + ss << b; + + return ss.str(); + } + + + +//! Safely rename a file. +//! Before renaming, test if we can write to the final file. +//! This should prevent: +//! (i) overwriting files that are write protected, +//! (ii) overwriting directories. +inline +bool +diskio::safe_rename(const std::string& old_name, const std::string& new_name) + { + const char* new_name_c_str = new_name.c_str(); + + std::fstream f(new_name_c_str, std::fstream::out | std::fstream::app); + f.put(' '); + + if(f.good()) { f.close(); } else { return false; } + + if(std::remove( new_name_c_str) != 0) { return false; } + if(std::rename(old_name.c_str(), new_name_c_str) != 0) { return false; } + + return true; + } + + + +inline +bool +diskio::is_readable(const std::string& name) + { + std::ifstream f; + + f.open(name, std::fstream::binary); + + // std::ifstream destructor will close the file + + return (f.is_open()); + } + + + +inline +void +diskio::sanitise_token(std::string& token) + { + // remove spaces, tabs, carriage returns + + if(token.length() == 0) { return; } + + const char c_front = token.front(); + const char c_back = token.back(); + + if( (c_front == ' ') || (c_front == '\t') || (c_front == '\r') || (c_back == ' ') || (c_back == '\t') || (c_back == '\r') ) + { + token.erase(std::remove_if(token.begin(), token.end(), [](char c) { return ((c == ' ') || (c == '\t') || (c == '\r')); }), token.end()); + } + } + + + +template +inline +bool +diskio::convert_token(eT& val, const std::string& token) + { + const size_t N = size_t(token.length()); + + const char* str = token.c_str(); + + if( (N == 0) || ((N == 1) && (str[0] == '0')) ) { val = eT(0); return true; } + + if( (N == 3) || (N == 4) ) + { + const bool neg = (str[0] == '-'); + const bool pos = (str[0] == '+'); + + const size_t offset = ( (neg || pos) && (N == 4) ) ? 1 : 0; + + const char sig_a = str[offset ]; + const char sig_b = str[offset+1]; + const char sig_c = str[offset+2]; + + if( ((sig_a == 'i') || (sig_a == 'I')) && ((sig_b == 'n') || (sig_b == 'N')) && ((sig_c == 'f') || (sig_c == 'F')) ) + { + val = neg ? cond_rel< is_signed::value >::make_neg(Datum::inf) : Datum::inf; + + return true; + } + else + if( ((sig_a == 'n') || (sig_a == 'N')) && ((sig_b == 'a') || (sig_b == 'A')) && ((sig_c == 'n') || (sig_c == 'N')) ) + { + val = Datum::nan; + + return true; + } + } + + // #if (defined(ARMA_HAVE_CXX17) && (__cpp_lib_to_chars >= 201611L)) + // { + // // std::from_chars() doesn't handle leading whitespace + // // std::from_chars() doesn't handle leading + sign + // // std::from_chars() handles only the decimal point (.) as the decimal seperator + // + // const char str0 = str[0]; + // const bool start_ok = ((str0 != ' ') && (str0 != '\t') && (str0 != '+')); + // + // bool has_comma = false; + // for(uword i=0; i::value) + { + val = eT( std::strtod(str, &endptr) ); + } + else + { + if(is_signed::value) + { + // signed integer + + val = eT( std::strtoll(str, &endptr, 10) ); + } + else + { + // unsigned integer + + if((str[0] == '-') && (N >= 2)) + { + val = eT(0); + + if((str[1] == '-') || (str[1] == '+')) { return false; } + + const char* str_offset1 = &(str[1]); + + std::strtoull(str_offset1, &endptr, 10); + + if(str_offset1 == endptr) { return false; } + + return true; + } + + val = eT( std::strtoull(str, &endptr, 10) ); + } + } + + if(str == endptr) { return false; } + + return true; + } + + + +template +inline +bool +diskio::convert_token(std::complex& val, const std::string& token) + { + const size_t N = size_t(token.length()); + const size_t Nm1 = N-1; + + if(N == 0) { val = std::complex(0); return true; } + + const char* str = token.c_str(); + + // valid complex number formats: + // (real,imag) + // (real) + // () + + if( (token[0] != '(') || (token[Nm1] != ')') ) + { + // no brackets, so treat the token as a non-complex number + + T val_real; + + const bool state = diskio::convert_token(val_real, token); // use the non-complex version of this function + + val = std::complex(val_real); + + return state; + } + + // does the token contain only the () brackets? + if(N <= 2) { val = std::complex(0); return true; } + + size_t comma_loc = 0; + bool comma_found = false; + + for(size_t i=0; i(val_real); + } + else + { + const std::string token_real( &(str[1]), (comma_loc - 1 ) ); + const std::string token_imag( &(str[comma_loc+1]), (Nm1 - 1 - comma_loc) ); + + T val_real; + T val_imag; + + const bool state_real = diskio::convert_token(val_real, token_real); + const bool state_imag = diskio::convert_token(val_imag, token_imag); + + state = (state_real && state_imag); + + val = std::complex(val_real, val_imag); + } + + return state; + } + + + +template +inline +bool +diskio::convert_token_strict(eT& val, const std::string& token) + { + const size_t N = size_t(token.length()); + + const bool status = (N > 0) ? diskio::convert_token(val, token) : false; + + if(status == false) { val = Datum::nan; } + + return status; + } + + + +template +inline +std::streamsize +diskio::prepare_stream(std::ostream& f) + { + std::streamsize cell_width = f.width(); + + if(is_real::value) + { + f.unsetf(ios::fixed); + f.setf(ios::scientific); + f.fill(' '); + + f.precision(16); + cell_width = 24; + + // NOTE: for 'float' the optimum settings are f.precision(8) and cell_width = 15 + // NOTE: however, to avoid introducing errors in case single precision data is loaded as double precision, + // NOTE: the same settings must be used for both 'float' and 'double' + } + else + if(is_cx::value) + { + f.unsetf(ios::fixed); + f.setf(ios::scientific); + + f.precision(16); + } + + return cell_width; + } + + + + +//! Save a matrix as raw text (no header, human readable). +//! Matrices can be loaded in Matlab and Octave, as long as they don't have complex elements. +template +inline +bool +diskio::save_raw_ascii(const Mat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_raw_ascii(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a matrix as raw text (no header, human readable). +//! Matrices can be loaded in Matlab and Octave, as long as they don't have complex elements. +template +inline +bool +diskio::save_raw_ascii(const Mat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + const std::streamsize cell_width = diskio::prepare_stream(f); + + for(uword row=0; row < x.n_rows; ++row) + { + for(uword col=0; col < x.n_cols; ++col) + { + f.put(' '); + + if(is_real::value) { f.width(cell_width); } + + arma_ostream::raw_print_elem(f, x.at(row,col)); + } + + f.put('\n'); + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a matrix as raw binary (no header) +template +inline +bool +diskio::save_raw_binary(const Mat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f(tmp_name, std::fstream::binary); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_raw_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +template +inline +bool +diskio::save_raw_binary(const Mat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + f.write( reinterpret_cast(x.mem), std::streamsize(x.n_elem*sizeof(eT)) ); + + return f.good(); + } + + + +//! Save a matrix in text format (human readable), +//! with a header that indicates the matrix type as well as its dimensions +template +inline +bool +diskio::save_arma_ascii(const Mat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_arma_ascii(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a matrix in text format (human readable), +//! with a header that indicates the matrix type as well as its dimensions +template +inline +bool +diskio::save_arma_ascii(const Mat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + f << diskio::gen_txt_header(x) << '\n'; + f << x.n_rows << ' ' << x.n_cols << '\n'; + + const std::streamsize cell_width = diskio::prepare_stream(f); + + for(uword row=0; row < x.n_rows; ++row) + { + for(uword col=0; col < x.n_cols; ++col) + { + f.put(' '); + + if(is_real::value) { f.width(cell_width); } + + arma_ostream::raw_print_elem(f, x.at(row,col)); + } + + f.put('\n'); + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a matrix in CSV text format (human readable) +template +inline +bool +diskio::save_csv_ascii(const Mat& x, const std::string& final_name, const field& header, const bool with_header, const char separator) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay == false) { return false; } + + if(with_header) + { + arma_extra_debug_print("diskio::save_csv_ascii(): writing header"); + + for(uword i=0; i < header.n_elem; ++i) + { + f << header.at(i); + + if(i != (header.n_elem-1)) { f.put(separator); } + } + + f.put('\n'); + + save_okay = f.good(); + } + + if(save_okay) { save_okay = diskio::save_csv_ascii(x, f, separator); } + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + + return save_okay; + } + + + +//! Save a matrix in CSV text format (human readable) +template +inline +bool +diskio::save_csv_ascii(const Mat& x, std::ostream& f, const char separator) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + uword x_n_rows = x.n_rows; + uword x_n_cols = x.n_cols; + + for(uword row=0; row < x_n_rows; ++row) + { + for(uword col=0; col < x_n_cols; ++col) + { + arma_ostream::raw_print_elem(f, x.at(row,col)); + + if( col < (x_n_cols-1) ) { f.put(separator); } + } + + f.put('\n'); + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a matrix in CSV text format (human readable); complex numbers stored in "a+bi" format +template +inline +bool +diskio::save_csv_ascii(const Mat< std::complex >& x, std::ostream& f, const char separator) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + uword x_n_rows = x.n_rows; + uword x_n_cols = x.n_cols; + + for(uword row=0; row < x_n_rows; ++row) + { + for(uword col=0; col < x_n_cols; ++col) + { + const eT& val = x.at(row,col); + + const T tmp_r = std::real(val); + const T tmp_i = std::imag(val); + const T tmp_i_abs = (tmp_i < T(0)) ? T(-tmp_i) : T(tmp_i); + const char tmp_sign = (tmp_i < T(0)) ? char('-') : char('+'); + + arma_ostream::raw_print_elem(f, tmp_r ); + f.put(tmp_sign); + arma_ostream::raw_print_elem(f, tmp_i_abs); + f.put('i'); + + if( col < (x_n_cols-1) ) { f.put(separator); } + } + + f.put('\n'); + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +template +inline +bool +diskio::save_coord_ascii(const Mat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_coord_ascii(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +template +inline +bool +diskio::save_coord_ascii(const Mat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + for(uword col=0; col < x.n_cols; ++col) + for(uword row=0; row < x.n_rows; ++row) + { + const eT val = x.at(row,col); + + if(val != eT(0)) + { + f << row << ' ' << col << ' ' << val << '\n'; + } + } + + // make sure it's possible to figure out the matrix size later + if( (x.n_rows > 0) && (x.n_cols > 0) ) + { + const uword max_row = (x.n_rows > 0) ? x.n_rows-1 : 0; + const uword max_col = (x.n_cols > 0) ? x.n_cols-1 : 0; + + if( x.at(max_row, max_col) == eT(0) ) + { + f << max_row << ' ' << max_col << " 0\n"; + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +template +inline +bool +diskio::save_coord_ascii(const Mat< std::complex >& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + const eT eT_zero = eT(0); + + for(uword col=0; col < x.n_cols; ++col) + for(uword row=0; row < x.n_rows; ++row) + { + const eT& val = x.at(row,col); + + if(val != eT_zero) + { + f << row << ' ' << col << ' ' << val.real() << ' ' << val.imag() << '\n'; + } + } + + // make sure it's possible to figure out the matrix size later + if( (x.n_rows > 0) && (x.n_cols > 0) ) + { + const uword max_row = (x.n_rows > 0) ? x.n_rows-1 : 0; + const uword max_col = (x.n_cols > 0) ? x.n_cols-1 : 0; + + if( x.at(max_row, max_col) == eT_zero ) + { + f << max_row << ' ' << max_col << " 0 0\n"; + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a matrix in binary format, +//! with a header that stores the matrix type as well as its dimensions +template +inline +bool +diskio::save_arma_binary(const Mat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f(tmp_name, std::fstream::binary); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_arma_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a matrix in binary format, +//! with a header that stores the matrix type as well as its dimensions +template +inline +bool +diskio::save_arma_binary(const Mat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + f << diskio::gen_bin_header(x) << '\n'; + f << x.n_rows << ' ' << x.n_cols << '\n'; + + f.write( reinterpret_cast(x.mem), std::streamsize(x.n_elem*sizeof(eT)) ); + + return f.good(); + } + + + +//! Save a matrix as a PGM greyscale image +template +inline +bool +diskio::save_pgm_binary(const Mat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::fstream f(tmp_name, std::fstream::out | std::fstream::binary); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_pgm_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a matrix as a PGM greyscale image +template +inline +bool +diskio::save_pgm_binary(const Mat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + f << "P5" << '\n'; + f << x.n_cols << ' ' << x.n_rows << '\n'; + f << 255 << '\n'; + + const uword n_elem = x.n_rows * x.n_cols; + podarray tmp(n_elem); + + uword i = 0; + + for(uword row=0; row < x.n_rows; ++row) + for(uword col=0; col < x.n_cols; ++col) + { + tmp[i] = u8( x.at(row,col) ); // TODO: add round() ? + ++i; + } + + f.write(reinterpret_cast(tmp.mem), std::streamsize(n_elem) ); + + return f.good(); + } + + + +//! Save a matrix as a PGM greyscale image +template +inline +bool +diskio::save_pgm_binary(const Mat< std::complex >& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const uchar_mat tmp = conv_to::from(x); + + return diskio::save_pgm_binary(tmp, final_name); + } + + + +//! Save a matrix as a PGM greyscale image +template +inline +bool +diskio::save_pgm_binary(const Mat< std::complex >& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const uchar_mat tmp = conv_to::from(x); + + return diskio::save_pgm_binary(tmp, f); + } + + + +//! Save a matrix as part of a HDF5 file +template +inline +bool +diskio::save_hdf5_binary(const Mat& x, const hdf5_name& spec, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_HDF5) + { + hdf5_misc::hdf5_suspend_printing_errors hdf5_print_suspender; + + bool save_okay = false; + + const bool append = bool(spec.opts.flags & hdf5_opts::flag_append); + const bool replace = bool(spec.opts.flags & hdf5_opts::flag_replace); + + const bool use_existing_file = ((append || replace) && (H5Fis_hdf5(spec.filename.c_str()) > 0)); + + const std::string tmp_name = (use_existing_file) ? std::string() : diskio::gen_tmp_name(spec.filename); + + // Set up the file according to HDF5's preferences + hid_t file = (use_existing_file) ? H5Fopen(spec.filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT) : H5Fcreate(tmp_name.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); + + if(file < 0) { return false; } + + // We need to create a dataset, datatype, and dataspace + hsize_t dims[2]; + dims[1] = x.n_rows; + dims[0] = x.n_cols; + + hid_t dataspace = H5Screate_simple(2, dims, NULL); // treat the matrix as a 2d array dataspace + hid_t datatype = hdf5_misc::get_hdf5_type(); + + // If this returned something invalid, well, it's time to crash. + arma_check(datatype == -1, "Mat::save(): unknown datatype for HDF5"); + + // MATLAB forces the users to specify a name at save time for HDF5; + // Octave will use the default of 'dataset' unless otherwise specified. + // If the user hasn't specified a dataset name, we will use 'dataset' + // We may have to split out the group name from the dataset name. + std::vector groups; + std::string full_name = spec.dsname; + size_t loc; + while((loc = full_name.find("/")) != std::string::npos) + { + // Create another group... + if(loc != 0) // Ignore the first /, if there is a leading /. + { + hid_t gid = H5Gcreate((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + + if((gid < 0) && use_existing_file) + { + gid = H5Gopen((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT); + } + + groups.push_back(gid); + } + + full_name = full_name.substr(loc + 1); + } + + const std::string dataset_name = full_name.empty() ? std::string("dataset") : full_name; + + const hid_t last_group = (groups.size() == 0) ? file : groups[groups.size() - 1]; + + if(use_existing_file && replace) + { + H5Ldelete(last_group, dataset_name.c_str(), H5P_DEFAULT); + // NOTE: H5Ldelete() in HDF5 v1.8 doesn't reclaim the deleted space; use h5repack to reclaim space: h5repack oldfile.h5 newfile.h5 + // NOTE: has this behaviour changed in HDF5 1.10 ? + // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010482.html + // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010486.html + } + + hid_t dataset = H5Dcreate(last_group, dataset_name.c_str(), datatype, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + + if(dataset < 0) + { + save_okay = false; + + err_msg = "failed to create dataset"; + } + else + { + save_okay = (H5Dwrite(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, x.mem) >= 0); + + H5Dclose(dataset); + } + + H5Tclose(datatype); + H5Sclose(dataspace); + for(size_t i = 0; i < groups.size(); ++i) { H5Gclose(groups[i]); } + H5Fclose(file); + + if((use_existing_file == false) && (save_okay == true)) { save_okay = diskio::safe_rename(tmp_name, spec.filename); } + + return save_okay; + } + #else + { + arma_ignore(x); + arma_ignore(spec); + arma_ignore(err_msg); + + arma_stop_logic_error("Mat::save(): use of HDF5 must be enabled"); + + return false; + } + #endif + } + + + +//! Load a matrix as raw text (no header, human readable). +//! Can read matrices saved as text in Matlab and Octave. +//! NOTE: this is much slower than reading a file with a header. +template +inline +bool +diskio::load_raw_ascii(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_raw_ascii(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +//! Load a matrix as raw text (no header, human readable). +//! Can read matrices saved as text in Matlab and Octave. +//! NOTE: this is much slower than reading a file with a header. +template +inline +bool +diskio::load_raw_ascii(Mat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + bool load_okay = f.good(); + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool f_n_cols_found = false; + + std::string line_string; + std::stringstream line_stream; + + std::string token; + + while( f.good() && load_okay ) + { + std::getline(f, line_string); + + // TODO: does it make sense to stop processing the file if an empty line is found ? + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_n_cols = 0; + + while(line_stream >> token) { ++line_n_cols; } + + if(f_n_cols_found == false) + { + f_n_cols = line_n_cols; + f_n_cols_found = true; + } + else + { + if(line_n_cols != f_n_cols) + { + load_okay = false; + err_msg = "inconsistent number of columns"; + } + } + + ++f_n_rows; + } + + + if(load_okay) + { + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { x.set_size(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + for(uword row=0; ((row < x.n_rows) && load_okay); ++row) + for(uword col=0; ((col < x.n_cols) && load_okay); ++col) + { + f >> token; + + if(diskio::convert_token(x.at(row,col), token) == false) + { + load_okay = false; + err_msg = "data interpretation failure"; + } + } + } + + + // an empty file indicates an empty matrix + if( (f_n_cols_found == false) && (load_okay == true) ) { x.reset(); } + + + return load_okay; + } + + + +//! Load a matrix in binary format (no header); +//! the matrix is assumed to have one column +template +inline +bool +diskio::load_raw_binary(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + f.open(name, std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_raw_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_raw_binary(Mat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + f.clear(); + const std::streampos pos1 = f.tellg(); + + f.clear(); + f.seekg(0, ios::end); + + f.clear(); + const std::streampos pos2 = f.tellg(); + + const uword N = ( (pos1 >= 0) && (pos2 >= 0) ) ? uword(pos2 - pos1) : 0; + + f.clear(); + //f.seekg(0, ios::beg); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { x.set_size(N / uword(sizeof(eT)), 1); } catch(...) { err_msg = "not enough memory"; return false; } + + f.clear(); + f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem * uword(sizeof(eT))) ); + + return f.good(); + } + + + +//! Load a matrix in text format (human readable), +//! with a header that indicates the matrix type as well as its dimensions +template +inline +bool +diskio::load_arma_ascii(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_arma_ascii(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +//! Load a matrix in text format (human readable), +//! with a header that indicates the matrix type as well as its dimensions +template +inline +bool +diskio::load_arma_ascii(Mat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::streampos pos = f.tellg(); + + bool load_okay = true; + + std::string f_header; + uword f_n_rows; + uword f_n_cols; + + f >> f_header; + f >> f_n_rows; + f >> f_n_cols; + + if(f_header == diskio::gen_txt_header(x)) + { + try { x.zeros(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + std::string token; + + for(uword row=0; row < x.n_rows; ++row) + for(uword col=0; col < x.n_cols; ++col) + { + f >> token; + + diskio::convert_token( x.at(row,col), token ); + } + + load_okay = f.good(); + } + else + { + load_okay = false; + err_msg = "incorrect header"; + } + + + // allow automatic conversion of u32/s32 matrices into u64/s64 matrices + + if(load_okay == false) + { + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Mat tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_ascii(tmp, f, junk); + + if(load_okay) { x = conv_to< Mat >::from(tmp); } + } + else + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Mat tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_ascii(tmp, f, junk); + + if(load_okay) { x = conv_to< Mat >::from(tmp); } + } + } + + return load_okay; + } + + + +//! Load a matrix in CSV text format (human readable) +template +inline +bool +diskio::load_csv_ascii(Mat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator, const bool strict) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay == false) { return false; } + + if(with_header) + { + arma_extra_debug_print("diskio::load_csv_ascii(): reading header"); + + std::string header_line; + std::stringstream header_stream; + std::vector header_tokens; + + std::getline(f, header_line); + + load_okay = f.good(); + + if(load_okay) + { + std::string token; + + header_stream.clear(); + header_stream.str(header_line); + + uword header_n_tokens = 0; + + while(header_stream.good()) + { + std::getline(header_stream, token, separator); + + diskio::sanitise_token(token); + + ++header_n_tokens; + + header_tokens.push_back(token); + } + + if(header_n_tokens == uword(0)) + { + header.reset(); + } + else + { + header.set_size(1,header_n_tokens); + + for(uword i=0; i < header_n_tokens; ++i) { header.at(i) = header_tokens[i]; } + } + } + } + + if(load_okay) + { + load_okay = diskio::load_csv_ascii(x, f, err_msg, separator, strict); + } + + f.close(); + + return load_okay; + } + + + +//! Load a matrix in CSV text format (human readable) +template +inline +bool +diskio::load_csv_ascii(Mat& x, std::istream& f, std::string& err_msg, const char separator, const bool strict) + { + arma_extra_debug_sigprint(); + + // TODO: replace with more efficient implementation + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + std::string line_string; + std::stringstream line_stream; + + std::string token; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_n_cols = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + ++line_n_cols; + } + + if(f_n_cols < line_n_cols) { f_n_cols = line_n_cols; } + + ++f_n_rows; + } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { x.zeros(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + if(strict) { x.fill(Datum::nan); } // take into account that each row may have a unique number of columns + + const bool use_mp = (arma_config::openmp) && (f_n_rows >= 2) && (f_n_cols >= 64); + + field token_array; + + bool token_array_ok = false; + + if(use_mp) + { + try + { + token_array.set_size(f_n_cols); + + for(uword i=0; i < f_n_cols; ++i) { token_array(i).reserve(32); } + + token_array_ok = true; + } + catch(...) + { + token_array.reset(); + } + } + + if(use_mp && token_array_ok) + { + #if defined(ARMA_USE_OPENMP) + { + uword row = 0; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + for(uword i=0; i < f_n_cols; ++i) { token_array(i).clear(); } + + uword line_stream_col = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token_array(line_stream_col), separator); + + ++line_stream_col; + } + + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < line_stream_col; ++col) + { + eT& out_val = x.at(row,col); + + (strict) ? diskio::convert_token_strict( out_val, token_array(col) ) : diskio::convert_token( out_val, token_array(col) ); + } + + ++row; + } + } + #endif + } + else // serial implementation + { + uword row = 0; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword col = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + + eT& out_val = x.at(row,col); + + (strict) ? diskio::convert_token_strict( out_val, token ) : diskio::convert_token( out_val, token ); + + ++col; + } + + ++row; + } + } + + return true; + } + + + +//! Load a matrix in CSV text format (human readable); complex numbers stored in "a+bi" format +template +inline +bool +diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator, const bool strict) + { + arma_extra_debug_sigprint(); + + // TODO: replace with more efficient implementation + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + std::string line_string; + std::stringstream line_stream; + + std::string token; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_n_cols = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + ++line_n_cols; + } + + if(f_n_cols < line_n_cols) { f_n_cols = line_n_cols; } + + ++f_n_rows; + } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { x.zeros(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + if(strict) { x.fill(Datum< std::complex >::nan); } // take into account that each row may have a unique number of columns + + uword row = 0; + + std::string str_real; + std::string str_imag; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword col = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + + diskio::sanitise_token(token); + + const size_t token_len = size_t( token.length() ); + + if(token_len == 0) { col++; continue; } + + // handle special cases: inf and nan, without the imaginary part + if( (token_len == 3) || (token_len == 4) ) + { + const char* str = token.c_str(); + + const bool neg = (str[0] == '-'); + const bool pos = (str[0] == '+'); + + const size_t offset = ( (neg || pos) && (token_len == 4) ) ? 1 : 0; + + const char sig_a = str[offset ]; + const char sig_b = str[offset+1]; + const char sig_c = str[offset+2]; + + bool found_val_real = false; + T val_real = T(0); + + if( ((sig_a == 'i') || (sig_a == 'I')) && ((sig_b == 'n') || (sig_b == 'N')) && ((sig_c == 'f') || (sig_c == 'F')) ) + { + val_real = (neg) ? -(Datum::inf) : Datum::inf; + + found_val_real = true; + } + else + if( ((sig_a == 'n') || (sig_a == 'N')) && ((sig_b == 'a') || (sig_b == 'A')) && ((sig_c == 'n') || (sig_c == 'N')) ) + { + val_real = Datum::nan; + + found_val_real = true; + } + + if(found_val_real) + { + const T val_imag = (strict) ? T(Datum::nan) : T(0); + + x.at(row,col) = std::complex(val_real, val_imag); + + col++; continue; // get next token + } + } + + bool found_x = false; + std::string::size_type loc_x = 0; // location of the separator (+ or -) between the real and imaginary part + + std::string::size_type loc_i = token.find_last_of('i'); // location of the imaginary part indicator + + if(loc_i == std::string::npos) + { + str_real = token; + str_imag.clear(); + } + else + { + bool found_plus = false; + bool found_minus = false; + + std::string::size_type loc_plus = token.find_last_of('+'); + + if(loc_plus != std::string::npos) + { + if(loc_plus >= 1) + { + const char prev_char = token.at(loc_plus-1); + + // make sure we're not looking at the sign of the exponent + if( (prev_char != 'e') && (prev_char != 'E') ) + { + found_plus = true; + } + else + { + // search again, omitting the exponent + loc_plus = token.find_last_of('+', loc_plus-1); + + if(loc_plus != std::string::npos) { found_plus = true; } + } + } + else + { + // loc_plus == 0, meaning we're at the start of the string + found_plus = true; + } + } + + std::string::size_type loc_minus = token.find_last_of('-'); + + if(loc_minus != std::string::npos) + { + if(loc_minus >= 1) + { + const char prev_char = token.at(loc_minus-1); + + // make sure we're not looking at the sign of the exponent + if( (prev_char != 'e') && (prev_char != 'E') ) + { + found_minus = true; + } + else + { + // search again, omitting the exponent + loc_minus = token.find_last_of('-', loc_minus-1); + + if(loc_minus != std::string::npos) { found_minus = true; } + } + } + else + { + // loc_minus == 0, meaning we're at the start of the string + found_minus = true; + } + } + + if(found_plus && found_minus) + { + if( (loc_i > loc_plus) && (loc_i > loc_minus) ) + { + // choose the sign closest to the "i" to be the separator between the real and imaginary part + loc_x = ( (loc_i - loc_plus) < (loc_i - loc_minus) ) ? loc_plus : loc_minus; + found_x = true; + } + } + else if(found_plus ) { loc_x = loc_plus; found_x = true; } + else if(found_minus) { loc_x = loc_minus; found_x = true; } + + if(found_x) + { + if( loc_x > 0 ) { str_real = token.substr(0,loc_x); } else { str_real.clear(); } + if((loc_x+1) < token.size()) { str_imag = token.substr(loc_x, token.size()-loc_x-1); } else { str_imag.clear(); } + } + else + { + str_real.clear(); + str_imag.clear(); + } + } + + T val_real = T(0); + T val_imag = T(0); + + (strict) ? diskio::convert_token_strict(val_real, str_real) : diskio::convert_token(val_real, str_real); + (strict) ? diskio::convert_token_strict(val_imag, str_imag) : diskio::convert_token(val_imag, str_imag); + + x.at(row,col) = std::complex(val_real, val_imag); + + ++col; + } + + ++row; + } + + return true; + } + + + +template +inline +bool +diskio::load_coord_ascii(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay == false) { return false; } + + if(load_okay) + { + load_okay = diskio::load_coord_ascii(x, f, err_msg); + } + + f.close(); + + return load_okay; + } + + + +//! Load a matrix in CSV text format (human readable) +template +inline +bool +diskio::load_coord_ascii(Mat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool size_found = false; + + std::string line_string; + std::stringstream line_stream; + + std::string token; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + // a valid line in co-ord format has at least 2 entries + + line_stream >> line_row; + + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } + + line_stream >> line_col; + + size_found = true; + + if(f_n_rows < line_row) { f_n_rows = line_row; } + if(f_n_cols < line_col) { f_n_cols = line_col; } + } + + // take into account that indices start at 0 + if(size_found) { ++f_n_rows; ++f_n_cols; } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try + { + Mat tmp(f_n_rows, f_n_cols, arma_zeros_indicator()); + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + line_stream >> line_row; + line_stream >> line_col; + + eT val = eT(0); + + line_stream >> token; + + if(line_stream.fail() == false) { diskio::convert_token( val, token ); } + + if(val != eT(0)) { tmp(line_row,line_col) = val; } + } + + x.steal_mem(tmp); + } + catch(...) + { + err_msg = "not enough memory"; + return false; + } + + return true; + } + + + +template +inline +bool +diskio::load_coord_ascii(Mat< std::complex >& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool size_found = false; + + std::string line_string; + std::stringstream line_stream; + + std::string token_real; + std::string token_imag; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + // a valid line in co-ord format has at least 2 entries + + line_stream >> line_row; + + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } + + line_stream >> line_col; + + size_found = true; + + if(f_n_rows < line_row) f_n_rows = line_row; + if(f_n_cols < line_col) f_n_cols = line_col; + } + + // take into account that indices start at 0 + if(size_found) { ++f_n_rows; ++f_n_cols; } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try + { + Mat< std::complex > tmp(f_n_rows, f_n_cols, arma_zeros_indicator()); + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + line_stream >> line_row; + line_stream >> line_col; + + T val_real = T(0); + T val_imag = T(0); + + line_stream >> token_real; + + if(line_stream.fail() == false) { diskio::convert_token( val_real, token_real ); } + + line_stream >> token_imag; + + if(line_stream.fail() == false) { diskio::convert_token( val_imag, token_imag ); } + + if( (val_real != T(0)) || (val_imag != T(0)) ) + { + tmp(line_row,line_col) = std::complex(val_real, val_imag); + } + } + + x.steal_mem(tmp); + } + catch(...) + { + err_msg = "not enough memory"; + return false; + } + + return true; + } + + + +//! Load a matrix in binary format, +//! with a header that indicates the matrix type as well as its dimensions +template +inline +bool +diskio::load_arma_binary(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + f.open(name, std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_arma_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_arma_binary(Mat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::streampos pos = f.tellg(); + + bool load_okay = true; + + std::string f_header; + uword f_n_rows; + uword f_n_cols; + + f >> f_header; + f >> f_n_rows; + f >> f_n_cols; + + if(f_header == diskio::gen_bin_header(x)) + { + //f.seekg(1, ios::cur); // NOTE: this may not be portable, as on a Windows machine a newline could be two characters + f.get(); + + try { x.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem*sizeof(eT)) ); + + load_okay = f.good(); + } + else + { + load_okay = false; + err_msg = "incorrect header"; + } + + + // allow automatic conversion of u32/s32 matrices into u64/s64 matrices + + if(load_okay == false) + { + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Mat tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_binary(tmp, f, junk); + + if(load_okay) { x = conv_to< Mat >::from(tmp); } + } + else + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Mat tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_binary(tmp, f, junk); + + if(load_okay) { x = conv_to< Mat >::from(tmp); } + } + } + + return load_okay; + } + + + +inline +void +diskio::pnm_skip_comments(std::istream& f) + { + while( isspace(f.peek()) ) + { + while( isspace(f.peek()) ) { f.get(); } + + if(f.peek() == '#') + { + while( (f.peek() != '\r') && (f.peek() != '\n') ) { f.get(); } + } + } + } + + + +//! Load a PGM greyscale image as a matrix +template +inline +bool +diskio::load_pgm_binary(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::fstream f; + f.open(name, std::fstream::in | std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_pgm_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +//! Load a PGM greyscale image as a matrix +template +inline +bool +diskio::load_pgm_binary(Mat& x, std::istream& f, std::string& err_msg) + { + bool load_okay = true; + + std::string f_header; + + f >> f_header; + + if(f_header == "P5") + { + uword f_n_rows = 0; + uword f_n_cols = 0; + int f_maxval = 0; + + diskio::pnm_skip_comments(f); + + f >> f_n_cols; + diskio::pnm_skip_comments(f); + + f >> f_n_rows; + diskio::pnm_skip_comments(f); + + f >> f_maxval; + f.get(); + + if( (f_maxval > 0) && (f_maxval <= 65535) ) + { + try { x.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + if(f_maxval <= 255) + { + const uword n_elem = f_n_cols*f_n_rows; + podarray tmp(n_elem); + + f.read( reinterpret_cast(tmp.memptr()), std::streamsize(n_elem) ); + + uword i = 0; + + //cout << "f_n_cols = " << f_n_cols << endl; + //cout << "f_n_rows = " << f_n_rows << endl; + + for(uword row=0; row < f_n_rows; ++row) + for(uword col=0; col < f_n_cols; ++col) + { + x.at(row,col) = eT(tmp[i]); + ++i; + } + } + else + { + const uword n_elem = f_n_cols*f_n_rows; + podarray tmp(n_elem); + + f.read( reinterpret_cast(tmp.memptr()), std::streamsize(n_elem*2) ); + + uword i = 0; + + for(uword row=0; row < f_n_rows; ++row) + for(uword col=0; col < f_n_cols; ++col) + { + x.at(row,col) = eT(tmp[i]); + ++i; + } + } + } + else + { + load_okay = false; + err_msg = "functionality unimplemented"; + } + + if(f.good() == false) { load_okay = false; } + } + else + { + load_okay = false; + err_msg = "unsupported header"; + } + + return load_okay; + } + + + +//! Load a PGM greyscale image as a matrix +template +inline +bool +diskio::load_pgm_binary(Mat< std::complex >& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + uchar_mat tmp; + const bool load_okay = diskio::load_pgm_binary(tmp, name, err_msg); + + x = conv_to< Mat< std::complex > >::from(tmp); + + return load_okay; + } + + + +//! Load a PGM greyscale image as a matrix +template +inline +bool +diskio::load_pgm_binary(Mat< std::complex >& x, std::istream& is, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + uchar_mat tmp; + const bool load_okay = diskio::load_pgm_binary(tmp, is, err_msg); + + x = conv_to< Mat< std::complex > >::from(tmp); + + return load_okay; + } + + + +//! Load a HDF5 file as a matrix +template +inline +bool +diskio::load_hdf5_binary(Mat& x, const hdf5_name& spec, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_HDF5) + { + if(diskio::is_readable(spec.filename) == false) { return false; } + + hdf5_misc::hdf5_suspend_printing_errors hdf5_print_suspender; + + bool load_okay = false; + + hid_t fid = H5Fopen(spec.filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + + if(fid >= 0) + { + // MATLAB HDF5 dataset names are user-specified; + // Octave tends to store the datasets in a group, with the actual dataset being referred to as "value". + // If the user hasn't specified a dataset, we will search for "dataset" and "value", + // and if those are not found we will take the first dataset we do find. + + std::vector searchNames; + + const bool exact = (spec.dsname.empty() == false); + + if(exact) + { + searchNames.push_back(spec.dsname); + } + else + { + searchNames.push_back("dataset"); + searchNames.push_back("value" ); + } + + hid_t dataset = hdf5_misc::search_hdf5_file(searchNames, fid, 2, exact); + + if(dataset >= 0) + { + hid_t filespace = H5Dget_space(dataset); + + // This must be <= 2 due to our search rules. + const int ndims = H5Sget_simple_extent_ndims(filespace); + + hsize_t dims[2]; + const herr_t query_status = H5Sget_simple_extent_dims(filespace, dims, NULL); + + // arma_check(query_status < 0, "Mat::load(): cannot get size of HDF5 dataset"); + if(query_status < 0) + { + err_msg = "cannot get size of HDF5 dataset"; + + H5Sclose(filespace); + H5Dclose(dataset); + H5Fclose(fid); + + return false; + } + + if(ndims == 1) { dims[1] = 1; } // Vector case; fake second dimension (one column). + + try { x.set_size(dims[1], dims[0]); } catch(...) { err_msg = "not enough memory"; return false; } + + // Now we have to see what type is stored to figure out how to load it. + hid_t datatype = H5Dget_type(dataset); + hid_t mat_type = hdf5_misc::get_hdf5_type(); + + // If these are the same type, it is simple. + if(H5Tequal(datatype, mat_type) > 0) + { + // Load directly; H5S_ALL used so that we load the entire dataset. + hid_t read_status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(x.memptr())); + + if(read_status >= 0) { load_okay = true; } + } + else + { + // Load into another array and convert its type accordingly. + hid_t read_status = hdf5_misc::load_and_convert_hdf5(x.memptr(), dataset, datatype, x.n_elem); + + if(read_status >= 0) { load_okay = true; } + } + + // Now clean up. + H5Tclose(datatype); + H5Tclose(mat_type); + H5Sclose(filespace); + } + + H5Dclose(dataset); + + H5Fclose(fid); + + if(load_okay == false) + { + err_msg = "unsupported or missing HDF5 data"; + } + } + else + { + err_msg = "cannot open"; + } + + return load_okay; + } + #else + { + arma_ignore(x); + arma_ignore(spec); + arma_ignore(err_msg); + + arma_stop_logic_error("Mat::load(): use of HDF5 must be enabled"); + + return false; + } + #endif + } + + + +//! Try to load a matrix by automatically determining its type +template +inline +bool +diskio::load_auto_detect(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(diskio::is_readable(name) == false) { return false; } + + #if defined(ARMA_USE_HDF5) + // We're currently using the C bindings for the HDF5 library, which don't support C++ streams + if( H5Fis_hdf5(name.c_str()) ) { return load_hdf5_binary(x, name, err_msg); } + #endif + + std::fstream f; + f.open(name, std::fstream::in | std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_auto_detect(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +//! Try to load a matrix by automatically determining its type +template +inline +bool +diskio::load_auto_detect(Mat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + const char* ARMA_MAT_TXT_str = "ARMA_MAT_TXT"; + const char* ARMA_MAT_BIN_str = "ARMA_MAT_BIN"; + const char* P5_str = "P5"; + + const uword ARMA_MAT_TXT_len = uword(12); + const uword ARMA_MAT_BIN_len = uword(12); + const uword P5_len = uword(2); + + podarray header(ARMA_MAT_TXT_len + 1); + + char* header_mem = header.memptr(); + + std::streampos pos = f.tellg(); + + f.read( header_mem, std::streamsize(ARMA_MAT_TXT_len) ); + f.clear(); + f.seekg(pos); + + header_mem[ARMA_MAT_TXT_len] = '\0'; + + if( std::strncmp(ARMA_MAT_TXT_str, header_mem, size_t(ARMA_MAT_TXT_len)) == 0 ) + { + return load_arma_ascii(x, f, err_msg); + } + else + if( std::strncmp(ARMA_MAT_BIN_str, header_mem, size_t(ARMA_MAT_BIN_len)) == 0 ) + { + return load_arma_binary(x, f, err_msg); + } + else + if( std::strncmp(P5_str, header_mem, size_t(P5_len)) == 0 ) + { + return load_pgm_binary(x, f, err_msg); + } + else + { + const file_type ft = guess_file_type_internal(f); + + switch(ft) + { + case csv_ascii: + return load_csv_ascii(x, f, err_msg, char(','), false); + break; + + case ssv_ascii: + return load_csv_ascii(x, f, err_msg, char(';'), false); + break; + + case raw_binary: + return load_raw_binary(x, f, err_msg); + break; + + case raw_ascii: + return load_raw_ascii(x, f, err_msg); + break; + + default: + err_msg = "unknown data"; + return false; + } + } + + return false; + } + + + +// +// sparse matrices +// + + + +//! Save a sparse matrix in CSV format +template +inline +bool +diskio::save_csv_ascii(const SpMat& x, const std::string& final_name, const field& header, const bool with_header, const char separator) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay == false) { return false; } + + if(with_header) + { + arma_extra_debug_print("diskio::save_csv_ascii(): writing header"); + + for(uword i=0; i < header.n_elem; ++i) + { + f << header(i); + + if(i != (header.n_elem-1)) { f.put(separator); } + } + + f.put('\n'); + + save_okay = f.good(); + } + + if(save_okay) { save_okay = diskio::save_csv_ascii(x, f, separator); } + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + + return save_okay; + } + + + +//! Save a sparse matrix in CSV format +template +inline +bool +diskio::save_csv_ascii(const SpMat& x, std::ostream& f, const char separator) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + x.sync(); + + uword x_n_rows = x.n_rows; + uword x_n_cols = x.n_cols; + + const eT eT_zero = eT(0); + + for(uword row=0; row < x_n_rows; ++row) + { + for(uword col=0; col < x_n_cols; ++col) + { + const eT val = x.at(row,col); + + if(val == eT_zero) + { + f.put('0'); + } + else + { + arma_ostream::raw_print_elem(f, val); + } + + if( col < (x_n_cols-1) ) { f.put(separator); } + } + + f.put('\n'); + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a sparse matrix in CSV format (complex numbers) +template +inline +bool +diskio::save_csv_ascii(const SpMat< std::complex >& x, std::ostream& f, const char separator) + { + arma_extra_debug_sigprint(); + + arma_ignore(x); + arma_ignore(f); + arma_ignore(separator); + + arma_debug_warn_level(1, "saving complex sparse matrices as csv_ascii not yet implemented"); + + return false; + } + + + +//! Save a matrix in ASCII coord format +template +inline +bool +diskio::save_coord_ascii(const SpMat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_coord_ascii(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a matrix in ASCII coord format +template +inline +bool +diskio::save_coord_ascii(const SpMat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + typename SpMat::const_iterator iter = x.begin(); + typename SpMat::const_iterator iter_end = x.end(); + + for(; iter != iter_end; ++iter) + { + const eT val = (*iter); + + f << iter.row() << ' ' << iter.col() << ' ' << val << '\n'; + } + + + // make sure it's possible to figure out the matrix size later + if( (x.n_rows > 0) && (x.n_cols > 0) ) + { + const uword max_row = (x.n_rows > 0) ? x.n_rows-1 : 0; + const uword max_col = (x.n_cols > 0) ? x.n_cols-1 : 0; + + if( x.at(max_row, max_col) == eT(0) ) + { + f << max_row << ' ' << max_col << " 0\n"; + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a matrix in ASCII coord format (complex numbers) +template +inline +bool +diskio::save_coord_ascii(const SpMat< std::complex >& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + typename SpMat::const_iterator iter = x.begin(); + typename SpMat::const_iterator iter_end = x.end(); + + for(; iter != iter_end; ++iter) + { + const eT val = (*iter); + + f << iter.row() << ' ' << iter.col() << ' ' << val.real() << ' ' << val.imag() << '\n'; + } + + // make sure it's possible to figure out the matrix size later + if( (x.n_rows > 0) && (x.n_cols > 0) ) + { + const uword max_row = (x.n_rows > 0) ? x.n_rows-1 : 0; + const uword max_col = (x.n_cols > 0) ? x.n_cols-1 : 0; + + if( x.at(max_row, max_col) == eT(0) ) + { + f << max_row << ' ' << max_col << " 0 0\n"; + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a matrix in binary format, +//! with a header that stores the matrix type as well as its dimensions +template +inline +bool +diskio::save_arma_binary(const SpMat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f(tmp_name, std::fstream::binary); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_arma_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a matrix in binary format, +//! with a header that stores the matrix type as well as its dimensions +template +inline +bool +diskio::save_arma_binary(const SpMat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + f << diskio::gen_bin_header(x) << '\n'; + f << x.n_rows << ' ' << x.n_cols << ' ' << x.n_nonzero << '\n'; + + f.write( reinterpret_cast(x.values), std::streamsize(x.n_nonzero*sizeof(eT)) ); + f.write( reinterpret_cast(x.row_indices), std::streamsize(x.n_nonzero*sizeof(uword)) ); + f.write( reinterpret_cast(x.col_ptrs), std::streamsize((x.n_cols+1)*sizeof(uword)) ); + + return f.good(); + } + + + +template +inline +bool +diskio::load_csv_ascii(SpMat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay == false) { return false; } + + if(with_header) + { + arma_extra_debug_print("diskio::load_csv_ascii(): reading header"); + + std::string header_line; + std::stringstream header_stream; + std::vector header_tokens; + + std::getline(f, header_line); + + load_okay = f.good(); + + if(load_okay) + { + std::string token; + + header_stream.clear(); + header_stream.str(header_line); + + uword header_n_tokens = 0; + + while(header_stream.good()) + { + std::getline(header_stream, token, separator); + + diskio::sanitise_token(token); + + ++header_n_tokens; + + header_tokens.push_back(token); + } + + if(header_n_tokens == uword(0)) + { + header.reset(); + } + else + { + header.set_size(1,header_n_tokens); + + for(uword i=0; i < header_n_tokens; ++i) { header.at(i) = header_tokens[i]; } + } + } + } + + if(load_okay) + { + load_okay = diskio::load_csv_ascii(x, f, err_msg, separator); + } + + f.close(); + + return load_okay; + } + + + +template +inline +bool +diskio::load_csv_ascii(SpMat& x, std::istream& f, std::string& err_msg, const char separator) + { + arma_extra_debug_sigprint(); + + // TODO: replace with more efficient implementation + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + std::string line_string; + std::stringstream line_stream; + + std::string token; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_n_cols = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + ++line_n_cols; + } + + if(f_n_cols < line_n_cols) { f_n_cols = line_n_cols; } + + ++f_n_rows; + } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try + { + MapMat tmp(f_n_rows, f_n_cols); + + uword row = 0; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword col = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + + eT val = eT(0); + + diskio::convert_token( val, token ); + + if(val != eT(0)) { tmp(row,col) = val; } + + ++col; + } + + ++row; + } + + x = tmp; + } + catch(...) + { + err_msg = "not enough memory"; + return false; + } + + return true; + } + + + +template +inline +bool +diskio::load_csv_ascii(SpMat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator) + { + arma_extra_debug_sigprint(); + + arma_ignore(x); + arma_ignore(f); + arma_ignore(err_msg); + arma_ignore(separator); + + arma_debug_warn_level(1, "loading complex sparse matrices as csv_ascii not yet implemented"); + + return false; + } + + + +template +inline +bool +diskio::load_coord_ascii(SpMat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_coord_ascii(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool size_found = false; + + std::string line_string; + std::stringstream line_stream; + + std::string token; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + // a valid line in co-ord format has at least 2 entries + + line_stream >> line_row; + + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } + + line_stream >> line_col; + + size_found = true; + + if(f_n_rows < line_row) { f_n_rows = line_row; } + if(f_n_cols < line_col) { f_n_cols = line_col; } + } + + // take into account that indices start at 0 + if(size_found) { ++f_n_rows; ++f_n_cols; } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try + { + MapMat tmp(f_n_rows, f_n_cols); + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + line_stream >> line_row; + line_stream >> line_col; + + eT val = eT(0); + + line_stream >> token; + + if(line_stream.fail() == false) { diskio::convert_token( val, token ); } + + if(val != eT(0)) { tmp(line_row,line_col) = val; } + } + + x = tmp; + } + catch(...) + { + err_msg = "not enough memory"; + return false; + } + + return true; + } + + + +template +inline +bool +diskio::load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool size_found = false; + + std::string line_string; + std::stringstream line_stream; + + std::string token_real; + std::string token_imag; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + // a valid line in co-ord format has at least 2 entries + + line_stream >> line_row; + + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } + + line_stream >> line_col; + + size_found = true; + + if(f_n_rows < line_row) f_n_rows = line_row; + if(f_n_cols < line_col) f_n_cols = line_col; + } + + // take into account that indices start at 0 + if(size_found) { ++f_n_rows; ++f_n_cols; } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try + { + MapMat< std::complex > tmp(f_n_rows, f_n_cols); + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + line_stream >> line_row; + line_stream >> line_col; + + T val_real = T(0); + T val_imag = T(0); + + line_stream >> token_real; + + if(line_stream.fail() == false) { diskio::convert_token( val_real, token_real ); } + + line_stream >> token_imag; + + if(line_stream.fail() == false) { diskio::convert_token( val_imag, token_imag ); } + + if( (val_real != T(0)) || (val_imag != T(0)) ) + { + tmp(line_row,line_col) = std::complex(val_real, val_imag); + } + } + + x = tmp; + } + catch(...) + { + err_msg = "not enough memory"; + return false; + } + + return true; + } + + + +//! Load a matrix in binary format, +//! with a header that indicates the matrix type as well as its dimensions +template +inline +bool +diskio::load_arma_binary(SpMat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + f.open(name, std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_arma_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_arma_binary(SpMat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + bool load_okay = true; + + std::string f_header; + + f >> f_header; + + if(f_header == diskio::gen_bin_header(x)) + { + uword f_n_rows; + uword f_n_cols; + uword f_n_nz; + + f >> f_n_rows; + f >> f_n_cols; + f >> f_n_nz; + + //f.seekg(1, ios::cur); // NOTE: this may not be portable, as on a Windows machine a newline could be two characters + f.get(); + + try { x.reserve(f_n_rows, f_n_cols, f_n_nz); } catch(...) { err_msg = "not enough memory"; return false; } + + f.read( reinterpret_cast(access::rwp(x.values)), std::streamsize(x.n_nonzero*sizeof(eT)) ); + + std::streampos pos = f.tellg(); + + f.read( reinterpret_cast(access::rwp(x.row_indices)), std::streamsize(x.n_nonzero*sizeof(uword)) ); + f.read( reinterpret_cast(access::rwp(x.col_ptrs)), std::streamsize((x.n_cols+1)*sizeof(uword)) ); + + bool check1 = true; for(uword i=0; i < x.n_nonzero; ++i) { if(x.values[i] == eT(0)) { check1 = false; break; } } + bool check2 = true; for(uword i=0; i < x.n_cols; ++i) { if(x.col_ptrs[i+1] < x.col_ptrs[i]) { check2 = false; break; } } + bool check3 = (x.col_ptrs[x.n_cols] == x.n_nonzero); + + if((check1 == true) && ((check2 == false) || (check3 == false))) + { + if(sizeof(uword) == 8) + { + arma_extra_debug_print("detected inconsistent data while loading; re-reading integer parts as u32"); + + // inconstency could be due to a different uword size used during saving, + // so try loading the row_indices and col_ptrs under the assumption of 32 bit unsigned integers + + f.clear(); + f.seekg(pos); + + podarray tmp_a(x.n_nonzero ); tmp_a.zeros(); + podarray tmp_b(x.n_cols + 1); tmp_b.zeros(); + + f.read( reinterpret_cast(tmp_a.memptr()), std::streamsize( x.n_nonzero * sizeof(u32)) ); + f.read( reinterpret_cast(tmp_b.memptr()), std::streamsize((x.n_cols + 1) * sizeof(u32)) ); + + check2 = true; for(uword i=0; i < x.n_cols; ++i) { if(tmp_b[i+1] < tmp_b[i]) { check2 = false; break; } } + check3 = (tmp_b[x.n_cols] == x.n_nonzero); + + load_okay = f.good(); + + if( load_okay && (check2 == true) && (check3 == true) ) + { + arma_extra_debug_print("reading integer parts as u32 succeeded"); + + arrayops::convert(access::rwp(x.row_indices), tmp_a.memptr(), x.n_nonzero ); + arrayops::convert(access::rwp(x.col_ptrs), tmp_b.memptr(), x.n_cols + 1); + } + else + { + arma_extra_debug_print("reading integer parts as u32 failed"); + } + } + } + + if((check1 == false) || (check2 == false) || (check3 == false)) + { + load_okay = false; + err_msg = "inconsistent data"; + } + else + { + load_okay = f.good(); + } + } + else + { + load_okay = false; + err_msg = "incorrect header"; + } + + return load_okay; + } + + + +// cubes + + + +//! Save a cube as raw text (no header, human readable). +template +inline +bool +diskio::save_raw_ascii(const Cube& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = save_raw_ascii(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a cube as raw text (no header, human readable). +template +inline +bool +diskio::save_raw_ascii(const Cube& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + const std::streamsize cell_width = diskio::prepare_stream(f); + + for(uword slice=0; slice < x.n_slices; ++slice) + { + for(uword row=0; row < x.n_rows; ++row) + { + for(uword col=0; col < x.n_cols; ++col) + { + f.put(' '); + + if(is_real::value) { f.width(cell_width); } + + arma_ostream::raw_print_elem(f, x.at(row,col,slice)); + } + + f.put('\n'); + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a cube as raw binary (no header) +template +inline +bool +diskio::save_raw_binary(const Cube& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f(tmp_name, std::fstream::binary); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_raw_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +template +inline +bool +diskio::save_raw_binary(const Cube& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + f.write( reinterpret_cast(x.mem), std::streamsize(x.n_elem*sizeof(eT)) ); + + return f.good(); + } + + + +//! Save a cube in text format (human readable), +//! with a header that indicates the cube type as well as its dimensions +template +inline +bool +diskio::save_arma_ascii(const Cube& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_arma_ascii(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a cube in text format (human readable), +//! with a header that indicates the cube type as well as its dimensions +template +inline +bool +diskio::save_arma_ascii(const Cube& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + f << diskio::gen_txt_header(x) << '\n'; + f << x.n_rows << ' ' << x.n_cols << ' ' << x.n_slices << '\n'; + + const std::streamsize cell_width = diskio::prepare_stream(f); + + for(uword slice=0; slice < x.n_slices; ++slice) + { + for(uword row=0; row < x.n_rows; ++row) + { + for(uword col=0; col < x.n_cols; ++col) + { + f.put(' '); + + if(is_real::value) { f.width(cell_width); } + + arma_ostream::raw_print_elem(f, x.at(row,col,slice)); + } + + f.put('\n'); + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +//! Save a cube in binary format, +//! with a header that stores the cube type as well as its dimensions +template +inline +bool +diskio::save_arma_binary(const Cube& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f(tmp_name, std::fstream::binary); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_arma_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +//! Save a cube in binary format, +//! with a header that stores the cube type as well as its dimensions +template +inline +bool +diskio::save_arma_binary(const Cube& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + f << diskio::gen_bin_header(x) << '\n'; + f << x.n_rows << ' ' << x.n_cols << ' ' << x.n_slices << '\n'; + + f.write( reinterpret_cast(x.mem), std::streamsize(x.n_elem*sizeof(eT)) ); + + return f.good(); + } + + + +//! Save a cube as part of a HDF5 file +template +inline +bool +diskio::save_hdf5_binary(const Cube& x, const hdf5_name& spec, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_HDF5) + { + hdf5_misc::hdf5_suspend_printing_errors hdf5_print_suspender; + + bool save_okay = false; + + const bool append = bool(spec.opts.flags & hdf5_opts::flag_append); + const bool replace = bool(spec.opts.flags & hdf5_opts::flag_replace); + + const bool use_existing_file = ((append || replace) && (H5Fis_hdf5(spec.filename.c_str()) > 0)); + + const std::string tmp_name = (use_existing_file) ? std::string() : diskio::gen_tmp_name(spec.filename); + + // Set up the file according to HDF5's preferences + hid_t file = (use_existing_file) ? H5Fopen(spec.filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT) : H5Fcreate(tmp_name.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); + + if(file < 0) { return false; } + + // We need to create a dataset, datatype, and dataspace + hsize_t dims[3]; + dims[2] = x.n_rows; + dims[1] = x.n_cols; + dims[0] = x.n_slices; + + hid_t dataspace = H5Screate_simple(3, dims, NULL); // treat the cube as a 3d array dataspace + hid_t datatype = hdf5_misc::get_hdf5_type(); + + // If this returned something invalid, well, it's time to crash. + arma_check(datatype == -1, "Cube::save(): unknown datatype for HDF5"); + + // MATLAB forces the users to specify a name at save time for HDF5; + // Octave will use the default of 'dataset' unless otherwise specified. + // If the user hasn't specified a dataset name, we will use 'dataset' + // We may have to split out the group name from the dataset name. + std::vector groups; + std::string full_name = spec.dsname; + size_t loc; + while((loc = full_name.find("/")) != std::string::npos) + { + // Create another group... + if(loc != 0) // Ignore the first /, if there is a leading /. + { + hid_t gid = H5Gcreate((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + + if((gid < 0) && use_existing_file) + { + gid = H5Gopen((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT); + } + + groups.push_back(gid); + } + + full_name = full_name.substr(loc + 1); + } + + const std::string dataset_name = full_name.empty() ? std::string("dataset") : full_name; + + const hid_t last_group = (groups.size() == 0) ? file : groups[groups.size() - 1]; + + if(use_existing_file && replace) + { + H5Ldelete(last_group, dataset_name.c_str(), H5P_DEFAULT); + // NOTE: H5Ldelete() in HDF5 v1.8 doesn't reclaim the deleted space; use h5repack to reclaim space: h5repack oldfile.h5 newfile.h5 + // NOTE: has this behaviour changed in HDF5 1.10 ? + // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010482.html + // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010486.html + } + + hid_t dataset = H5Dcreate(last_group, dataset_name.c_str(), datatype, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + + if(dataset < 0) + { + save_okay = false; + + err_msg = "failed to create dataset"; + } + else + { + save_okay = (H5Dwrite(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, x.mem) >= 0); + + H5Dclose(dataset); + } + + H5Tclose(datatype); + H5Sclose(dataspace); + for(size_t i = 0; i < groups.size(); ++i) { H5Gclose(groups[i]); } + H5Fclose(file); + + if((use_existing_file == false) && (save_okay == true)) { save_okay = diskio::safe_rename(tmp_name, spec.filename); } + + return save_okay; + } + #else + { + arma_ignore(x); + arma_ignore(spec); + arma_ignore(err_msg); + + arma_stop_logic_error("Cube::save(): use of HDF5 must be enabled"); + + return false; + } + #endif + } + + + +//! Load a cube as raw text (no header, human readable). +//! NOTE: this is much slower than reading a file with a header. +template +inline +bool +diskio::load_raw_ascii(Cube& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + Mat tmp; + const bool load_okay = diskio::load_raw_ascii(tmp, name, err_msg); + + if(load_okay) + { + if(tmp.is_empty() == false) + { + try { x.set_size(tmp.n_rows, tmp.n_cols, 1); } catch(...) { err_msg = "not enough memory"; return false; } + + x.slice(0) = tmp; + } + else + { + x.reset(); + } + } + + return load_okay; + } + + + +//! Load a cube as raw text (no header, human readable). +//! NOTE: this is much slower than reading a file with a header. +template +inline +bool +diskio::load_raw_ascii(Cube& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + Mat tmp; + const bool load_okay = diskio::load_raw_ascii(tmp, f, err_msg); + + if(load_okay) + { + if(tmp.is_empty() == false) + { + try { x.set_size(tmp.n_rows, tmp.n_cols, 1); } catch(...) { err_msg = "not enough memory"; return false; } + + x.slice(0) = tmp; + } + else + { + x.reset(); + } + } + + return load_okay; + } + + + +//! Load a cube in binary format (no header); +//! the cube is assumed to have one slice with one column +template +inline +bool +diskio::load_raw_binary(Cube& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + f.open(name, std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_raw_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_raw_binary(Cube& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + f.clear(); + const std::streampos pos1 = f.tellg(); + + f.clear(); + f.seekg(0, ios::end); + + f.clear(); + const std::streampos pos2 = f.tellg(); + + const uword N = ( (pos1 >= 0) && (pos2 >= 0) ) ? uword(pos2 - pos1) : 0; + + f.clear(); + //f.seekg(0, ios::beg); + f.seekg(pos1); + + try { x.set_size(N / uword(sizeof(eT)), 1, 1); } catch(...) { err_msg = "not enough memory"; return false; } + + f.clear(); + f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem * uword(sizeof(eT))) ); + + return f.good(); + } + + + +//! Load a cube in text format (human readable), +//! with a header that indicates the cube type as well as its dimensions +template +inline +bool +diskio::load_arma_ascii(Cube& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_arma_ascii(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +//! Load a cube in text format (human readable), +//! with a header that indicates the cube type as well as its dimensions +template +inline +bool +diskio::load_arma_ascii(Cube& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::streampos pos = f.tellg(); + + bool load_okay = true; + + std::string f_header; + uword f_n_rows; + uword f_n_cols; + uword f_n_slices; + + f >> f_header; + f >> f_n_rows; + f >> f_n_cols; + f >> f_n_slices; + + if(f_header == diskio::gen_txt_header(x)) + { + try { x.set_size(f_n_rows, f_n_cols, f_n_slices); } catch(...) { err_msg = "not enough memory"; return false; } + + for(uword slice = 0; slice < x.n_slices; ++slice) + for(uword row = 0; row < x.n_rows; ++row ) + for(uword col = 0; col < x.n_cols; ++col ) + { + f >> x.at(row,col,slice); + } + + load_okay = f.good(); + } + else + { + load_okay = false; + err_msg = "incorrect header"; + } + + + // allow automatic conversion of u32/s32 cubes into u64/s64 cubes + + if(load_okay == false) + { + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Cube tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_ascii(tmp, f, junk); + + if(load_okay) { x = conv_to< Cube >::from(tmp); } + } + else + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Cube tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_ascii(tmp, f, junk); + + if(load_okay) { x = conv_to< Cube >::from(tmp); } + } + } + + return load_okay; + } + + + +//! Load a cube in binary format, +//! with a header that indicates the cube type as well as its dimensions +template +inline +bool +diskio::load_arma_binary(Cube& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + f.open(name, std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_arma_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_arma_binary(Cube& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::streampos pos = f.tellg(); + + bool load_okay = true; + + std::string f_header; + uword f_n_rows; + uword f_n_cols; + uword f_n_slices; + + f >> f_header; + f >> f_n_rows; + f >> f_n_cols; + f >> f_n_slices; + + if(f_header == diskio::gen_bin_header(x)) + { + //f.seekg(1, ios::cur); // NOTE: this may not be portable, as on a Windows machine a newline could be two characters + f.get(); + + try { x.set_size(f_n_rows, f_n_cols, f_n_slices); } catch(...) { err_msg = "not enough memory"; return false; } + + f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem*sizeof(eT)) ); + + load_okay = f.good(); + } + else + { + load_okay = false; + err_msg = "incorrect header"; + } + + + // allow automatic conversion of u32/s32 cubes into u64/s64 cubes + + if(load_okay == false) + { + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Cube tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_binary(tmp, f, junk); + + if(load_okay) { x = conv_to< Cube >::from(tmp); } + } + else + if( (sizeof(eT) == 8) && is_same_type::yes ) + { + Cube tmp; + std::string junk; + + f.clear(); + f.seekg(pos); + + load_okay = diskio::load_arma_binary(tmp, f, junk); + + if(load_okay) { x = conv_to< Cube >::from(tmp); } + } + } + + return load_okay; + } + + + +//! Load a HDF5 file as a cube +template +inline +bool +diskio::load_hdf5_binary(Cube& x, const hdf5_name& spec, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_HDF5) + { + if(diskio::is_readable(spec.filename) == false) { return false; } + + hdf5_misc::hdf5_suspend_printing_errors hdf5_print_suspender; + + bool load_okay = false; + + hid_t fid = H5Fopen(spec.filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + + if(fid >= 0) + { + // MATLAB HDF5 dataset names are user-specified; + // Octave tends to store the datasets in a group, with the actual dataset being referred to as "value". + // If the user hasn't specified a dataset, we will search for "dataset" and "value", + // and if those are not found we will take the first dataset we do find. + + std::vector searchNames; + + const bool exact = (spec.dsname.empty() == false); + + if(exact) + { + searchNames.push_back(spec.dsname); + } + else + { + searchNames.push_back("dataset"); + searchNames.push_back("value" ); + } + + hid_t dataset = hdf5_misc::search_hdf5_file(searchNames, fid, 3, exact); + + if(dataset >= 0) + { + hid_t filespace = H5Dget_space(dataset); + + // This must be <= 3 due to our search rules. + const int ndims = H5Sget_simple_extent_ndims(filespace); + + hsize_t dims[3]; + const herr_t query_status = H5Sget_simple_extent_dims(filespace, dims, NULL); + + // arma_check(query_status < 0, "Cube::load(): cannot get size of HDF5 dataset"); + if(query_status < 0) + { + err_msg = "cannot get size of HDF5 dataset"; + + H5Sclose(filespace); + H5Dclose(dataset); + H5Fclose(fid); + + return false; + } + + if(ndims == 1) { dims[1] = 1; dims[2] = 1; } // Vector case; one row/colum, several slices + if(ndims == 2) { dims[2] = 1; } // Matrix case; one column, several rows/slices + + try { x.set_size(dims[2], dims[1], dims[0]); } catch(...) { err_msg = "not enough memory"; return false; } + + // Now we have to see what type is stored to figure out how to load it. + hid_t datatype = H5Dget_type(dataset); + hid_t mat_type = hdf5_misc::get_hdf5_type(); + + // If these are the same type, it is simple. + if(H5Tequal(datatype, mat_type) > 0) + { + // Load directly; H5S_ALL used so that we load the entire dataset. + hid_t read_status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(x.memptr())); + + if(read_status >= 0) { load_okay = true; } + } + else + { + // Load into another array and convert its type accordingly. + hid_t read_status = hdf5_misc::load_and_convert_hdf5(x.memptr(), dataset, datatype, x.n_elem); + + if(read_status >= 0) { load_okay = true; } + } + + // Now clean up. + H5Tclose(datatype); + H5Tclose(mat_type); + H5Sclose(filespace); + } + + H5Dclose(dataset); + + H5Fclose(fid); + + if(load_okay == false) + { + err_msg = "unsupported or missing HDF5 data"; + } + } + else + { + err_msg = "cannot open"; + } + + return load_okay; + } + #else + { + arma_ignore(x); + arma_ignore(spec); + arma_ignore(err_msg); + + arma_stop_logic_error("Cube::load(): use of HDF5 must be enabled"); + + return false; + } + #endif + } + + + +//! Try to load a cube by automatically determining its type +template +inline +bool +diskio::load_auto_detect(Cube& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(diskio::is_readable(name) == false) { return false; } + + #if defined(ARMA_USE_HDF5) + // We're currently using the C bindings for the HDF5 library, which don't support C++ streams + if( H5Fis_hdf5(name.c_str()) ) { return load_hdf5_binary(x, name, err_msg); } + #endif + + std::fstream f; + f.open(name, std::fstream::in | std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_auto_detect(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +//! Try to load a cube by automatically determining its type +template +inline +bool +diskio::load_auto_detect(Cube& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + const char* ARMA_CUB_TXT_str = "ARMA_CUB_TXT"; + const char* ARMA_CUB_BIN_str = "ARMA_CUB_BIN"; + const char* P6_str = "P6"; + + const uword ARMA_CUB_TXT_len = uword(12); + const uword ARMA_CUB_BIN_len = uword(12); + const uword P6_len = uword(2); + + podarray header(ARMA_CUB_TXT_len + 1); + + char* header_mem = header.memptr(); + + std::streampos pos = f.tellg(); + + f.read( header_mem, std::streamsize(ARMA_CUB_TXT_len) ); + f.clear(); + f.seekg(pos); + + header_mem[ARMA_CUB_TXT_len] = '\0'; + + if( std::strncmp(ARMA_CUB_TXT_str, header_mem, size_t(ARMA_CUB_TXT_len)) == 0 ) + { + return load_arma_ascii(x, f, err_msg); + } + else + if( std::strncmp(ARMA_CUB_BIN_str, header_mem, size_t(ARMA_CUB_BIN_len)) == 0 ) + { + return load_arma_binary(x, f, err_msg); + } + else + if( std::strncmp(P6_str, header_mem, size_t(P6_len)) == 0 ) + { + return load_ppm_binary(x, f, err_msg); + } + else + { + const file_type ft = guess_file_type_internal(f); + + switch(ft) + { + // case csv_ascii: + // return load_csv_ascii(x, f, err_msg); + // break; + + case raw_binary: + return load_raw_binary(x, f, err_msg); + break; + + case raw_ascii: + return load_raw_ascii(x, f, err_msg); + break; + + default: + err_msg = "unknown data"; + return false; + } + } + + return false; + } + + + + + +// fields + + + +template +inline +bool +diskio::save_arma_binary(const field& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f( tmp_name, std::fstream::binary ); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_arma_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +template +inline +bool +diskio::save_arma_binary(const field& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + arma_type_check(( (is_Mat::value == false) && (is_Cube::value == false) )); + + if(x.n_slices <= 1) + { + f << "ARMA_FLD_BIN" << '\n'; + f << x.n_rows << '\n'; + f << x.n_cols << '\n'; + } + else + { + f << "ARMA_FL3_BIN" << '\n'; + f << x.n_rows << '\n'; + f << x.n_cols << '\n'; + f << x.n_slices << '\n'; + } + + bool save_okay = true; + + for(uword i=0; i +inline +bool +diskio::load_arma_binary(field& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f( name, std::fstream::binary ); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_arma_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_arma_binary(field& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + arma_type_check(( (is_Mat::value == false) && (is_Cube::value == false) )); + + bool load_okay = true; + + std::string f_type; + f >> f_type; + + if(f_type == "ARMA_FLD_BIN") + { + uword f_n_rows; + uword f_n_cols; + + f >> f_n_rows; + f >> f_n_cols; + + try { x.set_size(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + f.get(); + + for(uword i=0; i> f_n_rows; + f >> f_n_cols; + f >> f_n_slices; + + try { x.set_size(f_n_rows, f_n_cols, f_n_slices); } catch(...) { err_msg = "not enough memory"; return false; } + + f.get(); + + for(uword i=0; i& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f( tmp_name, std::fstream::binary ); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_std_string(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +inline +bool +diskio::save_std_string(const field& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + for(uword row=0; row& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f(name); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_std_string(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +inline +bool +diskio::load_std_string(field& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + bool load_okay = true; + + // + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool f_n_cols_found = false; + + std::string line_string; + std::string token; + + while( f.good() && load_okay ) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + std::stringstream line_stream(line_string); + + uword line_n_cols = 0; + + while(line_stream >> token) { line_n_cols++; } + + if(f_n_cols_found == false) + { + f_n_cols = line_n_cols; + f_n_cols_found = true; + } + else + { + if(line_n_cols != f_n_cols) + { + load_okay = false; + err_msg = "inconsistent number of columns"; + } + } + + ++f_n_rows; + } + + if(load_okay) + { + f.clear(); + f.seekg(0, ios::beg); + //f.seekg(start); + + try { x.set_size(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + for(uword row=0; row < x.n_rows; ++row) + for(uword col=0; col < x.n_cols; ++col) + { + f >> x.at(row,col); + } + } + + if(f.good() == false) { load_okay = false; } + + return load_okay; + } + + + +//! Try to load a field by automatically determining its type +template +inline +bool +diskio::load_auto_detect(field& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::fstream f; + f.open(name, std::fstream::in | std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_auto_detect(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +//! Try to load a field by automatically determining its type +template +inline +bool +diskio::load_auto_detect(field& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_Mat::value == false )); + + static const std::string ARMA_FLD_BIN = "ARMA_FLD_BIN"; + static const std::string ARMA_FL3_BIN = "ARMA_FL3_BIN"; + static const std::string P6 = "P6"; + + podarray raw_header(uword(ARMA_FLD_BIN.length()) + 1); + + std::streampos pos = f.tellg(); + + f.read( raw_header.memptr(), std::streamsize(ARMA_FLD_BIN.length()) ); + + f.clear(); + f.seekg(pos); + + raw_header[uword(ARMA_FLD_BIN.length())] = '\0'; + + const std::string header = raw_header.mem; + + if(ARMA_FLD_BIN == header.substr(0, ARMA_FLD_BIN.length())) + { + return load_arma_binary(x, f, err_msg); + } + else + if(ARMA_FL3_BIN == header.substr(0, ARMA_FL3_BIN.length())) + { + return load_arma_binary(x, f, err_msg); + } + else + if(P6 == header.substr(0, P6.length())) + { + return load_ppm_binary(x, f, err_msg); + } + else + { + err_msg = "unsupported header"; + return false; + } + } + + + +// +// handling of PPM images by cubes + + +template +inline +bool +diskio::load_ppm_binary(Cube& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::fstream f; + f.open(name, std::fstream::in | std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_ppm_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_ppm_binary(Cube& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + bool load_okay = true; + + std::string f_header; + + f >> f_header; + + if(f_header == "P6") + { + uword f_n_rows = 0; + uword f_n_cols = 0; + int f_maxval = 0; + + diskio::pnm_skip_comments(f); + + f >> f_n_cols; + diskio::pnm_skip_comments(f); + + f >> f_n_rows; + diskio::pnm_skip_comments(f); + + f >> f_maxval; + f.get(); + + if( (f_maxval > 0) && (f_maxval <= 65535) ) + { + try { x.set_size(f_n_rows, f_n_cols, 3); } catch(...) { err_msg = "not enough memory"; return false; } + + if(f_maxval <= 255) + { + const uword n_elem = 3*f_n_cols*f_n_rows; + podarray tmp(n_elem); + + f.read( reinterpret_cast(tmp.memptr()), std::streamsize(n_elem) ); + + uword i = 0; + + //cout << "f_n_cols = " << f_n_cols << endl; + //cout << "f_n_rows = " << f_n_rows << endl; + + for(uword row=0; row < f_n_rows; ++row) + for(uword col=0; col < f_n_cols; ++col) + { + x.at(row,col,0) = eT(tmp[i+0]); + x.at(row,col,1) = eT(tmp[i+1]); + x.at(row,col,2) = eT(tmp[i+2]); + i+=3; + } + } + else + { + const uword n_elem = 3*f_n_cols*f_n_rows; + podarray tmp(n_elem); + + f.read( reinterpret_cast(tmp.memptr()), std::streamsize(2*n_elem) ); + + uword i = 0; + + for(uword row=0; row < f_n_rows; ++row) + for(uword col=0; col < f_n_cols; ++col) + { + x.at(row,col,0) = eT(tmp[i+0]); + x.at(row,col,1) = eT(tmp[i+1]); + x.at(row,col,2) = eT(tmp[i+2]); + i+=3; + } + } + } + else + { + load_okay = false; + err_msg = "functionality unimplemented"; + } + + if(f.good() == false) { load_okay = false; } + } + else + { + load_okay = false; + err_msg = "unsupported header"; + } + + return load_okay; + } + + + +template +inline +bool +diskio::save_ppm_binary(const Cube& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f( tmp_name, std::fstream::binary ); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_ppm_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +template +inline +bool +diskio::save_ppm_binary(const Cube& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (x.n_slices != 3), "diskio::save_ppm_binary(): given cube must have exactly 3 slices" ); + + const uword n_elem = 3 * x.n_rows * x.n_cols; + podarray tmp(n_elem); + + uword i = 0; + for(uword row=0; row < x.n_rows; ++row) + { + for(uword col=0; col < x.n_cols; ++col) + { + tmp[i+0] = u8( access::tmp_real( x.at(row,col,0) ) ); + tmp[i+1] = u8( access::tmp_real( x.at(row,col,1) ) ); + tmp[i+2] = u8( access::tmp_real( x.at(row,col,2) ) ); + + i+=3; + } + } + + f << "P6" << '\n'; + f << x.n_cols << '\n'; + f << x.n_rows << '\n'; + f << 255 << '\n'; + + f.write( reinterpret_cast(tmp.mem), std::streamsize(n_elem) ); + + return f.good(); + } + + + +// +// handling of PPM images by fields + + + +template +inline +bool +diskio::load_ppm_binary(field& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::fstream f; + f.open(name, std::fstream::in | std::fstream::binary); + + bool load_okay = f.is_open(); + + if(load_okay) + { + load_okay = diskio::load_ppm_binary(x, f, err_msg); + f.close(); + } + + return load_okay; + } + + + +template +inline +bool +diskio::load_ppm_binary(field& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_Mat::value == false )); + typedef typename T1::elem_type eT; + + bool load_okay = true; + + std::string f_header; + + f >> f_header; + + if(f_header == "P6") + { + uword f_n_rows = 0; + uword f_n_cols = 0; + int f_maxval = 0; + + diskio::pnm_skip_comments(f); + + f >> f_n_cols; + diskio::pnm_skip_comments(f); + + f >> f_n_rows; + diskio::pnm_skip_comments(f); + + f >> f_maxval; + f.get(); + + if( (f_maxval > 0) && (f_maxval <= 65535) ) + { + x.set_size(3); + Mat& R = x(0); + Mat& G = x(1); + Mat& B = x(2); + + try { R.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + try { G.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + try { B.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + if(f_maxval <= 255) + { + const uword n_elem = 3*f_n_cols*f_n_rows; + podarray tmp(n_elem); + + f.read( reinterpret_cast(tmp.memptr()), std::streamsize(n_elem) ); + + uword i = 0; + + //cout << "f_n_cols = " << f_n_cols << endl; + //cout << "f_n_rows = " << f_n_rows << endl; + + + for(uword row=0; row < f_n_rows; ++row) + { + for(uword col=0; col < f_n_cols; ++col) + { + R.at(row,col) = eT(tmp[i+0]); + G.at(row,col) = eT(tmp[i+1]); + B.at(row,col) = eT(tmp[i+2]); + i+=3; + } + + } + } + else + { + const uword n_elem = 3*f_n_cols*f_n_rows; + podarray tmp(n_elem); + + f.read( reinterpret_cast(tmp.memptr()), std::streamsize(2*n_elem) ); + + uword i = 0; + + for(uword row=0; row < f_n_rows; ++row) + for(uword col=0; col < f_n_cols; ++col) + { + R.at(row,col) = eT(tmp[i+0]); + G.at(row,col) = eT(tmp[i+1]); + B.at(row,col) = eT(tmp[i+2]); + i+=3; + } + } + } + else + { + load_okay = false; + err_msg = "functionality unimplemented"; + } + + if(f.good() == false) { load_okay = false; } + } + else + { + load_okay = false; + err_msg = "unsupported header"; + } + + return load_okay; + } + + + +template +inline +bool +diskio::save_ppm_binary(const field& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + std::ofstream f( tmp_name, std::fstream::binary ); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_ppm_binary(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +template +inline +bool +diskio::save_ppm_binary(const field& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_Mat::value == false )); + + typedef typename T1::elem_type eT; + + arma_debug_check( (x.n_elem != 3), "diskio::save_ppm_binary(): given field must have exactly 3 matrices of equal size" ); + + bool same_size = true; + for(uword i=1; i<3; ++i) + { + if( (x(0).n_rows != x(i).n_rows) || (x(0).n_cols != x(i).n_cols) ) + { + same_size = false; + break; + } + } + + arma_debug_check( (same_size != true), "diskio::save_ppm_binary(): given field must have exactly 3 matrices of equal size" ); + + const Mat& R = x(0); + const Mat& G = x(1); + const Mat& B = x(2); + + f << "P6" << '\n'; + f << R.n_cols << '\n'; + f << R.n_rows << '\n'; + f << 255 << '\n'; + + const uword n_elem = 3 * R.n_rows * R.n_cols; + podarray tmp(n_elem); + + uword i = 0; + for(uword row=0; row < R.n_rows; ++row) + for(uword col=0; col < R.n_cols; ++col) + { + tmp[i+0] = u8( access::tmp_real( R.at(row,col) ) ); + tmp[i+1] = u8( access::tmp_real( G.at(row,col) ) ); + tmp[i+2] = u8( access::tmp_real( B.at(row,col) ) ); + + i+=3; + } + + f.write( reinterpret_cast(tmp.mem), std::streamsize(n_elem) ); + + return f.good(); + } + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/distr_param.hpp b/src/armadillo/include/armadillo_bits/distr_param.hpp new file mode 100644 index 0000000..61f3c23 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/distr_param.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup distr_param +//! @{ + + + +class distr_param + { + public: + + const uword state; + + private: + + int a_int; + int b_int; + + double a_double; + double b_double; + + public: + + inline distr_param() + : state (0) + , a_int (0) + , b_int (0) + , a_double(0) + , b_double(0) + { + } + + + inline explicit distr_param(const int a, const int b) + : state (1) + , a_int (a) + , b_int (b) + , a_double(double(a)) + , b_double(double(b)) + { + } + + + inline explicit distr_param(const double a, const double b) + : state (2) + , a_int (int(a)) + , b_int (int(b)) + , a_double(a) + , b_double(b) + { + } + + + inline void get_int_vals(int& out_a, int& out_b) const + { + if(state == 0) { return; } + + out_a = a_int; + out_b = b_int; + } + + + inline void get_double_vals(double& out_a, double& out_b) const + { + if(state == 0) { return; } + + out_a = a_double; + out_b = b_double; + } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eGlueCube_bones.hpp b/src/armadillo/include/armadillo_bits/eGlueCube_bones.hpp new file mode 100644 index 0000000..8c157d8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eGlueCube_bones.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eGlueCube +//! @{ + + +template +class eGlueCube : public BaseCube< typename T1::elem_type, eGlueCube > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at ); + static constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp ); + static constexpr bool has_subview = (ProxyCube::has_subview || ProxyCube::has_subview); + + arma_aligned const ProxyCube P1; + arma_aligned const ProxyCube P2; + + arma_inline ~eGlueCube(); + arma_inline eGlueCube(const T1& in_A, const T2& in_B); + + arma_inline uword get_n_rows() const; + arma_inline uword get_n_cols() const; + arma_inline uword get_n_elem_slice() const; + arma_inline uword get_n_slices() const; + arma_inline uword get_n_elem() const; + + arma_inline elem_type operator[] (const uword i) const; + arma_inline elem_type at (const uword row, const uword col, const uword slice) const; + arma_inline elem_type at_alt (const uword i) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eGlueCube_meat.hpp b/src/armadillo/include/armadillo_bits/eGlueCube_meat.hpp new file mode 100644 index 0000000..59b30d3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eGlueCube_meat.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eGlueCube +//! @{ + + + +template +arma_inline +eGlueCube::~eGlueCube() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +eGlueCube::eGlueCube(const T1& in_A, const T2& in_B) + : P1(in_A) + , P2(in_B) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size + ( + P1.get_n_rows(), P1.get_n_cols(), P1.get_n_slices(), + P2.get_n_rows(), P2.get_n_cols(), P2.get_n_slices(), + eglue_type::text() + ); + } + + + +template +arma_inline +uword +eGlueCube::get_n_rows() const + { + return P1.get_n_rows(); + } + + + +template +arma_inline +uword +eGlueCube::get_n_cols() const + { + return P1.get_n_cols(); + } + + + +template +arma_inline +uword +eGlueCube::get_n_slices() const + { + return P1.get_n_slices(); + } + + + +template +arma_inline +uword +eGlueCube::get_n_elem_slice() const + { + return P1.get_n_elem_slice(); + } + + + +template +arma_inline +uword +eGlueCube::get_n_elem() const + { + return P1.get_n_elem(); + } + + + +template +arma_inline +typename T1::elem_type +eGlueCube::operator[] (const uword i) const + { + // the optimiser will keep only one return statement + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return P1[i] + P2[i]; } + else if(is_same_type::yes) { return P1[i] - P2[i]; } + else if(is_same_type::yes) { return P1[i] / P2[i]; } + else if(is_same_type::yes) { return P1[i] * P2[i]; } + else return eT(0); + } + + +template +arma_inline +typename T1::elem_type +eGlueCube::at(const uword row, const uword col, const uword slice) const + { + // the optimiser will keep only one return statement + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return P1.at(row,col,slice) + P2.at(row,col,slice); } + else if(is_same_type::yes) { return P1.at(row,col,slice) - P2.at(row,col,slice); } + else if(is_same_type::yes) { return P1.at(row,col,slice) / P2.at(row,col,slice); } + else if(is_same_type::yes) { return P1.at(row,col,slice) * P2.at(row,col,slice); } + else return eT(0); + } + + + +template +arma_inline +typename T1::elem_type +eGlueCube::at_alt(const uword i) const + { + // the optimiser will keep only one return statement + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return P1.at_alt(i) + P2.at_alt(i); } + else if(is_same_type::yes) { return P1.at_alt(i) - P2.at_alt(i); } + else if(is_same_type::yes) { return P1.at_alt(i) / P2.at_alt(i); } + else if(is_same_type::yes) { return P1.at_alt(i) * P2.at_alt(i); } + else return eT(0); + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eGlue_bones.hpp b/src/armadillo/include/armadillo_bits/eGlue_bones.hpp new file mode 100644 index 0000000..097dc6c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eGlue_bones.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eGlue +//! @{ + + +template +class eGlue : public Base< typename T1::elem_type, eGlue > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Proxy proxy1_type; + typedef Proxy proxy2_type; + + static constexpr bool use_at = (Proxy::use_at || Proxy::use_at ); + static constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp ); + static constexpr bool has_subview = (Proxy::has_subview || Proxy::has_subview); + + static constexpr bool is_col = (Proxy::is_col || Proxy::is_col ); + static constexpr bool is_row = (Proxy::is_row || Proxy::is_row ); + static constexpr bool is_xvec = (Proxy::is_xvec || Proxy::is_xvec); + + arma_aligned const Proxy P1; + arma_aligned const Proxy P2; + + arma_inline ~eGlue(); + arma_inline eGlue(const T1& in_A, const T2& in_B); + + arma_inline uword get_n_rows() const; + arma_inline uword get_n_cols() const; + arma_inline uword get_n_elem() const; + + arma_inline elem_type operator[] (const uword ii) const; + arma_inline elem_type at (const uword row, const uword col) const; + arma_inline elem_type at_alt (const uword ii) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eGlue_meat.hpp b/src/armadillo/include/armadillo_bits/eGlue_meat.hpp new file mode 100644 index 0000000..30fb507 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eGlue_meat.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eGlue +//! @{ + + + +template +arma_inline +eGlue::~eGlue() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +eGlue::eGlue(const T1& in_A, const T2& in_B) + : P1(in_A) + , P2(in_B) + { + arma_extra_debug_sigprint(); + + // arma_debug_assert_same_size( P1, P2, eglue_type::text() ); + arma_debug_assert_same_size + ( + P1.get_n_rows(), P1.get_n_cols(), + P2.get_n_rows(), P2.get_n_cols(), + eglue_type::text() + ); + } + + + +template +arma_inline +uword +eGlue::get_n_rows() const + { + return is_row ? 1 : P1.get_n_rows(); + } + + + +template +arma_inline +uword +eGlue::get_n_cols() const + { + return is_col ? 1 : P1.get_n_cols(); + } + + + +template +arma_inline +uword +eGlue::get_n_elem() const + { + return P1.get_n_elem(); + } + + + +template +arma_inline +typename T1::elem_type +eGlue::operator[] (const uword ii) const + { + // the optimiser will keep only one return statement + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return P1[ii] + P2[ii]; } + else if(is_same_type::yes) { return P1[ii] - P2[ii]; } + else if(is_same_type::yes) { return P1[ii] / P2[ii]; } + else if(is_same_type::yes) { return P1[ii] * P2[ii]; } + else return eT(0); + } + + + +template +arma_inline +typename T1::elem_type +eGlue::at(const uword row, const uword col) const + { + // the optimiser will keep only one return statement + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return P1.at(row,col) + P2.at(row,col); } + else if(is_same_type::yes) { return P1.at(row,col) - P2.at(row,col); } + else if(is_same_type::yes) { return P1.at(row,col) / P2.at(row,col); } + else if(is_same_type::yes) { return P1.at(row,col) * P2.at(row,col); } + else return eT(0); + } + + + +template +arma_inline +typename T1::elem_type +eGlue::at_alt(const uword ii) const + { + // the optimiser will keep only one return statement + + typedef typename T1::elem_type eT; + + if(is_same_type::yes) { return P1.at_alt(ii) + P2.at_alt(ii); } + else if(is_same_type::yes) { return P1.at_alt(ii) - P2.at_alt(ii); } + else if(is_same_type::yes) { return P1.at_alt(ii) / P2.at_alt(ii); } + else if(is_same_type::yes) { return P1.at_alt(ii) * P2.at_alt(ii); } + else return eT(0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eOpCube_bones.hpp b/src/armadillo/include/armadillo_bits/eOpCube_bones.hpp new file mode 100644 index 0000000..b6bcaba --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eOpCube_bones.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eOpCube +//! @{ + + + +template +class eOpCube : public BaseCube< typename T1::elem_type, eOpCube > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool use_at = ProxyCube::use_at; + static constexpr bool use_mp = ProxyCube::use_mp || eop_type::use_mp; + static constexpr bool has_subview = ProxyCube::has_subview; + + arma_aligned const ProxyCube P; + arma_aligned elem_type aux; //!< storage of auxiliary data, user defined format + arma_aligned uword aux_uword_a; //!< storage of auxiliary data, uword format + arma_aligned uword aux_uword_b; //!< storage of auxiliary data, uword format + arma_aligned uword aux_uword_c; //!< storage of auxiliary data, uword format + + inline ~eOpCube(); + inline explicit eOpCube(const BaseCube& in_m); + inline eOpCube(const BaseCube& in_m, const elem_type in_aux); + inline eOpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); + inline eOpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); + inline eOpCube(const BaseCube& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); + + arma_inline uword get_n_rows() const; + arma_inline uword get_n_cols() const; + arma_inline uword get_n_elem_slice() const; + arma_inline uword get_n_slices() const; + arma_inline uword get_n_elem() const; + + arma_inline elem_type operator[] (const uword i) const; + arma_inline elem_type at (const uword row, const uword col, const uword slice) const; + arma_inline elem_type at_alt (const uword i) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eOpCube_meat.hpp b/src/armadillo/include/armadillo_bits/eOpCube_meat.hpp new file mode 100644 index 0000000..6a165f4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eOpCube_meat.hpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eOpCube +//! @{ + + + +template +inline +eOpCube::eOpCube(const BaseCube& in_m) + : P (in_m.get_ref()) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOpCube::eOpCube(const BaseCube& in_m, const typename T1::elem_type in_aux) + : P (in_m.get_ref()) + , aux (in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOpCube::eOpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b) + : P (in_m.get_ref()) + , aux_uword_a (in_aux_uword_a) + , aux_uword_b (in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOpCube::eOpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c) + : P (in_m.get_ref()) + , aux_uword_a (in_aux_uword_a) + , aux_uword_b (in_aux_uword_b) + , aux_uword_c (in_aux_uword_c) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOpCube::eOpCube(const BaseCube& in_m, const typename T1::elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c) + : P (in_m.get_ref()) + , aux (in_aux) + , aux_uword_a (in_aux_uword_a) + , aux_uword_b (in_aux_uword_b) + , aux_uword_c (in_aux_uword_c) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOpCube::~eOpCube() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +uword +eOpCube::get_n_rows() const + { + return P.get_n_rows(); + } + + + +template +arma_inline +uword +eOpCube::get_n_cols() const + { + return P.get_n_cols(); + } + + + +template +arma_inline +uword +eOpCube::get_n_elem_slice() const + { + return P.get_n_elem_slice(); + } + + + +template +arma_inline +uword +eOpCube::get_n_slices() const + { + return P.get_n_slices(); + } + + + +template +arma_inline +uword +eOpCube::get_n_elem() const + { + return P.get_n_elem(); + } + + + +template +arma_inline +typename T1::elem_type +eOpCube::operator[] (const uword i) const + { + return eop_core::process(P[i], aux); + } + + + +template +arma_inline +typename T1::elem_type +eOpCube::at(const uword row, const uword col, const uword slice) const + { + return eop_core::process(P.at(row, col, slice), aux); + } + + + +template +arma_inline +typename T1::elem_type +eOpCube::at_alt(const uword i) const + { + return eop_core::process(P.at_alt(i), aux); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eOp_bones.hpp b/src/armadillo/include/armadillo_bits/eOp_bones.hpp new file mode 100644 index 0000000..d32abdd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eOp_bones.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eOp +//! @{ + + + +template +class eOp : public Base< typename T1::elem_type, eOp > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Proxy proxy_type; + + static constexpr bool use_at = Proxy::use_at; + static constexpr bool use_mp = Proxy::use_mp || eop_type::use_mp; + static constexpr bool has_subview = Proxy::has_subview; + + static constexpr bool is_row = Proxy::is_row; + static constexpr bool is_col = Proxy::is_col; + static constexpr bool is_xvec = Proxy::is_xvec; + + arma_aligned const Proxy P; + + arma_aligned elem_type aux; //!< storage of auxiliary data, user defined format + arma_aligned uword aux_uword_a; //!< storage of auxiliary data, uword format + arma_aligned uword aux_uword_b; //!< storage of auxiliary data, uword format + + inline ~eOp(); + inline explicit eOp(const T1& in_m); + inline eOp(const T1& in_m, const elem_type in_aux); + inline eOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); + inline eOp(const T1& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b); + + arma_inline uword get_n_rows() const; + arma_inline uword get_n_cols() const; + arma_inline uword get_n_elem() const; + + arma_inline elem_type operator[] (const uword ii) const; + arma_inline elem_type at (const uword row, const uword col) const; + arma_inline elem_type at_alt (const uword ii) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eOp_meat.hpp b/src/armadillo/include/armadillo_bits/eOp_meat.hpp new file mode 100644 index 0000000..e087505 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eOp_meat.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eOp +//! @{ + + + +template +inline +eOp::eOp(const T1& in_m) + : P(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOp::eOp(const T1& in_m, const typename T1::elem_type in_aux) + : P(in_m) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOp::eOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b) + : P(in_m) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOp::eOp(const T1& in_m, const typename T1::elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b) + : P(in_m) + , aux(in_aux) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eOp::~eOp() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +uword +eOp::get_n_rows() const + { + return is_row ? 1 : P.get_n_rows(); + } + + + +template +arma_inline +uword +eOp::get_n_cols() const + { + return is_col ? 1 : P.get_n_cols(); + } + + + +template +arma_inline +uword +eOp::get_n_elem() const + { + return P.get_n_elem(); + } + + + +template +arma_inline +typename T1::elem_type +eOp::operator[] (const uword ii) const + { + return eop_core::process(P[ii], aux); + } + + + +template +arma_inline +typename T1::elem_type +eOp::at(const uword row, const uword col) const + { + if(is_row) + { + return eop_core::process(P.at(0, col), aux); + } + else + if(is_col) + { + return eop_core::process(P.at(row, 0), aux); + } + else + { + return eop_core::process(P.at(row, col), aux); + } + } + + + +template +arma_inline +typename T1::elem_type +eOp::at_alt(const uword ii) const + { + return eop_core::process(P.at_alt(ii), aux); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eglue_core_bones.hpp b/src/armadillo/include/armadillo_bits/eglue_core_bones.hpp new file mode 100644 index 0000000..67db243 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eglue_core_bones.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eglue_core +//! @{ + + + +template +struct eglue_core + { + + // matrices + + template arma_hot inline static void apply(outT& out, const eGlue& x); + + template arma_hot inline static void apply_inplace_plus (Mat& out, const eGlue& x); + template arma_hot inline static void apply_inplace_minus(Mat& out, const eGlue& x); + template arma_hot inline static void apply_inplace_schur(Mat& out, const eGlue& x); + template arma_hot inline static void apply_inplace_div (Mat& out, const eGlue& x); + + + // cubes + + template arma_hot inline static void apply(Cube& out, const eGlueCube& x); + + template arma_hot inline static void apply_inplace_plus (Cube& out, const eGlueCube& x); + template arma_hot inline static void apply_inplace_minus(Cube& out, const eGlueCube& x); + template arma_hot inline static void apply_inplace_schur(Cube& out, const eGlueCube& x); + template arma_hot inline static void apply_inplace_div (Cube& out, const eGlueCube& x); + }; + + + +class eglue_plus : public eglue_core + { + public: + + inline static const char* text() { return "addition"; } + }; + + + +class eglue_minus : public eglue_core + { + public: + + inline static const char* text() { return "subtraction"; } + }; + + + +class eglue_div : public eglue_core + { + public: + + inline static const char* text() { return "element-wise division"; } + }; + + + +class eglue_schur : public eglue_core + { + public: + + inline static const char* text() { return "element-wise multiplication"; } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eglue_core_meat.hpp b/src/armadillo/include/armadillo_bits/eglue_core_meat.hpp new file mode 100644 index 0000000..a36f389 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eglue_core_meat.hpp @@ -0,0 +1,1250 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eglue_core +//! @{ + + + +#undef arma_applier_1u +#undef arma_applier_1a +#undef arma_applier_2 +#undef arma_applier_3 +#undef operatorA +#undef operatorB + +#undef arma_applier_1_mp +#undef arma_applier_2_mp +#undef arma_applier_3_mp + + +#if defined(ARMA_SIMPLE_LOOPS) + #define arma_applier_1u(operatorA, operatorB) \ + {\ + for(uword i=0; i +template +inline +void +eglue_core::apply(outT& out, const eGlue& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); + + // NOTE: we're assuming that the matrix has already been set to the correct size and there is no aliasing; + // size setting and alias checking is done by either the Mat contructor or operator=() + + + eT* out_mem = out.memptr(); + + if(use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(n_elem)) + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename Proxy::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename Proxy::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(=, +); } + else if(is_same_type::yes) { arma_applier_1a(=, -); } + else if(is_same_type::yes) { arma_applier_1a(=, /); } + else if(is_same_type::yes) { arma_applier_1a(=, *); } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(=, +); } + else if(is_same_type::yes) { arma_applier_1u(=, -); } + else if(is_same_type::yes) { arma_applier_1u(=, /); } + else if(is_same_type::yes) { arma_applier_1u(=, *); } + } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(=, +); } + else if(is_same_type::yes) { arma_applier_1u(=, -); } + else if(is_same_type::yes) { arma_applier_1u(=, /); } + else if(is_same_type::yes) { arma_applier_1u(=, *); } + } + } + } + else + { + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + const Proxy& P1 = x.P1; + const Proxy& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_2_mp(=, +); } + else if(is_same_type::yes) { arma_applier_2_mp(=, -); } + else if(is_same_type::yes) { arma_applier_2_mp(=, /); } + else if(is_same_type::yes) { arma_applier_2_mp(=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_2(=, +); } + else if(is_same_type::yes) { arma_applier_2(=, -); } + else if(is_same_type::yes) { arma_applier_2(=, /); } + else if(is_same_type::yes) { arma_applier_2(=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_plus(Mat& out, const eGlue& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "addition"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(n_elem)) + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(+=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(+=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(+=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(+=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename Proxy::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename Proxy::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(+=, +); } + else if(is_same_type::yes) { arma_applier_1a(+=, -); } + else if(is_same_type::yes) { arma_applier_1a(+=, /); } + else if(is_same_type::yes) { arma_applier_1a(+=, *); } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(+=, +); } + else if(is_same_type::yes) { arma_applier_1u(+=, -); } + else if(is_same_type::yes) { arma_applier_1u(+=, /); } + else if(is_same_type::yes) { arma_applier_1u(+=, *); } + } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(+=, +); } + else if(is_same_type::yes) { arma_applier_1u(+=, -); } + else if(is_same_type::yes) { arma_applier_1u(+=, /); } + else if(is_same_type::yes) { arma_applier_1u(+=, *); } + } + } + } + else + { + const Proxy& P1 = x.P1; + const Proxy& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_2_mp(+=, +); } + else if(is_same_type::yes) { arma_applier_2_mp(+=, -); } + else if(is_same_type::yes) { arma_applier_2_mp(+=, /); } + else if(is_same_type::yes) { arma_applier_2_mp(+=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_2(+=, +); } + else if(is_same_type::yes) { arma_applier_2(+=, -); } + else if(is_same_type::yes) { arma_applier_2(+=, /); } + else if(is_same_type::yes) { arma_applier_2(+=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_minus(Mat& out, const eGlue& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "subtraction"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(n_elem)) + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(-=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(-=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(-=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(-=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename Proxy::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename Proxy::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(-=, +); } + else if(is_same_type::yes) { arma_applier_1a(-=, -); } + else if(is_same_type::yes) { arma_applier_1a(-=, /); } + else if(is_same_type::yes) { arma_applier_1a(-=, *); } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(-=, +); } + else if(is_same_type::yes) { arma_applier_1u(-=, -); } + else if(is_same_type::yes) { arma_applier_1u(-=, /); } + else if(is_same_type::yes) { arma_applier_1u(-=, *); } + } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(-=, +); } + else if(is_same_type::yes) { arma_applier_1u(-=, -); } + else if(is_same_type::yes) { arma_applier_1u(-=, /); } + else if(is_same_type::yes) { arma_applier_1u(-=, *); } + } + } + } + else + { + const Proxy& P1 = x.P1; + const Proxy& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_2_mp(-=, +); } + else if(is_same_type::yes) { arma_applier_2_mp(-=, -); } + else if(is_same_type::yes) { arma_applier_2_mp(-=, /); } + else if(is_same_type::yes) { arma_applier_2_mp(-=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_2(-=, +); } + else if(is_same_type::yes) { arma_applier_2(-=, -); } + else if(is_same_type::yes) { arma_applier_2(-=, /); } + else if(is_same_type::yes) { arma_applier_2(-=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_schur(Mat& out, const eGlue& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise multiplication"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(n_elem)) + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(*=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(*=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(*=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(*=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename Proxy::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename Proxy::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(*=, +); } + else if(is_same_type::yes) { arma_applier_1a(*=, -); } + else if(is_same_type::yes) { arma_applier_1a(*=, /); } + else if(is_same_type::yes) { arma_applier_1a(*=, *); } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(*=, +); } + else if(is_same_type::yes) { arma_applier_1u(*=, -); } + else if(is_same_type::yes) { arma_applier_1u(*=, /); } + else if(is_same_type::yes) { arma_applier_1u(*=, *); } + } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(*=, +); } + else if(is_same_type::yes) { arma_applier_1u(*=, -); } + else if(is_same_type::yes) { arma_applier_1u(*=, /); } + else if(is_same_type::yes) { arma_applier_1u(*=, *); } + } + } + } + else + { + const Proxy& P1 = x.P1; + const Proxy& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_2_mp(*=, +); } + else if(is_same_type::yes) { arma_applier_2_mp(*=, -); } + else if(is_same_type::yes) { arma_applier_2_mp(*=, /); } + else if(is_same_type::yes) { arma_applier_2_mp(*=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_2(*=, +); } + else if(is_same_type::yes) { arma_applier_2(*=, -); } + else if(is_same_type::yes) { arma_applier_2(*=, /); } + else if(is_same_type::yes) { arma_applier_2(*=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_div(Mat& out, const eGlue& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise division"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(n_elem)) + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(/=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(/=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(/=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(/=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename Proxy::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename Proxy::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(/=, +); } + else if(is_same_type::yes) { arma_applier_1a(/=, -); } + else if(is_same_type::yes) { arma_applier_1a(/=, /); } + else if(is_same_type::yes) { arma_applier_1a(/=, *); } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(/=, +); } + else if(is_same_type::yes) { arma_applier_1u(/=, -); } + else if(is_same_type::yes) { arma_applier_1u(/=, /); } + else if(is_same_type::yes) { arma_applier_1u(/=, *); } + } + } + else + { + typename Proxy::ea_type P1 = x.P1.get_ea(); + typename Proxy::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(/=, +); } + else if(is_same_type::yes) { arma_applier_1u(/=, -); } + else if(is_same_type::yes) { arma_applier_1u(/=, /); } + else if(is_same_type::yes) { arma_applier_1u(/=, *); } + } + } + } + else + { + const Proxy& P1 = x.P1; + const Proxy& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && Proxy::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_2_mp(/=, +); } + else if(is_same_type::yes) { arma_applier_2_mp(/=, -); } + else if(is_same_type::yes) { arma_applier_2_mp(/=, /); } + else if(is_same_type::yes) { arma_applier_2_mp(/=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_2(/=, +); } + else if(is_same_type::yes) { arma_applier_2(/=, -); } + else if(is_same_type::yes) { arma_applier_2(/=, /); } + else if(is_same_type::yes) { arma_applier_2(/=, *); } + } + } + } + + + +// +// cubes + + + +template +template +inline +void +eglue_core::apply(Cube& out, const eGlueCube& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); + + // NOTE: we're assuming that the cube has already been set to the correct size and there is no aliasing; + // size setting and alias checking is done by either the Cube contructor or operator=() + + + eT* out_mem = out.memptr(); + + if(use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(n_elem)) + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename ProxyCube::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename ProxyCube::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(=, +); } + else if(is_same_type::yes) { arma_applier_1a(=, -); } + else if(is_same_type::yes) { arma_applier_1a(=, /); } + else if(is_same_type::yes) { arma_applier_1a(=, *); } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(=, +); } + else if(is_same_type::yes) { arma_applier_1u(=, -); } + else if(is_same_type::yes) { arma_applier_1u(=, /); } + else if(is_same_type::yes) { arma_applier_1u(=, *); } + } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(=, +); } + else if(is_same_type::yes) { arma_applier_1u(=, -); } + else if(is_same_type::yes) { arma_applier_1u(=, /); } + else if(is_same_type::yes) { arma_applier_1u(=, *); } + } + } + } + else + { + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + const ProxyCube& P1 = x.P1; + const ProxyCube& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_3_mp(=, +); } + else if(is_same_type::yes) { arma_applier_3_mp(=, -); } + else if(is_same_type::yes) { arma_applier_3_mp(=, /); } + else if(is_same_type::yes) { arma_applier_3_mp(=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_3(=, +); } + else if(is_same_type::yes) { arma_applier_3(=, -); } + else if(is_same_type::yes) { arma_applier_3(=, /); } + else if(is_same_type::yes) { arma_applier_3(=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_plus(Cube& out, const eGlueCube& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "addition"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(n_elem)) + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(+=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(+=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(+=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(+=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename ProxyCube::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename ProxyCube::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(+=, +); } + else if(is_same_type::yes) { arma_applier_1a(+=, -); } + else if(is_same_type::yes) { arma_applier_1a(+=, /); } + else if(is_same_type::yes) { arma_applier_1a(+=, *); } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(+=, +); } + else if(is_same_type::yes) { arma_applier_1u(+=, -); } + else if(is_same_type::yes) { arma_applier_1u(+=, /); } + else if(is_same_type::yes) { arma_applier_1u(+=, *); } + } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(+=, +); } + else if(is_same_type::yes) { arma_applier_1u(+=, -); } + else if(is_same_type::yes) { arma_applier_1u(+=, /); } + else if(is_same_type::yes) { arma_applier_1u(+=, *); } + } + } + } + else + { + const ProxyCube& P1 = x.P1; + const ProxyCube& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_3_mp(+=, +); } + else if(is_same_type::yes) { arma_applier_3_mp(+=, -); } + else if(is_same_type::yes) { arma_applier_3_mp(+=, /); } + else if(is_same_type::yes) { arma_applier_3_mp(+=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_3(+=, +); } + else if(is_same_type::yes) { arma_applier_3(+=, -); } + else if(is_same_type::yes) { arma_applier_3(+=, /); } + else if(is_same_type::yes) { arma_applier_3(+=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_minus(Cube& out, const eGlueCube& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "subtraction"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(n_elem)) + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(-=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(-=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(-=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(-=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename ProxyCube::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename ProxyCube::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(-=, +); } + else if(is_same_type::yes) { arma_applier_1a(-=, -); } + else if(is_same_type::yes) { arma_applier_1a(-=, /); } + else if(is_same_type::yes) { arma_applier_1a(-=, *); } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(-=, +); } + else if(is_same_type::yes) { arma_applier_1u(-=, -); } + else if(is_same_type::yes) { arma_applier_1u(-=, /); } + else if(is_same_type::yes) { arma_applier_1u(-=, *); } + } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(-=, +); } + else if(is_same_type::yes) { arma_applier_1u(-=, -); } + else if(is_same_type::yes) { arma_applier_1u(-=, /); } + else if(is_same_type::yes) { arma_applier_1u(-=, *); } + } + } + } + else + { + const ProxyCube& P1 = x.P1; + const ProxyCube& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_3_mp(-=, +); } + else if(is_same_type::yes) { arma_applier_3_mp(-=, -); } + else if(is_same_type::yes) { arma_applier_3_mp(-=, /); } + else if(is_same_type::yes) { arma_applier_3_mp(-=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_3(-=, +); } + else if(is_same_type::yes) { arma_applier_3(-=, -); } + else if(is_same_type::yes) { arma_applier_3(-=, /); } + else if(is_same_type::yes) { arma_applier_3(-=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_schur(Cube& out, const eGlueCube& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise multiplication"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(n_elem)) + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(*=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(*=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(*=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(*=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename ProxyCube::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename ProxyCube::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(*=, +); } + else if(is_same_type::yes) { arma_applier_1a(*=, -); } + else if(is_same_type::yes) { arma_applier_1a(*=, /); } + else if(is_same_type::yes) { arma_applier_1a(*=, *); } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(*=, +); } + else if(is_same_type::yes) { arma_applier_1u(*=, -); } + else if(is_same_type::yes) { arma_applier_1u(*=, /); } + else if(is_same_type::yes) { arma_applier_1u(*=, *); } + } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(*=, +); } + else if(is_same_type::yes) { arma_applier_1u(*=, -); } + else if(is_same_type::yes) { arma_applier_1u(*=, /); } + else if(is_same_type::yes) { arma_applier_1u(*=, *); } + } + } + } + else + { + const ProxyCube& P1 = x.P1; + const ProxyCube& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_3_mp(*=, +); } + else if(is_same_type::yes) { arma_applier_3_mp(*=, -); } + else if(is_same_type::yes) { arma_applier_3_mp(*=, /); } + else if(is_same_type::yes) { arma_applier_3_mp(*=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_3(*=, +); } + else if(is_same_type::yes) { arma_applier_3(*=, -); } + else if(is_same_type::yes) { arma_applier_3(*=, /); } + else if(is_same_type::yes) { arma_applier_3(*=, *); } + } + } + } + + + +template +template +inline +void +eglue_core::apply_inplace_div(Cube& out, const eGlueCube& x) + { + arma_extra_debug_sigprint(); + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise division"); + + typedef typename T1::elem_type eT; + + eT* out_mem = out.memptr(); + + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); + + if(use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(n_elem)) + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1_mp(/=, +); } + else if(is_same_type::yes) { arma_applier_1_mp(/=, -); } + else if(is_same_type::yes) { arma_applier_1_mp(/=, /); } + else if(is_same_type::yes) { arma_applier_1_mp(/=, *); } + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P1.is_aligned() && x.P2.is_aligned()) + { + typename ProxyCube::aligned_ea_type P1 = x.P1.get_aligned_ea(); + typename ProxyCube::aligned_ea_type P2 = x.P2.get_aligned_ea(); + + if(is_same_type::yes) { arma_applier_1a(/=, +); } + else if(is_same_type::yes) { arma_applier_1a(/=, -); } + else if(is_same_type::yes) { arma_applier_1a(/=, /); } + else if(is_same_type::yes) { arma_applier_1a(/=, *); } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(/=, +); } + else if(is_same_type::yes) { arma_applier_1u(/=, -); } + else if(is_same_type::yes) { arma_applier_1u(/=, /); } + else if(is_same_type::yes) { arma_applier_1u(/=, *); } + } + } + else + { + typename ProxyCube::ea_type P1 = x.P1.get_ea(); + typename ProxyCube::ea_type P2 = x.P2.get_ea(); + + if(is_same_type::yes) { arma_applier_1u(/=, +); } + else if(is_same_type::yes) { arma_applier_1u(/=, -); } + else if(is_same_type::yes) { arma_applier_1u(/=, /); } + else if(is_same_type::yes) { arma_applier_1u(/=, *); } + } + } + } + else + { + const ProxyCube& P1 = x.P1; + const ProxyCube& P2 = x.P2; + + if(use_mp && mp_gate::use_mp && ProxyCube::use_mp)>::eval(x.get_n_elem())) + { + if(is_same_type::yes) { arma_applier_3_mp(/=, +); } + else if(is_same_type::yes) { arma_applier_3_mp(/=, -); } + else if(is_same_type::yes) { arma_applier_3_mp(/=, /); } + else if(is_same_type::yes) { arma_applier_3_mp(/=, *); } + } + else + { + if(is_same_type::yes) { arma_applier_3(/=, +); } + else if(is_same_type::yes) { arma_applier_3(/=, -); } + else if(is_same_type::yes) { arma_applier_3(/=, /); } + else if(is_same_type::yes) { arma_applier_3(/=, *); } + } + } + } + + + +#undef arma_applier_1u +#undef arma_applier_1a +#undef arma_applier_2 +#undef arma_applier_3 + +#undef arma_applier_1_mp +#undef arma_applier_2_mp +#undef arma_applier_3_mp + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eop_aux.hpp b/src/armadillo/include/armadillo_bits/eop_aux.hpp new file mode 100644 index 0000000..2b66ef2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eop_aux.hpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eop_aux +//! @{ + + + +//! use of the SFINAE approach to work around compiler limitations +//! http://en.wikipedia.org/wiki/SFINAE + +class eop_aux + { + public: + + template arma_inline static typename arma_integral_only::result acos (const eT x) { return eT( std::acos(double(x)) ); } + template arma_inline static typename arma_integral_only::result asin (const eT x) { return eT( std::asin(double(x)) ); } + template arma_inline static typename arma_integral_only::result atan (const eT x) { return eT( std::atan(double(x)) ); } + + template arma_inline static typename arma_real_only::result acos (const eT x) { return std::acos(x); } + template arma_inline static typename arma_real_only::result asin (const eT x) { return std::asin(x); } + template arma_inline static typename arma_real_only::result atan (const eT x) { return std::atan(x); } + + template arma_inline static typename arma_cx_only::result acos (const eT x) { return std::acos(x); } + template arma_inline static typename arma_cx_only::result asin (const eT x) { return std::asin(x); } + template arma_inline static typename arma_cx_only::result atan (const eT x) { return std::atan(x); } + + template arma_inline static typename arma_integral_only::result acosh (const eT x) { return eT( std::acosh(double(x)) ); } + template arma_inline static typename arma_integral_only::result asinh (const eT x) { return eT( std::asinh(double(x)) ); } + template arma_inline static typename arma_integral_only::result atanh (const eT x) { return eT( std::atanh(double(x)) ); } + + template arma_inline static typename arma_real_or_cx_only::result acosh (const eT x) { return std::acosh(x); } + template arma_inline static typename arma_real_or_cx_only::result asinh (const eT x) { return std::asinh(x); } + template arma_inline static typename arma_real_or_cx_only::result atanh (const eT x) { return std::atanh(x); } + + template arma_inline static typename arma_not_cx::result conj(const eT x) { return x; } + template arma_inline static std::complex conj(const std::complex& x) { return std::conj(x); } + + template arma_inline static typename arma_integral_only::result sqrt (const eT x) { return eT( std::sqrt (double(x)) ); } + template arma_inline static typename arma_integral_only::result log10 (const eT x) { return eT( std::log10(double(x)) ); } + template arma_inline static typename arma_integral_only::result log (const eT x) { return eT( std::log (double(x)) ); } + template arma_inline static typename arma_integral_only::result exp (const eT x) { return eT( std::exp (double(x)) ); } + template arma_inline static typename arma_integral_only::result cos (const eT x) { return eT( std::cos (double(x)) ); } + template arma_inline static typename arma_integral_only::result sin (const eT x) { return eT( std::sin (double(x)) ); } + template arma_inline static typename arma_integral_only::result tan (const eT x) { return eT( std::tan (double(x)) ); } + template arma_inline static typename arma_integral_only::result cosh (const eT x) { return eT( std::cosh (double(x)) ); } + template arma_inline static typename arma_integral_only::result sinh (const eT x) { return eT( std::sinh (double(x)) ); } + template arma_inline static typename arma_integral_only::result tanh (const eT x) { return eT( std::tanh (double(x)) ); } + + template arma_inline static typename arma_real_or_cx_only::result sqrt (const eT x) { return std::sqrt (x); } + template arma_inline static typename arma_real_or_cx_only::result log10 (const eT x) { return std::log10(x); } + template arma_inline static typename arma_real_or_cx_only::result log (const eT x) { return std::log (x); } + template arma_inline static typename arma_real_or_cx_only::result exp (const eT x) { return std::exp (x); } + template arma_inline static typename arma_real_or_cx_only::result cos (const eT x) { return std::cos (x); } + template arma_inline static typename arma_real_or_cx_only::result sin (const eT x) { return std::sin (x); } + template arma_inline static typename arma_real_or_cx_only::result tan (const eT x) { return std::tan (x); } + template arma_inline static typename arma_real_or_cx_only::result cosh (const eT x) { return std::cosh (x); } + template arma_inline static typename arma_real_or_cx_only::result sinh (const eT x) { return std::sinh (x); } + template arma_inline static typename arma_real_or_cx_only::result tanh (const eT x) { return std::tanh (x); } + + template arma_inline static typename arma_unsigned_integral_only::result neg (const eT x) { return x; } + template arma_inline static typename arma_signed_only::result neg (const eT x) { return -x; } + + template arma_inline static typename arma_integral_only::result floor (const eT x) { return x; } + template arma_inline static typename arma_real_only::result floor (const eT x) { return std::floor(x); } + template arma_inline static typename arma_cx_only::result floor (const eT& x) { return eT( std::floor(x.real()), std::floor(x.imag()) ); } + + template arma_inline static typename arma_integral_only::result ceil (const eT x) { return x; } + template arma_inline static typename arma_real_only::result ceil (const eT x) { return std::ceil(x); } + template arma_inline static typename arma_cx_only::result ceil (const eT& x) { return eT( std::ceil(x.real()), std::ceil(x.imag()) ); } + + template arma_inline static typename arma_integral_only::result round (const eT x) { return x; } + template arma_inline static typename arma_real_only::result round (const eT x) { return std::round(x); } + template arma_inline static typename arma_cx_only::result round (const eT& x) { return eT( std::round(x.real()), std::round(x.imag()) ); } + + template arma_inline static typename arma_integral_only::result trunc (const eT x) { return x; } + template arma_inline static typename arma_real_only::result trunc (const eT x) { return std::trunc(x); } + template arma_inline static typename arma_cx_only::result trunc (const eT& x) { return eT( std::trunc(x.real()), std::trunc(x.imag()) ); } + + template arma_inline static typename arma_integral_only::result log2 (const eT x) { return eT( std::log2(double(x)) ); } + template arma_inline static typename arma_real_only::result log2 (const eT x) { return std::log2(x); } + template arma_inline static typename arma_cx_only::result log2 (const eT& x) { typedef typename get_pod_type::result T; return std::log(x) / T(0.69314718055994530942); } + + template arma_inline static typename arma_integral_only::result log1p (const eT x) { return eT( std::log1p(double(x)) ); } + template arma_inline static typename arma_real_only::result log1p (const eT x) { return std::log1p(x); } + template arma_inline static typename arma_cx_only::result log1p (const eT& x) { arma_ignore(x); return eT(0); } + + template arma_inline static typename arma_integral_only::result exp2 (const eT x) { return eT( std::exp2(double(x)) ); } + template arma_inline static typename arma_real_only::result exp2 (const eT x) { return std::exp2(x); } + template arma_inline static typename arma_cx_only::result exp2 (const eT& x) { typedef typename get_pod_type::result T; return std::pow( T(2), x); } + + template arma_inline static typename arma_integral_only::result exp10 (const eT x) { return eT( std::pow(double(10), double(x)) ); } + template arma_inline static typename arma_real_or_cx_only::result exp10 (const eT x) { typedef typename get_pod_type::result T; return std::pow( T(10), x); } + + template arma_inline static typename arma_integral_only::result expm1 (const eT x) { return eT( std::expm1(double(x)) ); } + template arma_inline static typename arma_real_only::result expm1 (const eT x) { return std::expm1(x); } + template arma_inline static typename arma_cx_only::result expm1 (const eT& x) { arma_ignore(x); return eT(0); } + + template arma_inline static typename arma_unsigned_integral_only::result arma_abs (const eT x) { return x; } + template arma_inline static typename arma_signed_integral_only::result arma_abs (const eT x) { return std::abs(x); } + template arma_inline static typename arma_real_only::result arma_abs (const eT x) { return std::abs(x); } + template arma_inline static typename arma_real_only< T>::result arma_abs (const std::complex& x) { return std::abs(x); } + + template arma_inline static typename arma_integral_only::result erf (const eT x) { return eT( std::erf(double(x)) ); } + template arma_inline static typename arma_real_only::result erf (const eT x) { return std::erf(x); } + template arma_inline static typename arma_cx_only::result erf (const eT& x) { arma_ignore(x); return eT(0); } + + template arma_inline static typename arma_integral_only::result erfc (const eT x) { return eT( std::erfc(double(x)) ); } + template arma_inline static typename arma_real_only::result erfc (const eT x) { return std::erfc(x); } + template arma_inline static typename arma_cx_only::result erfc (const eT& x) { arma_ignore(x); return eT(0); } + + template arma_inline static typename arma_integral_only::result lgamma (const eT x) { return eT( std::lgamma(double(x)) ); } + template arma_inline static typename arma_real_only::result lgamma (const eT x) { return std::lgamma(x); } + template arma_inline static typename arma_cx_only::result lgamma (const eT& x) { arma_ignore(x); return eT(0); } + + template arma_inline static typename arma_integral_only::result tgamma (const eT x) { return eT( std::tgamma(double(x)) ); } + template arma_inline static typename arma_real_only::result tgamma (const eT x) { return std::tgamma(x); } + template arma_inline static typename arma_cx_only::result tgamma (const eT& x) { arma_ignore(x); return eT(0); } + + template arma_inline static typename arma_integral_only::result pow (const T1 base, const T2 exponent) { return T1( std::pow( double(base), double(exponent) ) ); } + template arma_inline static typename arma_real_or_cx_only::result pow (const T1 base, const T2 exponent) { return T1( std::pow( base, exponent ) ); } + + + template + arma_inline + static + typename arma_integral_only::result + direct_eps(const eT) + { + return eT(0); + } + + + template + inline + static + typename arma_real_only::result + direct_eps(const eT x) + { + //arma_extra_debug_sigprint(); + + // acording to IEEE Standard for Floating-Point Arithmetic (IEEE 754) + // the mantissa length for double is 53 bits = std::numeric_limits::digits + // the mantissa length for float is 24 bits = std::numeric_limits::digits + + //return std::pow( std::numeric_limits::radix, (std::floor(std::log10(std::abs(x))/std::log10(std::numeric_limits::radix))-(std::numeric_limits::digits-1)) ); + + const eT radix_eT = eT(std::numeric_limits::radix); + const eT digits_m1_eT = eT(std::numeric_limits::digits - 1); + + // return std::pow( radix_eT, eT(std::floor(std::log10(std::abs(x))/std::log10(radix_eT)) - digits_m1_eT) ); + return eop_aux::pow( radix_eT, eT(std::floor(std::log10(std::abs(x))/std::log10(radix_eT)) - digits_m1_eT) ); + } + + + template + inline + static + typename arma_real_only::result + direct_eps(const std::complex& x) + { + //arma_extra_debug_sigprint(); + + //return std::pow( std::numeric_limits::radix, (std::floor(std::log10(std::abs(x))/std::log10(std::numeric_limits::radix))-(std::numeric_limits::digits-1)) ); + + const T radix_T = T(std::numeric_limits::radix); + const T digits_m1_T = T(std::numeric_limits::digits - 1); + + return std::pow( radix_T, T(std::floor(std::log10(std::abs(x))/std::log10(radix_T)) - digits_m1_T) ); + } + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/eop_core_bones.hpp b/src/armadillo/include/armadillo_bits/eop_core_bones.hpp new file mode 100644 index 0000000..8b9c75c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eop_core_bones.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eop_core +//! @{ + + + +template +class eop_core + { + public: + + // matrices + + template arma_hot inline static void apply(outT& out, const eOp& x); + + template arma_hot inline static void apply_inplace_plus (Mat& out, const eOp& x); + template arma_hot inline static void apply_inplace_minus(Mat& out, const eOp& x); + template arma_hot inline static void apply_inplace_schur(Mat& out, const eOp& x); + template arma_hot inline static void apply_inplace_div (Mat& out, const eOp& x); + + + // cubes + + template arma_hot inline static void apply(Cube& out, const eOpCube& x); + + template arma_hot inline static void apply_inplace_plus (Cube& out, const eOpCube& x); + template arma_hot inline static void apply_inplace_minus(Cube& out, const eOpCube& x); + template arma_hot inline static void apply_inplace_schur(Cube& out, const eOpCube& x); + template arma_hot inline static void apply_inplace_div (Cube& out, const eOpCube& x); + + + // common + + template arma_inline static eT process(const eT val, const eT k); + }; + + +struct eop_use_mp_true { static constexpr bool use_mp = true; }; +struct eop_use_mp_false { static constexpr bool use_mp = false; }; + + +class eop_neg : public eop_core , public eop_use_mp_false {}; +class eop_scalar_plus : public eop_core , public eop_use_mp_false {}; +class eop_scalar_minus_pre : public eop_core , public eop_use_mp_false {}; +class eop_scalar_minus_post : public eop_core , public eop_use_mp_false {}; +class eop_scalar_times : public eop_core , public eop_use_mp_false {}; +class eop_scalar_div_pre : public eop_core , public eop_use_mp_false {}; +class eop_scalar_div_post : public eop_core , public eop_use_mp_false {}; +class eop_square : public eop_core , public eop_use_mp_false {}; +class eop_sqrt : public eop_core , public eop_use_mp_true {}; +class eop_pow : public eop_core , public eop_use_mp_false {}; // for pow(), use_mp is selectively enabled in eop_core_meat.hpp +class eop_log : public eop_core , public eop_use_mp_true {}; +class eop_log2 : public eop_core , public eop_use_mp_true {}; +class eop_log10 : public eop_core , public eop_use_mp_true {}; +class eop_trunc_log : public eop_core , public eop_use_mp_true {}; +class eop_log1p : public eop_core , public eop_use_mp_true {}; +class eop_exp : public eop_core , public eop_use_mp_true {}; +class eop_exp2 : public eop_core , public eop_use_mp_true {}; +class eop_exp10 : public eop_core , public eop_use_mp_true {}; +class eop_trunc_exp : public eop_core , public eop_use_mp_true {}; +class eop_expm1 : public eop_core , public eop_use_mp_true {}; +class eop_cos : public eop_core , public eop_use_mp_true {}; +class eop_sin : public eop_core , public eop_use_mp_true {}; +class eop_tan : public eop_core , public eop_use_mp_true {}; +class eop_acos : public eop_core , public eop_use_mp_true {}; +class eop_asin : public eop_core , public eop_use_mp_true {}; +class eop_atan : public eop_core , public eop_use_mp_true {}; +class eop_cosh : public eop_core , public eop_use_mp_true {}; +class eop_sinh : public eop_core , public eop_use_mp_true {}; +class eop_tanh : public eop_core , public eop_use_mp_true {}; +class eop_acosh : public eop_core , public eop_use_mp_true {}; +class eop_asinh : public eop_core , public eop_use_mp_true {}; +class eop_atanh : public eop_core , public eop_use_mp_true {}; +class eop_sinc : public eop_core , public eop_use_mp_true {}; +class eop_eps : public eop_core , public eop_use_mp_true {}; +class eop_abs : public eop_core , public eop_use_mp_false {}; +class eop_arg : public eop_core , public eop_use_mp_false {}; +class eop_conj : public eop_core , public eop_use_mp_false {}; +class eop_floor : public eop_core , public eop_use_mp_false {}; +class eop_ceil : public eop_core , public eop_use_mp_false {}; +class eop_round : public eop_core , public eop_use_mp_false {}; +class eop_trunc : public eop_core , public eop_use_mp_false {}; +class eop_sign : public eop_core , public eop_use_mp_false {}; +class eop_erf : public eop_core , public eop_use_mp_true {}; +class eop_erfc : public eop_core , public eop_use_mp_true {}; +class eop_lgamma : public eop_core , public eop_use_mp_true {}; +class eop_tgamma : public eop_core , public eop_use_mp_true {}; + + + +// the classes below are currently not used; reserved for potential future use +class eop_log_approx {}; +class eop_exp_approx {}; +class eop_approx_log {}; +class eop_approx_exp {}; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/eop_core_meat.hpp b/src/armadillo/include/armadillo_bits/eop_core_meat.hpp new file mode 100644 index 0000000..4bc0c7f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/eop_core_meat.hpp @@ -0,0 +1,1168 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup eop_core +//! @{ + + +#undef arma_applier_1u +#undef arma_applier_1a +#undef arma_applier_2 +#undef arma_applier_3 +#undef operatorA + +#undef arma_applier_1_mp +#undef arma_applier_2_mp +#undef arma_applier_3_mp + + +#if defined(ARMA_SIMPLE_LOOPS) + #define arma_applier_1u(operatorA) \ + {\ + for(uword i=0; i::process(P[i], k);\ + }\ + } +#else + #define arma_applier_1u(operatorA) \ + {\ + uword i,j;\ + \ + for(i=0, j=1; j::process(tmp_i, k);\ + tmp_j = eop_core::process(tmp_j, k);\ + \ + out_mem[i] operatorA tmp_i;\ + out_mem[j] operatorA tmp_j;\ + }\ + \ + if(i < n_elem)\ + {\ + out_mem[i] operatorA eop_core::process(P[i], k);\ + }\ + } +#endif + + + +#if defined(ARMA_SIMPLE_LOOPS) + #define arma_applier_1a(operatorA) \ + {\ + for(uword i=0; i::process(P.at_alt(i), k);\ + }\ + } +#else + #define arma_applier_1a(operatorA) \ + {\ + uword i,j;\ + \ + for(i=0, j=1; j::process(tmp_i, k);\ + tmp_j = eop_core::process(tmp_j, k);\ + \ + out_mem[i] operatorA tmp_i;\ + out_mem[j] operatorA tmp_j;\ + }\ + \ + if(i < n_elem)\ + {\ + out_mem[i] operatorA eop_core::process(P.at_alt(i), k);\ + }\ + } +#endif + + +#define arma_applier_2(operatorA) \ + {\ + if(n_rows != 1)\ + {\ + for(uword col=0; col::process(tmp_i, k);\ + tmp_j = eop_core::process(tmp_j, k);\ + \ + *out_mem operatorA tmp_i; out_mem++;\ + *out_mem operatorA tmp_j; out_mem++;\ + }\ + \ + if(i < n_rows)\ + {\ + *out_mem operatorA eop_core::process(P.at(i,col), k); out_mem++;\ + }\ + }\ + }\ + else\ + {\ + for(uword count=0; count < n_cols; ++count)\ + {\ + out_mem[count] operatorA eop_core::process(P.at(0,count), k);\ + }\ + }\ + } + + + +#define arma_applier_3(operatorA) \ + {\ + for(uword slice=0; slice::process(tmp_i, k);\ + tmp_j = eop_core::process(tmp_j, k);\ + \ + *out_mem operatorA tmp_i; out_mem++; \ + *out_mem operatorA tmp_j; out_mem++; \ + }\ + \ + if(i < n_rows)\ + {\ + *out_mem operatorA eop_core::process(P.at(i,col,slice), k); out_mem++; \ + }\ + }\ + }\ + } + + + +#if defined(ARMA_USE_OPENMP) + + #define arma_applier_1_mp(operatorA) \ + {\ + const int n_threads = mp_thread_limit::get();\ + _Pragma("omp parallel for schedule(static) num_threads(n_threads)")\ + for(uword i=0; i::process(P[i], k);\ + }\ + } + + #define arma_applier_2_mp(operatorA) \ + {\ + const int n_threads = mp_thread_limit::get();\ + if(n_cols == 1)\ + {\ + _Pragma("omp parallel for schedule(static) num_threads(n_threads)")\ + for(uword count=0; count < n_rows; ++count)\ + {\ + out_mem[count] operatorA eop_core::process(P.at(count,0), k);\ + }\ + }\ + else\ + if(n_rows == 1)\ + {\ + _Pragma("omp parallel for schedule(static) num_threads(n_threads)")\ + for(uword count=0; count < n_cols; ++count)\ + {\ + out_mem[count] operatorA eop_core::process(P.at(0,count), k);\ + }\ + }\ + else\ + {\ + _Pragma("omp parallel for schedule(static) num_threads(n_threads)")\ + for(uword col=0; col < n_cols; ++col)\ + {\ + for(uword row=0; row < n_rows; ++row)\ + {\ + out.at(row,col) operatorA eop_core::process(P.at(row,col), k);\ + }\ + }\ + }\ + } + + #define arma_applier_3_mp(operatorA) \ + {\ + const int n_threads = mp_thread_limit::get();\ + _Pragma("omp parallel for schedule(static) num_threads(n_threads)")\ + for(uword slice=0; slice::process(P.at(row,col,slice), k);\ + }\ + }\ + } + +#else + + #define arma_applier_1_mp(operatorA) arma_applier_1u(operatorA) + #define arma_applier_2_mp(operatorA) arma_applier_2(operatorA) + #define arma_applier_3_mp(operatorA) arma_applier_3(operatorA) + +#endif + + + +// +// matrices + + + +template +template +inline +void +eop_core::apply(outT& out, const eOp& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // NOTE: we're assuming that the matrix has already been set to the correct size and there is no aliasing; + // size setting and alias checking is done by either the Mat contructor or operator=() + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(Proxy::use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::eval(n_elem)) + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename Proxy::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(=); + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(=); + } + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(=); + } + } + } + else + { + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + const Proxy& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_2_mp(=); + } + else + { + arma_applier_2(=); + } + } + } + + + +template +template +inline +void +eop_core::apply_inplace_plus(Mat& out, const eOp& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "addition"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(Proxy::use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::eval(n_elem)) + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(+=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename Proxy::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(+=); + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(+=); + } + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(+=); + } + } + } + else + { + const Proxy& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_2_mp(+=); + } + else + { + arma_applier_2(+=); + } + } + } + + + +template +template + +inline +void +eop_core::apply_inplace_minus(Mat& out, const eOp& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "subtraction"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(Proxy::use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::eval(n_elem)) + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(-=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename Proxy::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(-=); + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(-=); + } + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(-=); + } + } + } + else + { + const Proxy& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_2_mp(-=); + } + else + { + arma_applier_2(-=); + } + } + } + + + +template +template + +inline +void +eop_core::apply_inplace_schur(Mat& out, const eOp& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise multiplication"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(Proxy::use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::eval(n_elem)) + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(*=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename Proxy::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(*=); + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(*=); + } + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(*=); + } + } + } + else + { + const Proxy& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_2_mp(*=); + } + else + { + arma_applier_2(*=); + } + } + } + + + +template +template + +inline +void +eop_core::apply_inplace_div(Mat& out, const eOp& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise division"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(Proxy::use_at == false) + { + const uword n_elem = x.get_n_elem(); + + if(use_mp && mp_gate::eval(n_elem)) + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(/=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename Proxy::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(/=); + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(/=); + } + } + else + { + typename Proxy::ea_type P = x.P.get_ea(); + + arma_applier_1u(/=); + } + } + } + else + { + const Proxy& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_2_mp(/=); + } + else + { + arma_applier_2(/=); + } + } + } + + + +// +// cubes + + + +template +template + +inline +void +eop_core::apply(Cube& out, const eOpCube& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // NOTE: we're assuming that the matrix has already been set to the correct size and there is no aliasing; + // size setting and alias checking is done by either the Mat contructor or operator=() + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(ProxyCube::use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::eval(n_elem)) + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename ProxyCube::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(=); + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(=); + } + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(=); + } + } + } + else + { + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + const ProxyCube& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_3_mp(=); + } + else + { + arma_applier_3(=); + } + } + } + + + +template +template + +inline +void +eop_core::apply_inplace_plus(Cube& out, const eOpCube& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "addition"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(ProxyCube::use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::eval(n_elem)) + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(+=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename ProxyCube::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(+=); + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(+=); + } + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(+=); + } + } + } + else + { + const ProxyCube& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_3_mp(+=); + } + else + { + arma_applier_3(+=); + } + } + } + + + +template +template + +inline +void +eop_core::apply_inplace_minus(Cube& out, const eOpCube& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "subtraction"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(ProxyCube::use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::eval(n_elem)) + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(-=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename ProxyCube::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(-=); + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(-=); + } + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(-=); + } + } + } + else + { + const ProxyCube& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_3_mp(-=); + } + else + { + arma_applier_3(-=); + } + } + } + + + +template +template + +inline +void +eop_core::apply_inplace_schur(Cube& out, const eOpCube& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise multiplication"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(ProxyCube::use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::eval(n_elem)) + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(*=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename ProxyCube::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(*=); + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(*=); + } + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(*=); + } + } + } + else + { + const ProxyCube& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_3_mp(*=); + } + else + { + arma_applier_3(*=); + } + } + } + + + +template +template + +inline +void +eop_core::apply_inplace_div(Cube& out, const eOpCube& x) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = x.get_n_rows(); + const uword n_cols = x.get_n_cols(); + const uword n_slices = x.get_n_slices(); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise division"); + + const eT k = x.aux; + eT* out_mem = out.memptr(); + + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + + if(ProxyCube::use_at == false) + { + const uword n_elem = out.n_elem; + + if(use_mp && mp_gate::eval(n_elem)) + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1_mp(/=); + } + else + { + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + if(x.P.is_aligned()) + { + typename ProxyCube::aligned_ea_type P = x.P.get_aligned_ea(); + + arma_applier_1a(/=); + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(/=); + } + } + else + { + typename ProxyCube::ea_type P = x.P.get_ea(); + + arma_applier_1u(/=); + } + } + } + else + { + const ProxyCube& P = x.P; + + if(use_mp && mp_gate::eval(x.get_n_elem())) + { + arma_applier_3_mp(/=); + } + else + { + arma_applier_3(/=); + } + } + } + + + +// +// common + + + +template +template +arma_inline +eT +eop_core::process(const eT, const eT) + { + arma_stop_logic_error("eop_core::process(): unhandled eop_type"); + return eT(0); + } + + + +template<> template arma_inline eT +eop_core::process(const eT val, const eT k) { return val + k; } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT k) { return k - val; } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT k) { return val - k; } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT k) { return val * k; } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT k) { return k / val; } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT k) { return val / k; } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return val*val; } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::neg(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::sqrt(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::log(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::log2(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::log10(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return arma::trunc_log(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::log1p(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::exp(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::exp2(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::exp10(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return arma::trunc_exp(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::expm1(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::cos(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::sin(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::tan(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::acos(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::asin(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::atan(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::cosh(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::sinh(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::tanh(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::acosh(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::asinh(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::atanh(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return arma_sinc(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::direct_eps(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::arma_abs(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return arma_arg::eval(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::conj(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT k) { return eop_aux::pow(val, k); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::floor(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::ceil(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::round(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::trunc(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return arma_sign(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::erf(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::erfc(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::lgamma(val); } + +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::tgamma(val); } + + +#undef arma_applier_1u +#undef arma_applier_1a +#undef arma_applier_2 +#undef arma_applier_3 + +#undef arma_applier_1_mp +#undef arma_applier_2_mp +#undef arma_applier_3_mp + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fft_engine_fftw3.hpp b/src/armadillo/include/armadillo_bits/fft_engine_fftw3.hpp new file mode 100644 index 0000000..cabe5c9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fft_engine_fftw3.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ------------------------------------------------------------------------ + + +//! \addtogroup fft_engine_fftw3 +//! @{ + + +#if defined(ARMA_USE_FFTW3) + +template +class fft_engine_fftw3 + { + public: + + constexpr static int fftw3_sign_forward = -1; + constexpr static int fftw3_sign_backward = +1; + + constexpr static unsigned int fftw3_flag_destroy = (1u << 0); + constexpr static unsigned int fftw3_flag_preserve = (1u << 4); + constexpr static unsigned int fftw3_flag_estimate = (1u << 6); + + const uword N; + + void_ptr fftw3_plan; + + podarray X_work; // for storing copy of input (can be overwritten by FFTW3) + podarray Y_work; // for storing output + + inline + ~fft_engine_fftw3() + { + arma_extra_debug_sigprint(); + + if(fftw3_plan != nullptr) { fftw3::destroy_plan(fftw3_plan); } + + // fftw3::cleanup(); // NOTE: this also removes any wisdom acquired by FFTW3 + } + + inline + fft_engine_fftw3(const uword in_N) + : N (in_N ) + , fftw3_plan(nullptr) + { + arma_extra_debug_sigprint(); + + if(N == 0) { return; } + + if(N > uword(std::numeric_limits::max())) + { + arma_stop_runtime_error("integer overflow: FFT size too large for integer type used by FFTW3"); + } + + arma_extra_debug_print("fft_engine_fftw3::constructor: allocating work arrays"); + X_work.set_size(N); + Y_work.set_size(N); + + const int fftw3_sign = (inverse) ? fftw3_sign_backward : fftw3_sign_forward; + const int fftw3_flags = fftw3_flag_destroy | fftw3_flag_estimate; + + arma_extra_debug_print("fft_engine_fftw3::constructor: generating 1D plan"); + fftw3_plan = fftw3::plan_dft_1d(N, X_work.memptr(), Y_work.memptr(), fftw3_sign, fftw3_flags); + + if(fftw3_plan == nullptr) { arma_stop_runtime_error("fft_engine_fftw3::constructor: failed to create plan"); } + } + + inline + void + run(cx_type* Y, const cx_type* X) + { + arma_extra_debug_sigprint(); + + if(fftw3_plan == nullptr) { return; } + + arma_extra_debug_print("fft_engine_fftw3::run(): copying input array"); + arrayops::copy(X_work.memptr(), X, N); + + arma_extra_debug_print("fft_engine_fftw3::run(): executing plan"); + fftw3::execute(fftw3_plan); + + arma_extra_debug_print("fft_engine_fftw3::run(): copying output array"); + arrayops::copy(Y, Y_work.memptr(), N); + } + }; + +#endif + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fft_engine_kissfft.hpp b/src/armadillo/include/armadillo_bits/fft_engine_kissfft.hpp new file mode 100644 index 0000000..0c8c40c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fft_engine_kissfft.hpp @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ------------------------------------------------------------------------ +// +// This file includes portions of Kiss FFT software, +// licensed under the following conditions. +// +// Copyright (c) 2003-2010 Mark Borgerding +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the author nor the names of any contributors may be used to +// endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +// OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ------------------------------------------------------------------------ + + +//! \addtogroup fft_engine_kissfft +//! @{ + + +template +class fft_engine_kissfft + { + public: + + typedef typename get_pod_type::result T; + + const uword N; + + podarray coeffs_array; + podarray tmp_array; + + podarray residue; + podarray radix; + + + template + inline + uword + calc_radix() + { + uword i = 0; + + for(uword n = N, r=4; n >= 2; ++i) + { + while( (n % r) > 0 ) + { + switch(r) + { + case 2: r = 3; break; + case 4: r = 2; break; + default: r += 2; break; + } + + if(r*r > n) { r = n; } + } + + n /= r; + + if(fill) + { + residue[i] = n; + radix[i] = r; + } + } + + return i; + } + + + + inline + fft_engine_kissfft(const uword in_N) + : N(in_N) + { + arma_extra_debug_sigprint(); + + const uword len = calc_radix(); + + residue.set_size(len); + radix.set_size(len); + + calc_radix(); + + + // calculate the constant coefficients + + coeffs_array.set_size(N); + + cx_type* coeffs = coeffs_array.memptr(); + + const T k = T( (inverse) ? +2 : -2 ) * std::acos( T(-1) ) / T(N); + + for(uword i=0; i < N; ++i) { coeffs[i] = std::exp( cx_type(T(0), i*k) ); } + } + + + + arma_hot + inline + void + butterfly_2(cx_type* Y, const uword stride, const uword m) const + { + // arma_extra_debug_sigprint(); + + const cx_type* coeffs = coeffs_array.memptr(); + + for(uword i=0; i < m; ++i) + { + const cx_type t = Y[i+m] * coeffs[i*stride]; + + Y[i+m] = Y[i] - t; + Y[i ] += t; + } + } + + + + arma_hot + inline + void + butterfly_3(cx_type* Y, const uword stride, const uword m) const + { + // arma_extra_debug_sigprint(); + + arma_aligned cx_type tmp[5]; + + const cx_type* coeffs1 = coeffs_array.memptr(); + const cx_type* coeffs2 = coeffs1; + + const T coeff_sm_imag = coeffs1[stride*m].imag(); + + const uword n = m*2; + + // TODO: rearrange the indices within tmp[] into a more sane order + + for(uword i = m; i > 0; --i) + { + tmp[1] = Y[m] * (*coeffs1); + tmp[2] = Y[n] * (*coeffs2); + + tmp[0] = tmp[1] - tmp[2]; + tmp[0] *= coeff_sm_imag; + + tmp[3] = tmp[1] + tmp[2]; + + Y[m] = cx_type( (Y[0].real() - (T(0.5)*tmp[3].real())), (Y[0].imag() - (T(0.5)*tmp[3].imag())) ); + + Y[0] += tmp[3]; + + + Y[n] = cx_type( (Y[m].real() + tmp[0].imag()), (Y[m].imag() - tmp[0].real()) ); + + Y[m] += cx_type( -tmp[0].imag(), tmp[0].real() ); + + Y++; + + coeffs1 += stride; + coeffs2 += stride*2; + } + } + + + + arma_hot + inline + void + butterfly_4(cx_type* Y, const uword stride, const uword m) const + { + // arma_extra_debug_sigprint(); + + arma_aligned cx_type tmp[7]; + + const cx_type* coeffs = coeffs_array.memptr(); + + const uword m2 = m*2; + const uword m3 = m*3; + + // TODO: rearrange the indices within tmp[] into a more sane order + + for(uword i=0; i < m; ++i) + { + tmp[0] = Y[i + m ] * coeffs[i*stride ]; + tmp[2] = Y[i + m3] * coeffs[i*stride*3]; + tmp[3] = tmp[0] + tmp[2]; + + //tmp[4] = tmp[0] - tmp[2]; + //tmp[4] = (inverse) ? cx_type( -(tmp[4].imag()), tmp[4].real() ) : cx_type( tmp[4].imag(), -tmp[4].real() ); + + tmp[4] = (inverse) + ? cx_type( (tmp[2].imag() - tmp[0].imag()), (tmp[0].real() - tmp[2].real()) ) + : cx_type( (tmp[0].imag() - tmp[2].imag()), (tmp[2].real() - tmp[0].real()) ); + + tmp[1] = Y[i + m2] * coeffs[i*stride*2]; + tmp[5] = Y[i] - tmp[1]; + + + Y[i ] += tmp[1]; + Y[i + m2] = Y[i] - tmp[3]; + Y[i ] += tmp[3]; + Y[i + m ] = tmp[5] + tmp[4]; + Y[i + m3] = tmp[5] - tmp[4]; + } + } + + + + arma_hot + inline + void + butterfly_5(cx_type* Y, const uword stride, const uword m) const + { + // arma_extra_debug_sigprint(); + + arma_aligned cx_type tmp[13]; + + const cx_type* coeffs = coeffs_array.memptr(); + + const T a_real = coeffs[stride*1*m].real(); + const T a_imag = coeffs[stride*1*m].imag(); + + const T b_real = coeffs[stride*2*m].real(); + const T b_imag = coeffs[stride*2*m].imag(); + + cx_type* Y0 = Y; + cx_type* Y1 = Y + 1*m; + cx_type* Y2 = Y + 2*m; + cx_type* Y3 = Y + 3*m; + cx_type* Y4 = Y + 4*m; + + for(uword i=0; i < m; ++i) + { + tmp[0] = (*Y0); + + tmp[1] = (*Y1) * coeffs[stride*1*i]; + tmp[2] = (*Y2) * coeffs[stride*2*i]; + tmp[3] = (*Y3) * coeffs[stride*3*i]; + tmp[4] = (*Y4) * coeffs[stride*4*i]; + + tmp[7] = tmp[1] + tmp[4]; + tmp[8] = tmp[2] + tmp[3]; + tmp[9] = tmp[2] - tmp[3]; + tmp[10] = tmp[1] - tmp[4]; + + (*Y0) += tmp[7]; + (*Y0) += tmp[8]; + + tmp[5] = tmp[0] + cx_type( ( (tmp[7].real() * a_real) + (tmp[8].real() * b_real) ), ( (tmp[7].imag() * a_real) + (tmp[8].imag() * b_real) ) ); + + tmp[6] = cx_type( ( (tmp[10].imag() * a_imag) + (tmp[9].imag() * b_imag) ), ( -(tmp[10].real() * a_imag) - (tmp[9].real() * b_imag) ) ); + + (*Y1) = tmp[5] - tmp[6]; + (*Y4) = tmp[5] + tmp[6]; + + tmp[11] = tmp[0] + cx_type( ( (tmp[7].real() * b_real) + (tmp[8].real() * a_real) ), ( (tmp[7].imag() * b_real) + (tmp[8].imag() * a_real) ) ); + + tmp[12] = cx_type( ( -(tmp[10].imag() * b_imag) + (tmp[9].imag() * a_imag) ), ( (tmp[10].real() * b_imag) - (tmp[9].real() * a_imag) ) ); + + (*Y2) = tmp[11] + tmp[12]; + (*Y3) = tmp[11] - tmp[12]; + + Y0++; + Y1++; + Y2++; + Y3++; + Y4++; + } + } + + + + arma_hot + inline + void + butterfly_N(cx_type* Y, const uword stride, const uword m, const uword r) + { + // arma_extra_debug_sigprint(); + + const cx_type* coeffs = coeffs_array.memptr(); + + tmp_array.set_min_size(r); + cx_type* tmp = tmp_array.memptr(); + + for(uword u=0; u < m; ++u) + { + uword k = u; + + for(uword v=0; v < r; ++v) + { + tmp[v] = Y[k]; + k += m; + } + + k = u; + + for(uword v=0; v < r; ++v) + { + Y[k] = tmp[0]; + + uword j = 0; + + for(uword w=1; w < r; ++w) + { + j += stride * k; + + if(j >= N) { j -= N; } + + Y[k] += tmp[w] * coeffs[j]; + } + + k += m; + } + } + } + + + + inline + void + run(cx_type* Y, const cx_type* X, const uword stage = 0, const uword stride = 1) + { + arma_extra_debug_sigprint(); + + const uword m = residue[stage]; + const uword r = radix[stage]; + + const cx_type *Y_end = Y + r*m; + + if(m == 1) + { + for(cx_type* Yi = Y; Yi != Y_end; Yi++, X += stride) { (*Yi) = (*X); } + } + else + { + const uword next_stage = stage + 1; + const uword next_stride = stride * r; + + for(cx_type* Yi = Y; Yi != Y_end; Yi += m, X += stride) { run(Yi, X, next_stage, next_stride); } + } + + switch(r) + { + case 2: butterfly_2(Y, stride, m ); break; + case 3: butterfly_3(Y, stride, m ); break; + case 4: butterfly_4(Y, stride, m ); break; + case 5: butterfly_5(Y, stride, m ); break; + default: butterfly_N(Y, stride, m, r); break; + } + } + + + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/field_bones.hpp b/src/armadillo/include/armadillo_bits/field_bones.hpp new file mode 100644 index 0000000..d3d2b05 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/field_bones.hpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup field +//! @{ + + + +struct field_prealloc_n_elem + { + static constexpr uword val = 16; + }; + + + +//! A lightweight 1D/2D/3D container for arbitrary objects +//! (the objects must have a copy constructor) + +template +class field + { + public: + + typedef oT object_type; + + const uword n_rows; //!< number of rows (read-only) + const uword n_cols; //!< number of columns (read-only) + const uword n_slices; //!< number of slices (read-only) + const uword n_elem; //!< number of elements (read-only) + + + private: + + arma_aligned oT** mem; //!< pointers to stored objects + arma_aligned oT* mem_local[ field_prealloc_n_elem::val ]; //!< local storage, for small fields + + + public: + + inline ~field(); + inline field(); + + inline field(const field& x); + inline field& operator=(const field& x); + + inline field(const subview_field& x); + inline field& operator=(const subview_field& x); + + inline explicit field(const uword n_elem_in); + inline explicit field(const uword n_rows_in, const uword n_cols_in); + inline explicit field(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in); + inline explicit field(const SizeMat& s); + inline explicit field(const SizeCube& s); + + inline field& set_size(const uword n_obj_in); + inline field& set_size(const uword n_rows_in, const uword n_cols_in); + inline field& set_size(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in); + inline field& set_size(const SizeMat& s); + inline field& set_size(const SizeCube& s); + + inline field(const std::vector& x); + inline field& operator=(const std::vector& x); + + inline field(const std::initializer_list& list); + inline field& operator=(const std::initializer_list& list); + + inline field(const std::initializer_list< std::initializer_list >& list); + inline field& operator=(const std::initializer_list< std::initializer_list >& list); + + inline field(field&& X); + inline field& operator=(field&& X); + + template + inline field& copy_size(const field& x); + + arma_warn_unused arma_inline oT& operator[](const uword i); + arma_warn_unused arma_inline const oT& operator[](const uword i) const; + + arma_warn_unused arma_inline oT& at(const uword i); + arma_warn_unused arma_inline const oT& at(const uword i) const; + + arma_warn_unused arma_inline oT& operator()(const uword i); + arma_warn_unused arma_inline const oT& operator()(const uword i) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline oT& operator[](const uword row, const uword col); + arma_warn_unused arma_inline const oT& operator[](const uword row, const uword col) const; + #endif + + arma_warn_unused arma_inline oT& at(const uword row, const uword col); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline oT& operator[](const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& operator[](const uword row, const uword col, const uword slice) const; + #endif + + arma_warn_unused arma_inline oT& at(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col, const uword slice) const; + + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col) const; + + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col, const uword slice) const; + + + arma_warn_unused arma_inline oT& front(); + arma_warn_unused arma_inline const oT& front() const; + + arma_warn_unused arma_inline oT& back(); + arma_warn_unused arma_inline const oT& back() const; + + + arma_frown("use braced initialiser list instead") inline field_injector operator<<(const oT& val); + arma_frown("use braced initialiser list instead") inline field_injector operator<<(const injector_end_of_row<>& x); + + + inline subview_field row(const uword row_num); + inline const subview_field row(const uword row_num) const; + + inline subview_field col(const uword col_num); + inline const subview_field col(const uword col_num) const; + + inline subview_field slice(const uword slice_num); + inline const subview_field slice(const uword slice_num) const; + + inline subview_field rows(const uword in_row1, const uword in_row2); + inline const subview_field rows(const uword in_row1, const uword in_row2) const; + + inline subview_field cols(const uword in_col1, const uword in_col2); + inline const subview_field cols(const uword in_col1, const uword in_col2) const; + + inline subview_field slices(const uword in_slice1, const uword in_slice2); + inline const subview_field slices(const uword in_slice1, const uword in_slice2) const; + + inline subview_field subfield(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); + inline const subview_field subfield(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const; + + inline subview_field subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2); + inline const subview_field subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2) const; + + inline subview_field subfield(const uword in_row1, const uword in_col1, const SizeMat& s); + inline const subview_field subfield(const uword in_row1, const uword in_col1, const SizeMat& s) const; + + inline subview_field subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s); + inline const subview_field subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const; + + inline subview_field subfield(const span& row_span, const span& col_span); + inline const subview_field subfield(const span& row_span, const span& col_span) const; + + inline subview_field subfield(const span& row_span, const span& col_span, const span& slice_span); + inline const subview_field subfield(const span& row_span, const span& col_span, const span& slice_span) const; + + inline subview_field operator()(const span& row_span, const span& col_span); + inline const subview_field operator()(const span& row_span, const span& col_span) const; + + inline subview_field operator()(const span& row_span, const span& col_span, const span& slice_span); + inline const subview_field operator()(const span& row_span, const span& col_span, const span& slice_span) const; + + inline subview_field operator()(const uword in_row1, const uword in_col1, const SizeMat& s); + inline const subview_field operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const; + + inline subview_field operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s); + inline const subview_field operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const; + + + arma_cold inline void print( const std::string extra_text = "") const; + arma_cold inline void print(std::ostream& user_stream, const std::string extra_text = "") const; + + inline field& for_each(const std::function< void( oT&) >& F); + inline const field& for_each(const std::function< void(const oT&) >& F) const; + + inline field& fill(const oT& x); + + inline void reset(); + inline void reset_objects(); + + arma_warn_unused arma_inline bool is_empty() const; + + + arma_warn_unused arma_inline bool in_range(const uword i) const; + arma_warn_unused arma_inline bool in_range(const span& x) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const span& col_span) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const uword in_slice) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span, const span& slice_span) const; + + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const; + + + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; + + arma_cold inline bool load(const std::string name, const file_type type = auto_detect); + arma_cold inline bool load( std::istream& is, const file_type type = auto_detect); + + + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = auto_detect); + + + // for container-like functionality + + typedef oT value_type; + typedef uword size_type; + + + class iterator + { + public: + + inline iterator(field& in_M, const bool at_end = false); + + inline oT& operator* (); + + inline iterator& operator++(); + inline void operator++(int); + + inline iterator& operator--(); + inline void operator--(int); + + inline bool operator!=(const iterator& X) const; + inline bool operator==(const iterator& X) const; + + arma_aligned field& M; + arma_aligned uword i; + }; + + + class const_iterator + { + public: + + const_iterator(const field& in_M, const bool at_end = false); + const_iterator(const iterator& X); + + inline const oT& operator*() const; + + inline const_iterator& operator++(); + inline void operator++(int); + + inline const_iterator& operator--(); + inline void operator--(int); + + inline bool operator!=(const const_iterator& X) const; + inline bool operator==(const const_iterator& X) const; + + arma_aligned const field& M; + arma_aligned uword i; + }; + + inline iterator begin(); + inline const_iterator begin() const; + inline const_iterator cbegin() const; + + inline iterator end(); + inline const_iterator end() const; + inline const_iterator cend() const; + + inline void clear(); + inline bool empty() const; + inline uword size() const; + + + private: + + inline void init(const field& x); + inline void init(const uword n_rows_in, const uword n_cols_in); + inline void init(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in); + + inline void delete_objects(); + inline void create_objects(); + + friend class field_aux; + friend class subview_field; + + + public: + + #if defined(ARMA_EXTRA_FIELD_PROTO) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_FIELD_PROTO) + #endif + }; + + + +class field_aux + { + public: + + template inline static void reset_objects(field< oT >& x); + template inline static void reset_objects(field< Mat >& x); + template inline static void reset_objects(field< Col >& x); + template inline static void reset_objects(field< Row >& x); + template inline static void reset_objects(field< Cube >& x); + inline static void reset_objects(field< std::string >& x); + + + template inline static bool save(const field< oT >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool save(const field< oT >& x, std::ostream& os, const file_type type, std::string& err_msg); + template inline static bool load( field< oT >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool load( field< oT >& x, std::istream& is, const file_type type, std::string& err_msg); + + template inline static bool save(const field< Mat >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool save(const field< Mat >& x, std::ostream& os, const file_type type, std::string& err_msg); + template inline static bool load( field< Mat >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool load( field< Mat >& x, std::istream& is, const file_type type, std::string& err_msg); + + template inline static bool save(const field< Col >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool save(const field< Col >& x, std::ostream& os, const file_type type, std::string& err_msg); + template inline static bool load( field< Col >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool load( field< Col >& x, std::istream& is, const file_type type, std::string& err_msg); + + template inline static bool save(const field< Row >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool save(const field< Row >& x, std::ostream& os, const file_type type, std::string& err_msg); + template inline static bool load( field< Row >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool load( field< Row >& x, std::istream& is, const file_type type, std::string& err_msg); + + template inline static bool save(const field< Cube >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool save(const field< Cube >& x, std::ostream& os, const file_type type, std::string& err_msg); + template inline static bool load( field< Cube >& x, const std::string& name, const file_type type, std::string& err_msg); + template inline static bool load( field< Cube >& x, std::istream& is, const file_type type, std::string& err_msg); + + inline static bool save(const field< std::string >& x, const std::string& name, const file_type type, std::string& err_msg); + inline static bool save(const field< std::string >& x, std::ostream& os, const file_type type, std::string& err_msg); + inline static bool load( field< std::string >& x, const std::string& name, const file_type type, std::string& err_msg); + inline static bool load( field< std::string >& x, std::istream& is, const file_type type, std::string& err_msg); + + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/field_meat.hpp b/src/armadillo/include/armadillo_bits/field_meat.hpp new file mode 100644 index 0000000..b65bb84 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/field_meat.hpp @@ -0,0 +1,2999 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup field +//! @{ + + +template +inline +field::~field() + { + arma_extra_debug_sigprint_this(this); + + delete_objects(); + + if(n_elem > field_prealloc_n_elem::val) { delete [] mem; } + + // try to expose buggy user code that accesses deleted objects + if(arma_config::debug) { mem = nullptr; } + } + + + +template +inline +field::field() + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + } + + + +//! construct a field from a given field +template +inline +field::field(const field& x) + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint(arma_str::format("this = %x x = %x") % this % &x); + + init(x); + } + + + +//! construct a field from a given field +template +inline +field& +field::operator=(const field& x) + { + arma_extra_debug_sigprint(); + + init(x); + + return *this; + } + + + +//! construct a field from subview_field (eg. construct a field from a delayed subfield operation) +template +inline +field::field(const subview_field& X) + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + this->operator=(X); + } + + + +//! construct a field from subview_field (eg. construct a field from a delayed subfield operation) +template +inline +field& +field::operator=(const subview_field& X) + { + arma_extra_debug_sigprint(); + + subview_field::extract(*this, X); + + return *this; + } + + + +//! construct the field with the specified number of elements, +//! assuming a column-major layout +template +inline +field::field(const uword n_elem_in) + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(n_elem_in, 1); + } + + + +//! construct the field with the specified dimensions +template +inline +field::field(const uword n_rows_in, const uword n_cols_in) + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(n_rows_in, n_cols_in); + } + + + +//! construct the field with the specified dimensions +template +inline +field::field(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in) + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(n_rows_in, n_cols_in, n_slices_in); + } + + + +template +inline +field::field(const SizeMat& s) + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(s.n_rows, s.n_cols); + } + + + +template +inline +field::field(const SizeCube& s) + : n_rows(0) + , n_cols(0) + , n_slices(0) + , n_elem(0) + , mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + init(s.n_rows, s.n_cols, s.n_slices); + } + + + +//! change the field to have the specified number of elements, +//! assuming a column-major layout (data is not preserved) +template +inline +field& +field::set_size(const uword n_elem_in) + { + arma_extra_debug_sigprint(arma_str::format("n_elem_in = %u") % n_elem_in); + + init(n_elem_in, 1); + + return *this; + } + + + +//! change the field to have the specified dimensions (data is not preserved) +template +inline +field& +field::set_size(const uword n_rows_in, const uword n_cols_in) + { + arma_extra_debug_sigprint(arma_str::format("n_rows_in = %u, n_cols_in = %u") % n_rows_in % n_cols_in); + + init(n_rows_in, n_cols_in); + + return *this; + } + + + +//! change the field to have the specified dimensions (data is not preserved) +template +inline +field& +field::set_size(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in) + { + arma_extra_debug_sigprint(arma_str::format("n_rows_in = %u, n_cols_in = %u, n_slices_in = %u") % n_rows_in % n_cols_in % n_slices_in); + + init(n_rows_in, n_cols_in, n_slices_in); + + return *this; + } + + + +template +inline +field& +field::set_size(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + init(s.n_rows, s.n_cols); + + return *this; + } + + + +template +inline +field& +field::set_size(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + init(s.n_rows, s.n_cols, s.n_slices); + + return *this; + } + + + +template +inline +field::field(const std::vector& x) + : n_rows (0) + , n_cols (0) + , n_slices(0) + , n_elem (0) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(x); + } + + + +template +inline +field& +field::operator=(const std::vector& x) + { + arma_extra_debug_sigprint(); + + const uword N = uword(x.size()); + + set_size(N, 1); + + for(uword i=0; i +inline +field::field(const std::initializer_list& list) + : n_rows (0) + , n_cols (0) + , n_slices(0) + , n_elem (0) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(list); + } + + + +template +inline +field& +field::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + set_size(1, N); + + const oT* item_ptr = list.begin(); + + for(uword i=0; i +inline +field::field(const std::initializer_list< std::initializer_list >& list) + : n_rows (0) + , n_cols (0) + , n_slices(0) + , n_elem (0) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(list); + } + + + +template +inline +field& +field::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + uword x_n_rows = uword(list.size()); + uword x_n_cols = 0; + + auto it = list.begin(); + auto it_end = list.end(); + + for(; it != it_end; ++it) { x_n_cols = (std::max)(x_n_cols, uword((*it).size())); } + + field& t = (*this); + + t.set_size(x_n_rows, x_n_cols); + + uword row_num = 0; + + auto row_it = list.begin(); + auto row_it_end = list.end(); + + for(; row_it != row_it_end; ++row_it) + { + uword col_num = 0; + + auto col_it = (*row_it).begin(); + auto col_it_end = (*row_it).end(); + + for(; col_it != col_it_end; ++col_it) + { + t.at(row_num, col_num) = (*col_it); + ++col_num; + } + + for(uword c=col_num; c < x_n_cols; ++c) + { + t.at(row_num, c) = oT(); + } + + ++row_num; + } + + return *this; + } + + + +template +inline +field::field(field&& X) + : n_rows (X.n_rows ) + , n_cols (X.n_cols ) + , n_slices(X.n_slices) + , n_elem (X.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + if(n_elem > field_prealloc_n_elem::val) + { + mem = X.mem; + } + else + { + arrayops::copy(&mem_local[0], &X.mem_local[0], n_elem); + mem = mem_local; + } + + access::rw(X.n_rows ) = 0; + access::rw(X.n_cols ) = 0; + access::rw(X.n_slices) = 0; + access::rw(X.n_elem ) = 0; + access::rw(X.mem ) = nullptr; + } + + + +template +inline +field& +field::operator=(field&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + if(this == &X) { return *this; } + + reset(); + + access::rw(n_rows ) = X.n_rows; + access::rw(n_cols ) = X.n_cols; + access::rw(n_slices) = X.n_slices; + access::rw(n_elem ) = X.n_elem; + + if(n_elem > field_prealloc_n_elem::val) + { + mem = X.mem; + } + else + { + arrayops::copy(&mem_local[0], &X.mem_local[0], n_elem); + mem = mem_local; + } + + access::rw(X.n_rows ) = 0; + access::rw(X.n_cols ) = 0; + access::rw(X.n_elem ) = 0; + access::rw(X.n_slices) = 0; + access::rw(X.mem ) = nullptr; + + return *this; + } + + + +//! change the field to have the specified dimensions (data is not preserved) +template +template +inline +field& +field::copy_size(const field& x) + { + arma_extra_debug_sigprint(); + + init(x.n_rows, x.n_cols, x.n_slices); + + return *this; + } + + + +//! linear element accessor (treats the field as a vector); no bounds check +template +arma_inline +oT& +field::operator[] (const uword i) + { + return (*mem[i]); + } + + + +//! linear element accessor (treats the field as a vector); no bounds check +template +arma_inline +const oT& +field::operator[] (const uword i) const + { + return (*mem[i]); + } + + + +//! linear element accessor (treats the field as a vector); no bounds check +template +arma_inline +oT& +field::at(const uword i) + { + return (*mem[i]); + } + + + +//! linear element accessor (treats the field as a vector); no bounds check +template +arma_inline +const oT& +field::at(const uword i) const + { + return (*mem[i]); + } + + + +//! linear element accessor (treats the field as a vector); bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +oT& +field::operator() (const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "field::operator(): index out of bounds" ); + + return (*mem[i]); + } + + + +//! linear element accessor (treats the field as a vector); bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +const oT& +field::operator() (const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "field::operator(): index out of bounds" ); + + return (*mem[i]); + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +oT& +field::operator() (const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (0 >= n_slices) ), "field::operator(): index out of bounds" ); + + return (*mem[in_row + in_col*n_rows]); + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +const oT& +field::operator() (const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (0 >= n_slices) ), "field::operator(): index out of bounds" ); + + return (*mem[in_row + in_col*n_rows]); + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +oT& +field::operator() (const uword in_row, const uword in_col, const uword in_slice) + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "field::operator(): index out of bounds" ); + + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + + + +//! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined +template +arma_inline +const oT& +field::operator() (const uword in_row, const uword in_col, const uword in_slice) const + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "field::operator(): index out of bounds" ); + + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + + + +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + oT& + field::operator[] (const uword in_row, const uword in_col) + { + return (*mem[in_row + in_col*n_rows]); + } + + + + //! element accessor; no bounds check + template + arma_inline + const oT& + field::operator[] (const uword in_row, const uword in_col) const + { + return (*mem[in_row + in_col*n_rows]); + } + +#endif + + + +//! element accessor; no bounds check +template +arma_inline +oT& +field::at(const uword in_row, const uword in_col) + { + return (*mem[in_row + in_col*n_rows]); + } + + + +//! element accessor; no bounds check +template +arma_inline +const oT& +field::at(const uword in_row, const uword in_col) const + { + return (*mem[in_row + in_col*n_rows]); + } + + + +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + oT& + field::operator[] (const uword in_row, const uword in_col, const uword in_slice) + { + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + + + + //! element accessor; no bounds check + template + arma_inline + const oT& + field::operator[] (const uword in_row, const uword in_col, const uword in_slice) const + { + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + +#endif + + + +//! element accessor; no bounds check +template +arma_inline +oT& +field::at(const uword in_row, const uword in_col, const uword in_slice) + { + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + + + +//! element accessor; no bounds check +template +arma_inline +const oT& +field::at(const uword in_row, const uword in_col, const uword in_slice) const + { + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + + + +template +arma_inline +oT& +field::front() + { + arma_debug_check( (n_elem == 0), "field::front(): field is empty" ); + + return (*mem[0]); + } + + + +template +arma_inline +const oT& +field::front() const + { + arma_debug_check( (n_elem == 0), "field::front(): field is empty" ); + + return (*mem[0]); + } + + + +template +arma_inline +oT& +field::back() + { + arma_debug_check( (n_elem == 0), "field::back(): field is empty" ); + + return (*mem[n_elem-1]); + } + + + +template +arma_inline +const oT& +field::back() const + { + arma_debug_check( (n_elem == 0), "field::back(): field is empty" ); + + return (*mem[n_elem-1]); + } + + + +template +inline +field_injector< field > +field::operator<<(const oT& val) + { + return field_injector< field >(*this, val); + } + + + +template +inline +field_injector< field > +field::operator<<(const injector_end_of_row<>& x) + { + return field_injector< field >(*this, x); + } + + + +//! creation of subview_field (row of a field) +template +inline +subview_field +field::row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::row(): field must be 2D" ); + + arma_debug_check_bounds( (row_num >= n_rows), "field::row(): row out of bounds" ); + + return subview_field(*this, row_num, 0, 1, n_cols); + } + + + +//! creation of subview_field (row of a field) +template +inline +const subview_field +field::row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::row(): field must be 2D" ); + + arma_debug_check_bounds( (row_num >= n_rows), "field::row(): row out of bounds" ); + + return subview_field(*this, row_num, 0, 1, n_cols); + } + + + +//! creation of subview_field (column of a field) +template +inline +subview_field +field::col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::col(): field must be 2D" ); + + arma_debug_check_bounds( (col_num >= n_cols), "field::col(): out of bounds" ); + + return subview_field(*this, 0, col_num, n_rows, 1); + } + + + +//! creation of subview_field (column of a field) +template +inline +const subview_field +field::col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::col(): field must be 2D" ); + + arma_debug_check_bounds( (col_num >= n_cols), "field::col(): out of bounds" ); + + return subview_field(*this, 0, col_num, n_rows, 1); + } + + + +//! creation of subview_field (slice of a field) +template +inline +subview_field +field::slice(const uword slice_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (slice_num >= n_slices), "field::slice(): out of bounds" ); + + return subview_field(*this, 0, 0, slice_num, n_rows, n_cols, 1); + } + + + +//! creation of subview_field (slice of a field) +template +inline +const subview_field +field::slice(const uword slice_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (slice_num >= n_slices), "field::slice(): out of bounds" ); + + return subview_field(*this, 0, 0, slice_num, n_rows, n_cols, 1); + } + + + +//! creation of subview_field (subfield comprised of specified rows) +template +inline +subview_field +field::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::rows(): field must be 2D" ); + + arma_debug_check_bounds + ( + ( (in_row1 > in_row2) || (in_row2 >= n_rows) ), + "field::rows(): indicies out of bounds or incorrectly used" + ); + + const uword sub_n_rows = in_row2 - in_row1 + 1; + + return subview_field(*this, in_row1, 0, sub_n_rows, n_cols); + } + + + +//! creation of subview_field (subfield comprised of specified rows) +template +inline +const subview_field +field::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::rows(): field must be 2D" ); + + arma_debug_check_bounds + ( + ( (in_row1 > in_row2) || (in_row2 >= n_rows) ), + "field::rows(): indicies out of bounds or incorrectly used" + ); + + const uword sub_n_rows = in_row2 - in_row1 + 1; + + return subview_field(*this, in_row1, 0, sub_n_rows, n_cols); + } + + + +//! creation of subview_field (subfield comprised of specified columns) +template +inline +subview_field +field::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::cols(): field must be 2D" ); + + arma_debug_check_bounds + ( + ( (in_col1 > in_col2) || (in_col2 >= n_cols) ), + "field::cols(): indicies out of bounds or incorrectly used" + ); + + const uword sub_n_cols = in_col2 - in_col1 + 1; + + return subview_field(*this, 0, in_col1, n_rows, sub_n_cols); + } + + + +//! creation of subview_field (subfield comprised of specified columns) +template +inline +const subview_field +field::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::cols(): field must be 2D" ); + + arma_debug_check_bounds + ( + ( (in_col1 > in_col2) || (in_col2 >= n_cols) ), + "field::cols(): indicies out of bounds or incorrectly used" + ); + + const uword sub_n_cols = in_col2 - in_col1 + 1; + + return subview_field(*this, 0, in_col1, n_rows, sub_n_cols); + } + + + +//! creation of subview_field (subfield comprised of specified slices) +template +inline +subview_field +field::slices(const uword in_slice1, const uword in_slice2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + ( (in_slice1 > in_slice2) || (in_slice2 >= n_slices) ), + "field::slices(): indicies out of bounds or incorrectly used" + ); + + const uword sub_n_slices = in_slice2 - in_slice1 + 1; + + return subview_field(*this, 0, 0, in_slice1, n_rows, n_cols, sub_n_slices); + } + + + +//! creation of subview_field (subfield comprised of specified slices) +template +inline +const subview_field +field::slices(const uword in_slice1, const uword in_slice2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + ( (in_slice1 > in_slice2) || (in_slice2 >= n_slices) ), + "field::slices(): indicies out of bounds or incorrectly used" + ); + + const uword sub_n_slices = in_slice2 - in_slice1 + 1; + + return subview_field(*this, 0, 0, in_slice1, n_rows, n_cols, sub_n_slices); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +subview_field +field::subfield(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "field::subfield(): indices out of bounds or incorrectly used" + ); + + const uword sub_n_rows = in_row2 - in_row1 + 1; + const uword sub_n_cols = in_col2 - in_col1 + 1; + + return subview_field(*this, in_row1, in_col1, sub_n_rows, sub_n_cols); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +const subview_field +field::subfield(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "field::subfield(): indices out of bounds or incorrectly used" + ); + + const uword sub_n_rows = in_row2 - in_row1 + 1; + const uword sub_n_cols = in_col2 - in_col1 + 1; + + return subview_field(*this, in_row1, in_col1, sub_n_rows, sub_n_cols); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +subview_field +field::subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), + "field::subfield(): indices out of bounds or incorrectly used" + ); + + const uword sub_n_rows = in_row2 - in_row1 + 1; + const uword sub_n_cols = in_col2 - in_col1 + 1; + const uword sub_n_slices = in_slice2 - in_slice1 + 1; + + return subview_field(*this, in_row1, in_col1, in_slice1, sub_n_rows, sub_n_cols, sub_n_slices); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +const subview_field +field::subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_row2, const uword in_col2, const uword in_slice2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), + "field::subfield(): indices out of bounds or incorrectly used" + ); + + const uword sub_n_rows = in_row2 - in_row1 + 1; + const uword sub_n_cols = in_col2 - in_col1 + 1; + const uword sub_n_slices = in_slice2 - in_slice1 + 1; + + return subview_field(*this, in_row1, in_col1, in_slice1, sub_n_rows, sub_n_cols, sub_n_slices); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +subview_field +field::subfield(const uword in_row1, const uword in_col1, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "field::subfield(): indices or size out of bounds" + ); + + return subview_field(*this, in_row1, in_col1, s_n_rows, s_n_cols); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +const subview_field +field::subfield(const uword in_row1, const uword in_col1, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), + "field::subfield(): indices or size out of bounds" + ); + + return subview_field(*this, in_row1, in_col1, s_n_rows, s_n_cols); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +subview_field +field::subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + const uword l_n_slices = n_slices; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + const uword sub_n_slices = s.n_slices; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || (in_slice1 >= l_n_slices) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + sub_n_slices) > l_n_slices)), + "field::subfield(): indices or size out of bounds" + ); + + return subview_field(*this, in_row1, in_col1, in_slice1, s_n_rows, s_n_cols, sub_n_slices); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +const subview_field +field::subfield(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const + { + arma_extra_debug_sigprint(); + + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + const uword l_n_slices = n_slices; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + const uword sub_n_slices = s.n_slices; + + arma_debug_check_bounds + ( + ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || (in_slice1 >= l_n_slices) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + sub_n_slices) > l_n_slices)), + "field::subfield(): indices or size out of bounds" + ); + + return subview_field(*this, in_row1, in_col1, in_slice1, s_n_rows, s_n_cols, sub_n_slices); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +subview_field +field::subfield(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword sub_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword sub_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "field::subfield(): indices out of bounds or incorrectly used" + ); + + return subview_field(*this, in_row1, in_col1, sub_n_rows, sub_n_cols); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +const subview_field +field::subfield(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword sub_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword sub_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "field::subfield(): indices out of bounds or incorrectly used" + ); + + return subview_field(*this, in_row1, in_col1, sub_n_rows, sub_n_cols); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +subview_field +field::subfield(const span& row_span, const span& col_span, const span& slice_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + const bool slice_all = slice_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword sub_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword sub_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + const uword in_slice1 = slice_all ? 0 : slice_span.a; + const uword in_slice2 = slice_span.b; + const uword sub_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + || + ( slice_all ? false : ((in_slice1 > in_slice2) || (in_slice2 >= local_n_slices)) ) + , + "field::subfield(): indices out of bounds or incorrectly used" + ); + + return subview_field(*this, in_row1, in_col1, in_slice1, sub_n_rows, sub_n_cols, sub_n_slices); + } + + + +//! creation of subview_field (subfield with arbitrary dimensions) +template +inline +const subview_field +field::subfield(const span& row_span, const span& col_span, const span& slice_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + const bool slice_all = slice_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword sub_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword sub_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + const uword in_slice1 = slice_all ? 0 : slice_span.a; + const uword in_slice2 = slice_span.b; + const uword sub_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + || + ( slice_all ? false : ((in_slice1 > in_slice2) || (in_slice2 >= local_n_slices)) ) + , + "field::subfield(): indices out of bounds or incorrectly used" + ); + + return subview_field(*this, in_row1, in_col1, in_slice1, sub_n_rows, sub_n_cols, sub_n_slices); + } + + + +template +inline +subview_field +field::operator()(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + return (*this).subfield(row_span, col_span); + } + + + +template +inline +const subview_field +field::operator()(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + return (*this).subfield(row_span, col_span); + } + + + +template +inline +subview_field +field::operator()(const span& row_span, const span& col_span, const span& slice_span) + { + arma_extra_debug_sigprint(); + + return (*this).subfield(row_span, col_span, slice_span); + } + + + +template +inline +const subview_field +field::operator()(const span& row_span, const span& col_span, const span& slice_span) const + { + arma_extra_debug_sigprint(); + + return (*this).subfield(row_span, col_span, slice_span); + } + + + +template +inline +subview_field +field::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return (*this).subfield(in_row1, in_col1, s); + } + + + +template +inline +const subview_field +field::operator()(const uword in_row1, const uword in_col1, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + return (*this).subfield(in_row1, in_col1, s); + } + + + +template +inline +subview_field +field::operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return (*this).subfield(in_row1, in_col1, in_slice1, s); + } + + + +template +inline +const subview_field +field::operator()(const uword in_row1, const uword in_col1, const uword in_slice1, const SizeCube& s) const + { + arma_extra_debug_sigprint(); + + return (*this).subfield(in_row1, in_col1, in_slice1, s); + } + + + +//! print contents of the field (to the cout stream), +//! optionally preceding with a user specified line of text. +//! the field class preserves the stream's flags +//! but the associated operator<< function for type oT +//! may still modify the stream's parameters. +//! NOTE: this function assumes that type oT can be printed, +//! ie. the function "std::ostream& operator<< (std::ostream&, const oT&)" +//! has been defined. + +template +inline +void +field::print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), *this); + } + + + +//! print contents of the field to a user specified stream, +//! optionally preceding with a user specified line of text. +//! the field class preserves the stream's flags +//! but the associated operator<< function for type oT +//! may still modify the stream's parameters. +//! NOTE: this function assumes that type oT can be printed, +//! ie. the function "std::ostream& operator<< (std::ostream&, const oT&)" +//! has been defined. + +template +inline +void +field::print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, *this); + } + + + +//! apply a lambda function to each object +template +inline +field& +field::for_each(const std::function< void(oT&) >& F) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i < n_elem; ++i) { F(operator[](i)); } + + return *this; + } + + + +template +inline +const field& +field::for_each(const std::function< void(const oT&) >& F) const + { + arma_extra_debug_sigprint(); + + for(uword i=0; i < n_elem; ++i) { F(operator[](i)); } + + return *this; + } + + + +//! fill the field with an object +template +inline +field& +field::fill(const oT& x) + { + arma_extra_debug_sigprint(); + + field& t = *this; + + for(uword i=0; i +inline +void +field::reset() + { + arma_extra_debug_sigprint(); + + init(0,0,0); + } + + + +//! reset each object +template +inline +void +field::reset_objects() + { + arma_extra_debug_sigprint(); + + field_aux::reset_objects(*this); + } + + + +//! returns true if the field has no objects +template +arma_inline +bool +field::is_empty() const + { + return (n_elem == 0); + } + + + +//! returns true if the given index is currently in range +template +arma_inline +bool +field::in_range(const uword i) const + { + return (i < n_elem); + } + + + +//! returns true if the given start and end indices are currently in range +template +arma_inline +bool +field::in_range(const span& x) const + { + arma_extra_debug_sigprint(); + + if(x.whole) + { + return true; + } + else + { + const uword a = x.a; + const uword b = x.b; + + return ( (a <= b) && (b < n_elem) ); + } + } + + + +//! returns true if the given location is currently in range +template +arma_inline +bool +field::in_range(const uword in_row, const uword in_col) const + { + return ( (in_row < n_rows) && (in_col < n_cols) ); + } + + + +template +arma_inline +bool +field::in_range(const span& row_span, const uword in_col) const + { + arma_extra_debug_sigprint(); + + if(row_span.whole) + { + return (in_col < n_cols); + } + else + { + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) ); + } + } + + + +template +arma_inline +bool +field::in_range(const uword in_row, const span& col_span) const + { + arma_extra_debug_sigprint(); + + if(col_span.whole) + { + return (in_row < n_rows); + } + else + { + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) ); + } + } + + + +template +arma_inline +bool +field::in_range(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) ); + const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) ); + + return ( rows_ok && cols_ok ); + } + + + +template +arma_inline +bool +field::in_range(const uword in_row, const uword in_col, const SizeMat& s) const + { + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + + if( (in_row >= l_n_rows) || (in_col >= l_n_cols) || ((in_row + s.n_rows) > l_n_rows) || ((in_col + s.n_cols) > l_n_cols) ) + { + return false; + } + else + { + return true; + } + } + + + +template +arma_inline +bool +field::in_range(const uword in_row, const uword in_col, const uword in_slice) const + { + return ( (in_row < n_rows) && (in_col < n_cols) && (in_slice < n_slices) ); + } + + + +template +arma_inline +bool +field::in_range(const span& row_span, const span& col_span, const span& slice_span) const + { + arma_extra_debug_sigprint(); + + const uword in_row1 = row_span.a; + const uword in_row2 = row_span.b; + + const uword in_col1 = col_span.a; + const uword in_col2 = col_span.b; + + const uword in_slice1 = slice_span.a; + const uword in_slice2 = slice_span.b; + + const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2 ) && (in_row2 < n_rows ) ); + const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2 ) && (in_col2 < n_cols ) ); + const bool slices_ok = slice_span.whole ? true : ( (in_slice1 <= in_slice2) && (in_slice2 < n_slices) ); + + return ( rows_ok && cols_ok && slices_ok ); + } + + + +template +arma_inline +bool +field::in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const + { + const uword l_n_rows = n_rows; + const uword l_n_cols = n_cols; + const uword l_n_slices = n_slices; + + if( (in_row >= l_n_rows) || (in_col >= l_n_cols) || (in_slice >= l_n_slices) || ((in_row + s.n_rows) > l_n_rows) || ((in_col + s.n_cols) > l_n_cols) || ((in_slice + s.n_slices) > l_n_slices) ) + { + return false; + } + else + { + return true; + } + } + + + +template +inline +bool +field::save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + std::string err_msg; + + const bool save_okay = field_aux::save(*this, name, type, err_msg); + + if(save_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "field::save(): ", err_msg, "; file: ", name); + } + else + { + arma_debug_warn_level(3, "field::save(): write failed; file: ", name); + } + } + + return save_okay; + } + + + +template +inline +bool +field::save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + std::string err_msg; + + const bool save_okay = field_aux::save(*this, os, type, err_msg); + + if(save_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "field::save(): ", err_msg); + } + else + { + arma_debug_warn_level(3, "field::save(): stream write failed"); + } + } + + return save_okay; + } + + + +template +inline +bool +field::load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + std::string err_msg; + + const bool load_okay = field_aux::load(*this, name, type, err_msg); + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "field::load(): ", err_msg, "; file: ", name); + } + else + { + arma_debug_warn_level(3, "field::load(): read failed; file: ", name); + } + } + + if(load_okay == false) { (*this).reset(); } + + return load_okay; + } + + + +template +inline +bool +field::load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + std::string err_msg; + const bool load_okay = field_aux::load(*this, is, type, err_msg); + + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "field::load(): ", err_msg); + } + else + { + arma_debug_warn_level(3, "field::load(): stream read failed"); + } + } + + if(load_okay == false) { (*this).reset(); } + + return load_okay; + } + + + +template +inline +bool +field::quiet_save(const std::string name, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(name, type); + } + + + +template +inline +bool +field::quiet_save(std::ostream& os, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(os, type); + } + + + +template +inline +bool +field::quiet_load(const std::string name, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(name, type); + } + + + +template +inline +bool +field::quiet_load(std::istream& is, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(is, type); + } + + + +//! construct a field from a given field +template +inline +void +field::init(const field& x) + { + arma_extra_debug_sigprint(); + + if(this != &x) + { + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + const uword x_n_slices = x.n_slices; + + init(x_n_rows, x_n_cols, x_n_slices); + + field& t = *this; + + if(x_n_slices == 1) + { + for(uword ucol=0; ucol < x_n_cols; ++ucol) + for(uword urow=0; urow < x_n_rows; ++urow) + { + t.at(urow,ucol) = x.at(urow,ucol); + } + } + else + { + for(uword uslice=0; uslice < x_n_slices; ++uslice) + for(uword ucol=0; ucol < x_n_cols; ++ucol ) + for(uword urow=0; urow < x_n_rows; ++urow ) + { + t.at(urow,ucol,uslice) = x.at(urow,ucol,uslice); + } + } + } + } + + + +template +inline +void +field::init(const uword n_rows_in, const uword n_cols_in) + { + (*this).init(n_rows_in, n_cols_in, 1); + } + + + +template +inline +void +field::init(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in) + { + arma_extra_debug_sigprint( arma_str::format("n_rows_in = %u, n_cols_in = %u, n_slices_in = %u") % n_rows_in % n_cols_in % n_slices_in ); + + #if defined(ARMA_64BIT_WORD) + const char* error_message = "field::init(): requested size is too large"; + #else + const char* error_message = "field::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; + #endif + + arma_debug_check + ( + ( + ( (n_rows_in > 0x0FFF) || (n_cols_in > 0x0FFF) || (n_slices_in > 0xFF) ) + ? ( (double(n_rows_in) * double(n_cols_in) * double(n_slices_in)) > double(ARMA_MAX_UWORD) ) + : false + ), + error_message + ); + + const uword n_elem_new = n_rows_in * n_cols_in * n_slices_in; + + if(n_elem == n_elem_new) + { + // delete_objects(); + // create_objects(); + access::rw(n_rows) = n_rows_in; + access::rw(n_cols) = n_cols_in; + access::rw(n_slices) = n_slices_in; + } + else + { + delete_objects(); + + if(n_elem > field_prealloc_n_elem::val) { delete [] mem; } + + if(n_elem_new <= field_prealloc_n_elem::val) + { + mem = (n_elem_new == 0) ? nullptr : mem_local; + } + else + { + mem = new(std::nothrow) oT* [n_elem_new]; + + arma_check_bad_alloc( (mem == nullptr), "field::init(): out of memory" ); + } + + access::rw(n_rows) = n_rows_in; + access::rw(n_cols) = n_cols_in; + access::rw(n_slices) = n_slices_in; + access::rw(n_elem) = n_elem_new; + + create_objects(); + } + } + + + +template +inline +void +field::delete_objects() + { + arma_extra_debug_sigprint( arma_str::format("n_elem = %u") % n_elem ); + + for(uword i=0; i +inline +void +field::create_objects() + { + arma_extra_debug_sigprint( arma_str::format("n_elem = %u") % n_elem ); + + for(uword i=0; i +inline +field::iterator::iterator(field& in_M, const bool at_end) + : M(in_M) + , i( (at_end == false) ? 0 : in_M.n_elem ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +oT& +field::iterator::operator*() + { + return M[i]; + } + + + +template +inline +typename field::iterator& +field::iterator::operator++() + { + ++i; + + return *this; + } + + + +template +inline +void +field::iterator::operator++(int) + { + operator++(); + } + + + +template +inline +typename field::iterator& +field::iterator::operator--() + { + if(i > 0) { --i; } + + return *this; + } + + + +template +inline +void +field::iterator::operator--(int) + { + operator--(); + } + + + +template +inline +bool +field::iterator::operator!=(const typename field::iterator& X) const + { + return (i != X.i); + } + + + +template +inline +bool +field::iterator::operator==(const typename field::iterator& X) const + { + return (i == X.i); + } + + + +template +inline +field::const_iterator::const_iterator(const field& in_M, const bool at_end) + : M(in_M) + , i( (at_end == false) ? 0 : in_M.n_elem ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +field::const_iterator::const_iterator(const typename field::iterator& X) + : M(X.M) + , i(X.i) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +const oT& +field::const_iterator::operator*() const + { + return M[i]; + } + + + +template +inline +typename field::const_iterator& +field::const_iterator::operator++() + { + ++i; + + return *this; + } + + + +template +inline +void +field::const_iterator::operator++(int) + { + operator++(); + } + + + +template +inline +typename field::const_iterator& +field::const_iterator::operator--() + { + if(i > 0) { --i; } + + return *this; + } + + + +template +inline +void +field::const_iterator::operator--(int) + { + operator--(); + } + + + +template +inline +bool +field::const_iterator::operator!=(const typename field::const_iterator& X) const + { + return (i != X.i); + } + + + +template +inline +bool +field::const_iterator::operator==(const typename field::const_iterator& X) const + { + return (i == X.i); + } + + + +template +inline +typename field::iterator +field::begin() + { + arma_extra_debug_sigprint(); + + return field::iterator(*this); + } + + + +template +inline +typename field::const_iterator +field::begin() const + { + arma_extra_debug_sigprint(); + + return field::const_iterator(*this); + } + + + +template +inline +typename field::const_iterator +field::cbegin() const + { + arma_extra_debug_sigprint(); + + return field::const_iterator(*this); + } + + + +template +inline +typename field::iterator +field::end() + { + arma_extra_debug_sigprint(); + + return field::iterator(*this, true); + } + + + +template +inline +typename field::const_iterator +field::end() const + { + arma_extra_debug_sigprint(); + + return field::const_iterator(*this, true); + } + + + +template +inline +typename field::const_iterator +field::cend() const + { + arma_extra_debug_sigprint(); + + return field::const_iterator(*this, true); + } + + + +template +inline +void +field::clear() + { + reset(); + } + + + +template +inline +bool +field::empty() const + { + return (n_elem == 0); + } + + + +template +inline +uword +field::size() const + { + return n_elem; + } + + + +// +// +// + + + +template +inline +void +field_aux::reset_objects(field& x) + { + arma_extra_debug_sigprint(); + + x.delete_objects(); + x.create_objects(); + } + + + +template +inline +void +field_aux::reset_objects(field< Mat >& x) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i < x.n_elem; ++i) { (*(x.mem[i])).reset(); } + } + + + +template +inline +void +field_aux::reset_objects(field< Col >& x) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i < x.n_elem; ++i) { (*(x.mem[i])).reset(); } + } + + + +template +inline +void +field_aux::reset_objects(field< Row >& x) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i < x.n_elem; ++i) { (*(x.mem[i])).reset(); } + } + + + +template +inline +void +field_aux::reset_objects(field< Cube >& x) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i < x.n_elem; ++i) { (*(x.mem[i])).reset(); } + } + + + +inline +void +field_aux::reset_objects(field< std::string >& x) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i < x.n_elem; ++i) { (*(x.mem[i])).clear(); } + } + + + +// +// +// + + + +template +inline +bool +field_aux::save(const field&, const std::string&, const file_type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + err_msg = "saving/loading this type of field is currently not supported"; + + return false; + } + + + +template +inline +bool +field_aux::save(const field&, std::ostream&, const file_type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + err_msg = "saving/loading this type of field is currently not supported"; + + return false; + } + + + +template +inline +bool +field_aux::load(field&, const std::string&, const file_type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + err_msg = "saving/loading this type of field is currently not supported"; + + return false; + } + + + +template +inline +bool +field_aux::load(field&, std::istream&, const file_type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + err_msg = "saving/loading this type of field is currently not supported"; + + return false; + } + + + +template +inline +bool +field_aux::save(const field< Mat >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, name); + break; + + case ppm_binary: + return diskio::save_ppm_binary(x, name); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::save(const field< Mat >& x, std::ostream& os, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, os); + break; + + case ppm_binary: + return diskio::save_ppm_binary(x, os); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Mat >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + return diskio::load_auto_detect(x, name, err_msg); + break; + + case arma_binary: + return diskio::load_arma_binary(x, name, err_msg); + break; + + case ppm_binary: + return diskio::load_ppm_binary(x, name, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Mat >& x, std::istream& is, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + return diskio::load_auto_detect(x, is, err_msg); + break; + + case arma_binary: + return diskio::load_arma_binary(x, is, err_msg); + break; + + case ppm_binary: + return diskio::load_ppm_binary(x, is, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::save(const field< Col >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, name); + break; + + case ppm_binary: + return diskio::save_ppm_binary(x, name); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::save(const field< Col >& x, std::ostream& os, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, os); + break; + + case ppm_binary: + return diskio::save_ppm_binary(x, os); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Col >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + return diskio::load_auto_detect(x, name, err_msg); + break; + + case arma_binary: + return diskio::load_arma_binary(x, name, err_msg); + break; + + case ppm_binary: + return diskio::load_ppm_binary(x, name, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Col >& x, std::istream& is, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + return diskio::load_auto_detect(x, is, err_msg); + break; + + case arma_binary: + return diskio::load_arma_binary(x, is, err_msg); + break; + + case ppm_binary: + return diskio::load_ppm_binary(x, is, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::save(const field< Row >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, name); + break; + + case ppm_binary: + return diskio::save_ppm_binary(x, name); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::save(const field< Row >& x, std::ostream& os, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, os); + break; + + case ppm_binary: + return diskio::save_ppm_binary(x, os); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Row >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + return diskio::load_auto_detect(x, name, err_msg); + break; + + case arma_binary: + return diskio::load_arma_binary(x, name, err_msg); + break; + + case ppm_binary: + return diskio::load_ppm_binary(x, name, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Row >& x, std::istream& is, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + return diskio::load_auto_detect(x, is, err_msg); + break; + + case arma_binary: + return diskio::load_arma_binary(x, is, err_msg); + break; + + case ppm_binary: + return diskio::load_ppm_binary(x, is, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::save(const field< Cube >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, name); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::save(const field< Cube >& x, std::ostream& os, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case arma_binary: + return diskio::save_arma_binary(x, os); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Cube >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + case arma_binary: + return diskio::load_arma_binary(x, name, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +template +inline +bool +field_aux::load(field< Cube >& x, std::istream& is, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + switch(type) + { + case auto_detect: + case arma_binary: + return diskio::load_arma_binary(x, is, err_msg); + break; + + default: + err_msg = "unsupported type"; + return false; + } + } + + + +inline +bool +field_aux::save(const field< std::string >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + arma_ignore(type); + + err_msg.clear(); + + return diskio::save_std_string(x, name); + } + + + +inline +bool +field_aux::save(const field< std::string >& x, std::ostream& os, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + arma_ignore(type); + + err_msg.clear(); + + return diskio::save_std_string(x, os); + } + + + +inline +bool +field_aux::load(field< std::string >& x, const std::string& name, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + arma_ignore(type); + + return diskio::load_std_string(x, name, err_msg); + } + + + +inline +bool +field_aux::load(field< std::string >& x, std::istream& is, const file_type type, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + arma_ignore(type); + + return diskio::load_std_string(x, is, err_msg); + } + + + +#if defined(ARMA_EXTRA_FIELD_MEAT) + #include ARMA_INCFILE_WRAP(ARMA_EXTRA_FIELD_MEAT) +#endif + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fill.hpp b/src/armadillo/include/armadillo_bits/fill.hpp new file mode 100644 index 0000000..8b41097 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fill.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fill +//! @{ + + +namespace fill + { + struct fill_none {}; + struct fill_zeros {}; + struct fill_ones {}; + struct fill_eye {}; + struct fill_randu {}; + struct fill_randn {}; + + template + struct fill_class { inline constexpr fill_class() {} }; + + static constexpr fill_class none; + static constexpr fill_class zeros; + static constexpr fill_class ones; + static constexpr fill_class eye; + static constexpr fill_class randu; + static constexpr fill_class randn; + + // + + template + struct allow_conversion + { + static constexpr bool value = true; + }; + + template<> struct allow_conversion, double> { static constexpr bool value = false; }; + template<> struct allow_conversion, float > { static constexpr bool value = false; }; + template<> struct allow_conversion, u64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u8 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s8 > { static constexpr bool value = false; }; + + template<> struct allow_conversion, double> { static constexpr bool value = false; }; + template<> struct allow_conversion, float > { static constexpr bool value = false; }; + template<> struct allow_conversion, u64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u8 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s8 > { static constexpr bool value = false; }; + + // + + template inline bool isfinite_wrapper(eT ) { return true; } + template<> inline bool isfinite_wrapper(float x) { return std::isfinite(x); } + template<> inline bool isfinite_wrapper(double x) { return std::isfinite(x); } + template inline bool isfinite_wrapper(std::complex& x) { return std::isfinite(x.real()) && std::isfinite(x.imag()); } + + // + + template + struct scalar_holder + { + const scalar_type1 scalar; + + inline explicit scalar_holder(const scalar_type1& in_scalar) : scalar(in_scalar) {} + + inline scalar_holder() = delete; + + template + < + typename scalar_type2, + typename arma::enable_if2::value, int>::result = 0 + > + inline + operator scalar_holder() const + { + const bool ok_conversion = (std::is_integral::value && std::is_floating_point::value) ? isfinite_wrapper(scalar) : true; + + return scalar_holder( ok_conversion ? scalar_type2(scalar) : scalar_type2(0) ); + } + }; + + // + + template + inline + typename enable_if2< is_supported_elem_type::value, scalar_holder >::result + value(const scalar_type& in_scalar) + { + return scalar_holder(in_scalar); + } + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_accu.hpp b/src/armadillo/include/armadillo_bits/fn_accu.hpp new file mode 100644 index 0000000..957459c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_accu.hpp @@ -0,0 +1,1002 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_accu +//! @{ + + + +template +arma_hot +inline +typename T1::elem_type +accu_proxy_linear(const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + eT val = eT(0); + + typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + if( arma_config::openmp && Proxy::use_mp && mp_gate::eval(n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + // NOTE: using parallelisation with manual reduction workaround to take into account complex numbers; + // NOTE: OpenMP versions lower than 4.0 do not support user-defined reduction + + const int n_threads_max = mp_thread_limit::get(); + const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max)); + const uword chunk_size = n_elem / n_threads_use; + + podarray partial_accs(n_threads_use); + + #pragma omp parallel for schedule(static) num_threads(int(n_threads_use)) + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) + { + const uword start = (thread_id+0) * chunk_size; + const uword endp1 = (thread_id+1) * chunk_size; + + eT acc = eT(0); + for(uword i=start; i < endp1; ++i) { acc += Pea[i]; } + + partial_accs[thread_id] = acc; + } + + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; } + + for(uword i=(n_threads_use*chunk_size); i < n_elem; ++i) { val += Pea[i]; } + } + #endif + } + else + { + #if defined(__FAST_MATH__) + { + if(P.is_aligned()) + { + typename Proxy::aligned_ea_type Pea_aligned = P.get_aligned_ea(); + + for(uword i=0; i +arma_hot +inline +typename T1::elem_type +accu_proxy_at_mp(const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + eT val = eT(0); + + #if defined(ARMA_USE_OPENMP) + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_cols == 1) + { + const int n_threads_max = mp_thread_limit::get(); + const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max)); + const uword chunk_size = n_rows / n_threads_use; + + podarray partial_accs(n_threads_use); + + #pragma omp parallel for schedule(static) num_threads(int(n_threads_use)) + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) + { + const uword start = (thread_id+0) * chunk_size; + const uword endp1 = (thread_id+1) * chunk_size; + + eT acc = eT(0); + for(uword i=start; i < endp1; ++i) { acc += P.at(i,0); } + + partial_accs[thread_id] = acc; + } + + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; } + + for(uword i=(n_threads_use*chunk_size); i < n_rows; ++i) { val += P.at(i,0); } + } + else + if(n_rows == 1) + { + const int n_threads_max = mp_thread_limit::get(); + const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max)); + const uword chunk_size = n_cols / n_threads_use; + + podarray partial_accs(n_threads_use); + + #pragma omp parallel for schedule(static) num_threads(int(n_threads_use)) + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) + { + const uword start = (thread_id+0) * chunk_size; + const uword endp1 = (thread_id+1) * chunk_size; + + eT acc = eT(0); + for(uword i=start; i < endp1; ++i) { acc += P.at(0,i); } + + partial_accs[thread_id] = acc; + } + + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; } + + for(uword i=(n_threads_use*chunk_size); i < n_cols; ++i) { val += P.at(0,i); } + } + else + { + podarray col_accs(n_cols); + + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < n_cols; ++col) + { + eT val1 = eT(0); + eT val2 = eT(0); + + uword i,j; + for(i=0, j=1; j < n_rows; i+=2, j+=2) { val1 += P.at(i,col); val2 += P.at(j,col); } + + if(i < n_rows) { val1 += P.at(i,col); } + + col_accs[col] = val1 + val2; + } + + val = arrayops::accumulate(col_accs.memptr(), n_cols); + } + } + #else + { + arma_ignore(P); + } + #endif + + return val; + } + + + +template +arma_hot +inline +typename T1::elem_type +accu_proxy_at(const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(arma_config::openmp && Proxy::use_mp && mp_gate::eval(P.get_n_elem())) + { + return accu_proxy_at_mp(P); + } + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + eT val = eT(0); + + if(n_rows != 1) + { + eT val1 = eT(0); + eT val2 = eT(0); + + for(uword col=0; col < n_cols; ++col) + { + uword i,j; + for(i=0, j=1; j < n_rows; i+=2, j+=2) { val1 += P.at(i,col); val2 += P.at(j,col); } + + if(i < n_rows) { val1 += P.at(i,col); } + } + + val = val1 + val2; + } + else + { + for(uword col=0; col < n_cols; ++col) { val += P.at(0,col); } + } + + return val; + } + + + +//! accumulate the elements of a matrix +template +arma_warn_unused +arma_hot +inline +typename enable_if2< is_arma_type::value, typename T1::elem_type >::result +accu(const T1& X) + { + arma_extra_debug_sigprint(); + + const Proxy P(X); + + if(is_Mat::stored_type>::value || is_subview_col::stored_type>::value) + { + const quasi_unwrap::stored_type> tmp(P.Q); + + return arrayops::accumulate(tmp.M.memptr(), tmp.M.n_elem); + } + + return (Proxy::use_at) ? accu_proxy_at(P) : accu_proxy_linear(P); + } + + + +//! explicit handling of multiply-and-accumulate +template +arma_warn_unused +inline +typename T1::elem_type +accu(const eGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef eGlue expr_type; + + typedef typename expr_type::proxy1_type::stored_type P1_stored_type; + typedef typename expr_type::proxy2_type::stored_type P2_stored_type; + + const bool have_direct_mem_1 = (is_Mat::value) || (is_subview_col::value); + const bool have_direct_mem_2 = (is_Mat::value) || (is_subview_col::value); + + if(have_direct_mem_1 && have_direct_mem_2) + { + const quasi_unwrap tmp1(expr.P1.Q); + const quasi_unwrap tmp2(expr.P2.Q); + + return op_dot::direct_dot(tmp1.M.n_elem, tmp1.M.memptr(), tmp2.M.memptr()); + } + + const Proxy P(expr); + + return (Proxy::use_at) ? accu_proxy_at(P) : accu_proxy_linear(P); + } + + + +//! explicit handling of Hamming norm (also known as zero norm) +template +arma_warn_unused +inline +uword +accu(const mtOp& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const eT val = X.aux; + + const Proxy P(X.m); + + uword n_nonzero = 0; + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i +arma_warn_unused +inline +uword +accu(const mtOp& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const eT val = X.aux; + + const Proxy P(X.m); + + uword n_nonzero = 0; + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i +arma_warn_unused +inline +uword +accu(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Proxy PA(X.A); + const Proxy PB(X.B); + + arma_debug_assert_same_size(PA, PB, "operator!="); + + uword n_nonzero = 0; + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typedef typename Proxy::ea_type PA_ea_type; + typedef typename Proxy::ea_type PB_ea_type; + + PA_ea_type A = PA.get_ea(); + PB_ea_type B = PB.get_ea(); + const uword n_elem = PA.get_n_elem(); + + for(uword i=0; i < n_elem; ++i) + { + n_nonzero += (A[i] != B[i]) ? uword(1) : uword(0); + } + } + else + { + const uword PA_n_cols = PA.get_n_cols(); + const uword PA_n_rows = PA.get_n_rows(); + + if(PA_n_rows == 1) + { + for(uword col=0; col < PA_n_cols; ++col) + { + n_nonzero += (PA.at(0,col) != PB.at(0,col)) ? uword(1) : uword(0); + } + } + else + { + for(uword col=0; col < PA_n_cols; ++col) + for(uword row=0; row < PA_n_rows; ++row) + { + n_nonzero += (PA.at(row,col) != PB.at(row,col)) ? uword(1) : uword(0); + } + } + } + + return n_nonzero; + } + + + +template +arma_warn_unused +inline +uword +accu(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Proxy PA(X.A); + const Proxy PB(X.B); + + arma_debug_assert_same_size(PA, PB, "operator=="); + + uword n_nonzero = 0; + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typedef typename Proxy::ea_type PA_ea_type; + typedef typename Proxy::ea_type PB_ea_type; + + PA_ea_type A = PA.get_ea(); + PB_ea_type B = PB.get_ea(); + const uword n_elem = PA.get_n_elem(); + + for(uword i=0; i < n_elem; ++i) + { + n_nonzero += (A[i] == B[i]) ? uword(1) : uword(0); + } + } + else + { + const uword PA_n_cols = PA.get_n_cols(); + const uword PA_n_rows = PA.get_n_rows(); + + if(PA_n_rows == 1) + { + for(uword col=0; col < PA_n_cols; ++col) + { + n_nonzero += (PA.at(0,col) == PB.at(0,col)) ? uword(1) : uword(0); + } + } + else + { + for(uword col=0; col < PA_n_cols; ++col) + for(uword row=0; row < PA_n_rows; ++row) + { + n_nonzero += (PA.at(row,col) == PB.at(row,col)) ? uword(1) : uword(0); + } + } + } + + return n_nonzero; + } + + + +//! accumulate the elements of a subview (submatrix) +template +arma_warn_unused +arma_hot +inline +eT +accu(const subview& X) + { + arma_extra_debug_sigprint(); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(X_n_rows == 1) + { + const Mat& m = X.m; + + const uword col_offset = X.aux_col1; + const uword row_offset = X.aux_row1; + + eT val1 = eT(0); + eT val2 = eT(0); + + uword i,j; + for(i=0, j=1; j < X_n_cols; i+=2, j+=2) + { + val1 += m.at(row_offset, col_offset + i); + val2 += m.at(row_offset, col_offset + j); + } + + if(i < X_n_cols) { val1 += m.at(row_offset, col_offset + i); } + + return val1 + val2; + } + + if(X_n_cols == 1) { return arrayops::accumulate( X.colptr(0), X_n_rows ); } + + eT val = eT(0); + + for(uword col=0; col < X_n_cols; ++col) + { + val += arrayops::accumulate( X.colptr(col), X_n_rows ); + } + + return val; + } + + + +template +arma_warn_unused +arma_hot +inline +eT +accu(const subview_col& X) + { + arma_extra_debug_sigprint(); + + return arrayops::accumulate( X.colmem, X.n_rows ); + } + + + +// + + + +template +arma_hot +inline +typename T1::elem_type +accu_cube_proxy_linear(const ProxyCube& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + eT val = eT(0); + + typename ProxyCube::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + if( arma_config::openmp && ProxyCube::use_mp && mp_gate::eval(n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + // NOTE: using parallelisation with manual reduction workaround to take into account complex numbers; + // NOTE: OpenMP versions lower than 4.0 do not support user-defined reduction + + const int n_threads_max = mp_thread_limit::get(); + const uword n_threads_use = (std::min)(uword(podarray_prealloc_n_elem::val), uword(n_threads_max)); + const uword chunk_size = n_elem / n_threads_use; + + podarray partial_accs(n_threads_use); + + #pragma omp parallel for schedule(static) num_threads(int(n_threads_use)) + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) + { + const uword start = (thread_id+0) * chunk_size; + const uword endp1 = (thread_id+1) * chunk_size; + + eT acc = eT(0); + for(uword i=start; i < endp1; ++i) { acc += Pea[i]; } + + partial_accs[thread_id] = acc; + } + + for(uword thread_id=0; thread_id < n_threads_use; ++thread_id) { val += partial_accs[thread_id]; } + + for(uword i=(n_threads_use*chunk_size); i < n_elem; ++i) { val += Pea[i]; } + } + #endif + } + else + { + #if defined(__FAST_MATH__) + { + if(P.is_aligned()) + { + typename ProxyCube::aligned_ea_type Pea_aligned = P.get_aligned_ea(); + + for(uword i=0; i +arma_hot +inline +typename T1::elem_type +accu_cube_proxy_at_mp(const ProxyCube& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + eT val = eT(0); + + #if defined(ARMA_USE_OPENMP) + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + podarray slice_accs(n_slices); + + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword slice = 0; slice < n_slices; ++slice) + { + eT val1 = eT(0); + eT val2 = eT(0); + + for(uword col = 0; col < n_cols; ++col) + { + uword i,j; + for(i=0, j=1; j +arma_hot +inline +typename T1::elem_type +accu_cube_proxy_at(const ProxyCube& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(arma_config::openmp && ProxyCube::use_mp && mp_gate::eval(P.get_n_elem())) + { + return accu_cube_proxy_at_mp(P); + } + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + eT val1 = eT(0); + eT val2 = eT(0); + + for(uword slice = 0; slice < n_slices; ++slice) + for(uword col = 0; col < n_cols; ++col ) + { + uword i,j; + for(i=0, j=1; j +arma_warn_unused +arma_hot +inline +typename T1::elem_type +accu(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const ProxyCube P(X.get_ref()); + + if(is_Cube::stored_type>::value) + { + unwrap_cube::stored_type> tmp(P.Q); + + return arrayops::accumulate(tmp.M.memptr(), tmp.M.n_elem); + } + + return (ProxyCube::use_at) ? accu_cube_proxy_at(P) : accu_cube_proxy_linear(P); + } + + + +//! explicit handling of multiply-and-accumulate (cube version) +template +arma_warn_unused +inline +typename T1::elem_type +accu(const eGlueCube& expr) + { + arma_extra_debug_sigprint(); + + typedef eGlueCube expr_type; + + typedef typename ProxyCube::stored_type P1_stored_type; + typedef typename ProxyCube::stored_type P2_stored_type; + + if(is_Cube::value && is_Cube::value) + { + const unwrap_cube tmp1(expr.P1.Q); + const unwrap_cube tmp2(expr.P2.Q); + + return op_dot::direct_dot(tmp1.M.n_elem, tmp1.M.memptr(), tmp2.M.memptr()); + } + + const ProxyCube P(expr); + + return (ProxyCube::use_at) ? accu_cube_proxy_at(P) : accu_cube_proxy_linear(P); + } + + + +// + + + +template +arma_warn_unused +inline +typename arma_scalar_only::result +accu(const T& x) + { + return x; + } + + + +//! accumulate values in a sparse object +template +arma_warn_unused +inline +typename T1::elem_type +accu(const SpBase& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(expr.get_ref()); + + const uword N = P.get_n_nonzero(); + + if(N == 0) { return eT(0); } + + if(SpProxy::use_iterator == false) + { + // direct counting + return arrayops::accumulate(P.get_values(), N); + } + + if(is_SpSubview::stored_type>::value) + { + const SpSubview& sv = reinterpret_cast< const SpSubview& >(P.Q); + + if(sv.n_rows == sv.m.n_rows) + { + const SpMat& m = sv.m; + const uword col = sv.aux_col1; + + return arrayops::accumulate(&(m.values[ m.col_ptrs[col] ]), N); + } + } + + typename SpProxy::const_iterator_type it = P.begin(); + + eT val = eT(0); + + for(uword i=0; i < N; ++i) { val += (*it); ++it; } + + return val; + } + + + +//! explicit handling of accu(A + B), where A and B are sparse matrices +template +arma_warn_unused +inline +typename T1::elem_type +accu(const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "addition"); + + return (accu(UA.M) + accu(UB.M)); + } + + + +//! explicit handling of accu(A - B), where A and B are sparse matrices +template +arma_warn_unused +inline +typename T1::elem_type +accu(const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "subtraction"); + + return (accu(UA.M) - accu(UB.M)); + } + + + +//! explicit handling of accu(A % B), where A and B are sparse matrices +template +arma_warn_unused +inline +typename T1::elem_type +accu(const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy px(expr.A); + const SpProxy py(expr.B); + + typename SpProxy::const_iterator_type x_it = px.begin(); + typename SpProxy::const_iterator_type x_it_end = px.end(); + + typename SpProxy::const_iterator_type y_it = py.begin(); + typename SpProxy::const_iterator_type y_it_end = py.end(); + + eT acc = eT(0); + + while( (x_it != x_it_end) || (y_it != y_it_end) ) + { + if(x_it == y_it) + { + acc += ((*x_it) * (*y_it)); + + ++x_it; + ++y_it; + } + else + { + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + ++x_it; + } + else // x is closer to the end + { + ++y_it; + } + } + } + + return acc; + } + + + +template +arma_warn_unused +inline +typename T1::elem_type +accu(const SpOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const bool is_vectorise = \ + (is_same_type::yes) + || (is_same_type::yes) + || (is_same_type::yes); + + if(is_vectorise) + { + return accu(expr.m); + } + + const SpMat tmp = expr; + + return accu(tmp); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_all.hpp b/src/armadillo/include/armadillo_bits/fn_all.hpp new file mode 100644 index 0000000..c69095d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_all.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_all +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + bool + >::result +all(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_all::all_vec(X); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +all(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +all(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtOp(X, dim, 0); + } + + + +//! for compatibility purposes: allows compiling user code designed for earlier versions of Armadillo +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_supported_elem_type::value, + bool + >::result +all(const T& val) + { + return (val != T(0)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_any.hpp b/src/armadillo/include/armadillo_bits/fn_any.hpp new file mode 100644 index 0000000..9038059 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_any.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_any +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + bool + >::result +any(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_any::any_vec(X); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +any(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +any(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtOp(X, dim, 0); + } + + + +//! for compatibility purposes: allows compiling user code designed for earlier versions of Armadillo +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_supported_elem_type::value, + bool + >::result +any(const T& val) + { + return (val != T(0)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_approx_equal.hpp b/src/armadillo/include/armadillo_bits/fn_approx_equal.hpp new file mode 100644 index 0000000..e92b94f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_approx_equal.hpp @@ -0,0 +1,471 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_approx_equal +//! @{ + + + +template +arma_inline +bool +internal_approx_equal_abs_diff(const eT& x, const eT& y, const typename get_pod_type::result tol) + { + typedef typename get_pod_type::result T; + + if(x != y) + { + if(is_real::value) // also true for eT = std::complex or eT = std::complex + { + if( arma_isnan(x) || arma_isnan(y) || (eop_aux::arma_abs(x - y) > tol) ) { return false; } + } + else + { + if( eop_aux::arma_abs( ( cond_rel< is_cx::no >::gt(x, y) ) ? (x-y) : (y-x) ) > tol ) { return false; } + } + } + + return true; + } + + + +template +arma_inline +bool +internal_approx_equal_rel_diff(const eT& a, const eT& b, const typename get_pod_type::result tol) + { + typedef typename get_pod_type::result T; + + if(a != b) + { + if(is_real::value) // also true for eT = std::complex or eT = std::complex + { + if( arma_isnan(a) || arma_isnan(b) ) { return false; } + + const T abs_a = eop_aux::arma_abs(a); + const T abs_b = eop_aux::arma_abs(b); + + const T max_c = (std::max)(abs_a,abs_b); + + const T abs_d = eop_aux::arma_abs(a - b); + + if(max_c >= T(1)) + { + if( abs_d > (tol * max_c) ) { return false; } + } + else + { + if( (abs_d / max_c) > tol ) { return false; } + } + } + else + { + const T abs_a = eop_aux::arma_abs(a); + const T abs_b = eop_aux::arma_abs(b); + + const T max_c = (std::max)(abs_a,abs_b); + + const T abs_d = eop_aux::arma_abs( ( cond_rel< is_cx::no >::gt(a, b) ) ? (a-b) : (b-a) ); + + if( abs_d > (tol * max_c) ) { return false; } + } + } + + return true; + } + + + +template +inline +bool +internal_approx_equal_worker + ( + const Base& A, + const Base& B, + const typename T1::pod_type abs_tol, + const typename T1::pod_type rel_tol + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + arma_debug_check( ((use_abs_diff == false) && (use_rel_diff == false)), "internal_approx_equal_worker(): both 'use_abs_diff' and 'use_rel_diff' are false" ); + + if(use_abs_diff) { arma_debug_check( cond_rel< is_signed::value >::lt(abs_tol, T(0)), "approx_equal(): argument 'abs_tol' must be >= 0" ); } + if(use_rel_diff) { arma_debug_check( cond_rel< is_signed::value >::lt(rel_tol, T(0)), "approx_equal(): argument 'rel_tol' must be >= 0" ); } + + const Proxy PA(A.get_ref()); + const Proxy PB(B.get_ref()); + + if( (PA.get_n_rows() != PB.get_n_rows()) || (PA.get_n_cols() != PB.get_n_cols()) ) { return false; } + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + const uword N = PA.get_n_elem(); + + const typename Proxy::ea_type PA_ea = PA.get_ea(); + const typename Proxy::ea_type PB_ea = PB.get_ea(); + + for(uword i=0; i +inline +bool +internal_approx_equal_worker + ( + const BaseCube& A, + const BaseCube& B, + const typename T1::pod_type abs_tol, + const typename T1::pod_type rel_tol + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + arma_debug_check( ((use_abs_diff == false) && (use_rel_diff == false)), "internal_approx_equal_worker(): both 'use_abs_diff' and 'use_rel_diff' are false" ); + + if(use_abs_diff) { arma_debug_check( cond_rel< is_signed::value >::lt(abs_tol, T(0)), "approx_equal(): argument 'abs_tol' must be >= 0" ); } + if(use_rel_diff) { arma_debug_check( cond_rel< is_signed::value >::lt(rel_tol, T(0)), "approx_equal(): argument 'rel_tol' must be >= 0" ); } + + const ProxyCube PA(A.get_ref()); + const ProxyCube PB(B.get_ref()); + + if( (PA.get_n_rows() != PB.get_n_rows()) || (PA.get_n_cols() != PB.get_n_cols()) || (PA.get_n_slices() != PB.get_n_slices()) ) { return false; } + + if( (ProxyCube::use_at == false) && (ProxyCube::use_at == false) ) + { + const uword N = PA.get_n_elem(); + + const typename ProxyCube::ea_type PA_ea = PA.get_ea(); + const typename ProxyCube::ea_type PB_ea = PB.get_ea(); + + for(uword i=0; i +inline +bool +internal_approx_equal_handler(const T1& A, const T2& B, const char* method, const typename T1::pod_type abs_tol, const typename T1::pod_type rel_tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 'a') && (sig != 'r') && (sig != 'b')), "approx_equal(): argument 'method' must be \"absdiff\" or \"reldiff\" or \"both\"" ); + + bool status = false; + + if(sig == 'a') + { + status = internal_approx_equal_worker(A, B, abs_tol, T(0)); + } + else + if(sig == 'r') + { + status = internal_approx_equal_worker(A, B, T(0), rel_tol); + } + else + if(sig == 'b') + { + status = internal_approx_equal_worker(A, B, abs_tol, rel_tol); + } + + return status; + } + + + +template +inline +bool +internal_approx_equal_handler(const T1& A, const T2& B, const char* method, const typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 'a') && (sig != 'r') && (sig != 'b')), "approx_equal(): argument 'method' must be \"absdiff\" or \"reldiff\" or \"both\"" ); + + arma_debug_check( (sig == 'b'), "approx_equal(): argument 'method' is \"both\", but only one 'tol' argument has been given" ); + + bool status = false; + + if(sig == 'a') + { + status = internal_approx_equal_worker(A, B, tol, T(0)); + } + else + if(sig == 'r') + { + status = internal_approx_equal_worker(A, B, T(0), tol); + } + + return status; + } + + + +template +arma_warn_unused +inline +bool +approx_equal(const Base& A, const Base& B, const char* method, const typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + return internal_approx_equal_handler(A.get_ref(), B.get_ref(), method, tol); + } + + + +template +arma_warn_unused +inline +bool +approx_equal(const BaseCube& A, const BaseCube& B, const char* method, const typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + return internal_approx_equal_handler(A.get_ref(), B.get_ref(), method, tol); + } + + + +template +arma_warn_unused +inline +bool +approx_equal(const Base& A, const Base& B, const char* method, const typename T1::pod_type abs_tol, const typename T1::pod_type rel_tol) + { + arma_extra_debug_sigprint(); + + return internal_approx_equal_handler(A.get_ref(), B.get_ref(), method, abs_tol, rel_tol); + } + + + +template +arma_warn_unused +inline +bool +approx_equal(const BaseCube& A, const BaseCube& B, const char* method, const typename T1::pod_type abs_tol, const typename T1::pod_type rel_tol) + { + arma_extra_debug_sigprint(); + + return internal_approx_equal_handler(A.get_ref(), B.get_ref(), method, abs_tol, rel_tol); + } + + + +template +arma_warn_unused +inline +bool +approx_equal(const SpBase& A, const SpBase& B, const char* method, const typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 'a') && (sig != 'r') && (sig != 'b')), "approx_equal(): argument 'method' must be \"absdiff\" or \"reldiff\" or \"both\"" ); + + arma_debug_check( (sig == 'b'), "approx_equal(): argument 'method' is \"both\", but only one 'tol' argument has been given" ); + + arma_debug_check( (sig == 'r'), "approx_equal(): only the \"absdiff\" method is currently implemented for sparse matrices" ); + + arma_debug_check( cond_rel< is_signed::value >::lt(tol, T(0)), "approx_equal(): argument 'tol' must be >= 0" ); + + const unwrap_spmat UA(A.get_ref()); + const unwrap_spmat UB(B.get_ref()); + + if( (UA.M.n_rows != UB.M.n_rows) || (UA.M.n_cols != UB.M.n_cols) ) { return false; } + + const SpMat C = UA.M - UB.M; + + typename SpMat::const_iterator it = C.begin(); + typename SpMat::const_iterator it_end = C.end(); + + while(it != it_end) + { + const eT val = (*it); + + if( arma_isnan(val) || (eop_aux::arma_abs(val) > tol) ) { return false; } + + ++it; + } + + return true; + } + + + +template +arma_warn_unused +inline +bool +approx_equal(const SpBase& A, const SpBase& B, const char* method, const typename T1::pod_type abs_tol, const typename T1::pod_type rel_tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 'a') && (sig != 'r') && (sig != 'b')), "approx_equal(): argument 'method' must be \"absdiff\" or \"reldiff\" or \"both\"" ); + + arma_debug_check( ((sig == 'r') || (sig == 'b')), "approx_equal(): only the \"absdiff\" method is currently implemented for sparse matrices" ); + + arma_debug_check( cond_rel< is_signed::value >::lt(abs_tol, T(0)), "approx_equal(): argument 'abs_tol' must be >= 0" ); + arma_debug_check( cond_rel< is_signed::value >::lt(rel_tol, T(0)), "approx_equal(): argument 'rel_tol' must be >= 0" ); + + return approx_equal(A.get_ref(), B.get_ref(), "abs", abs_tol); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_as_scalar.hpp b/src/armadillo/include/armadillo_bits/fn_as_scalar.hpp new file mode 100644 index 0000000..b59dbfd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_as_scalar.hpp @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_as_scalar +//! @{ + + + +template +struct as_scalar_redirect + { + template + inline static typename T1::elem_type apply(const T1& X); + }; + + + +template<> +struct as_scalar_redirect<2> + { + template + inline static typename T1::elem_type apply(const Glue& X); + }; + + +template<> +struct as_scalar_redirect<3> + { + template + inline static typename T1::elem_type apply(const Glue< Glue, T3, glue_times>& X); + }; + + + +template +template +inline +typename T1::elem_type +as_scalar_redirect::apply(const T1& X) + { + arma_extra_debug_sigprint(); + + const Proxy P(X); + + arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression must evaluate to exactly one element" ); + + return (Proxy::use_at) ? P.at(0,0) : P[0]; + } + + + +template +inline +typename T1::elem_type +as_scalar_redirect<2>::apply(const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // T1 must result in a matrix with one row + // T2 must result in a matrix with one column + + const bool has_all_mat = (is_Mat::value || is_Mat_trans::value) && (is_Mat::value || is_Mat_trans::value); + + const bool use_at = (Proxy::use_at || Proxy::use_at); + + const bool do_partial_unwrap = (has_all_mat || use_at); + + if(do_partial_unwrap) + { + const partial_unwrap tmp1(X.A); + const partial_unwrap tmp2(X.B); + + typedef typename partial_unwrap::stored_type TA; + typedef typename partial_unwrap::stored_type TB; + + const TA& A = tmp1.M; + const TB& B = tmp2.M; + + const uword A_n_rows = (tmp1.do_trans == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols); + const uword A_n_cols = (tmp1.do_trans == false) ? (TA::is_col ? 1 : A.n_cols) : (TA::is_row ? 1 : A.n_rows); + + const uword B_n_rows = (tmp2.do_trans == false) ? (TB::is_row ? 1 : B.n_rows) : (TB::is_col ? 1 : B.n_cols); + const uword B_n_cols = (tmp2.do_trans == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows); + + arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" ); + + const eT val = op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr()); + + return (tmp1.do_times || tmp2.do_times) ? (val * tmp1.get_val() * tmp2.get_val()) : val; + } + else + { + const Proxy PA(X.A); + const Proxy PB(X.B); + + arma_debug_check + ( + (PA.get_n_rows() != 1) || (PB.get_n_cols() != 1) || (PA.get_n_cols() != PB.get_n_rows()), + "as_scalar(): incompatible dimensions" + ); + + return op_dot::apply_proxy(PA,PB); + } + } + + + +template +inline +typename T1::elem_type +as_scalar_redirect<3>::apply(const Glue< Glue, T3, glue_times >& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // T1 * T2 must result in a matrix with one row + // T3 must result in a matrix with one column + + typedef typename strip_inv ::stored_type T2_stripped_1; + typedef typename strip_diagmat::stored_type T2_stripped_2; + + const strip_inv strip1(X.A.B); + const strip_diagmat strip2(strip1.M); + + const bool tmp2_do_inv_gen = strip1.do_inv_gen && arma_config::optimise_invexpr; + const bool tmp2_do_diagmat = strip2.do_diagmat; + + if(tmp2_do_diagmat == false) + { + const Mat tmp(X); + + arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression must evaluate to exactly one element" ); + + return tmp[0]; + } + else + { + const partial_unwrap tmp1(X.A.A); + const partial_unwrap tmp2(strip2.M); + const partial_unwrap tmp3(X.B); + + const Mat& A = tmp1.M; + const Mat& B = tmp2.M; + const Mat& C = tmp3.M; + + const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; + const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; + + const bool B_is_vec = B.is_vec(); + + const uword B_n_rows = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); + const uword B_n_cols = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); + + const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; + const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; + + const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); + + arma_debug_check + ( + (A_n_rows != 1) || + (C_n_cols != 1) || + (A_n_cols != B_n_rows) || + (B_n_cols != C_n_rows) + , + "as_scalar(): incompatible dimensions" + ); + + + if(B_is_vec) + { + if(tmp2_do_inv_gen) + { + return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem); + } + else + { + return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); + } + } + else + { + if(tmp2_do_inv_gen) + { + return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem); + } + else + { + return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); + } + } + } + } + + + +template +inline +typename T1::elem_type +as_scalar_diag(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression must evaluate to exactly one element" ); + + return A.mem[0]; + } + + + +template +inline +typename T1::elem_type +as_scalar_diag(const Glue< Glue, T3, glue_times >& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // T1 * T2 must result in a matrix with one row + // T3 must result in a matrix with one column + + typedef typename strip_diagmat::stored_type T2_stripped; + + const strip_diagmat strip(X.A.B); + + const partial_unwrap tmp1(X.A.A); + const partial_unwrap tmp2(strip.M); + const partial_unwrap tmp3(X.B); + + const Mat& A = tmp1.M; + const Mat& B = tmp2.M; + const Mat& C = tmp3.M; + + + const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; + const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; + + const bool B_is_vec = B.is_vec(); + + const uword B_n_rows = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); + const uword B_n_cols = (B_is_vec) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); + + const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; + const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; + + const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); + + arma_debug_check + ( + (A_n_rows != 1) || + (C_n_cols != 1) || + (A_n_cols != B_n_rows) || + (B_n_cols != C_n_rows) + , + "as_scalar(): incompatible dimensions" + ); + + + if(B_is_vec) + { + return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); + } + else + { + return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); + } + } + + + +template +arma_warn_unused +inline +typename T1::elem_type +as_scalar(const Glue& X, const typename arma_not_cx::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_glue_times_diag::value) { return as_scalar_diag(X); } + + constexpr uword N_mat = 1 + depth_lhs< glue_times, Glue >::num; + + arma_extra_debug_print(arma_str::format("N_mat = %u") % N_mat); + + return as_scalar_redirect::apply(X); + } + + + +template +arma_warn_unused +inline +typename T1::elem_type +as_scalar(const Base& X) + { + arma_extra_debug_sigprint(); + + const Proxy P(X.get_ref()); + + arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression must evaluate to exactly one element" ); + + return (Proxy::use_at) ? P.at(0,0) : P[0]; + } + + +template +arma_warn_unused +inline +typename T1::elem_type +as_scalar(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const ProxyCube P(X.get_ref()); + + arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression must evaluate to exactly one element" ); + + return (ProxyCube::use_at) ? P.at(0,0,0) : P[0]; + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +as_scalar(const T& x) + { + return x; + } + + + +template +arma_warn_unused +inline +typename T1::elem_type +as_scalar(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat tmp(X.get_ref()); + const SpMat& A = tmp.M; + + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression must evaluate to exactly one element" ); + + return A.at(0,0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_chi2rnd.hpp b/src/armadillo/include/armadillo_bits/fn_chi2rnd.hpp new file mode 100644 index 0000000..9da08f4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_chi2rnd.hpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_chi2rnd +//! @{ + + + +arma_warn_unused +inline +double +chi2rnd(const double df) + { + arma_extra_debug_sigprint(); + + op_chi2rnd_varying_df generator; + + return generator(df); + } + + + +template +arma_warn_unused +inline +typename arma_real_only::result +chi2rnd(const eT df) + { + arma_extra_debug_sigprint(); + + op_chi2rnd_varying_df generator; + + return generator(df); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_real::value), + const Op + >::result +chi2rnd(const T1& expr) + { + arma_extra_debug_sigprint(); + + return Op(expr); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_Mat::value && is_real::value), + obj_type + >::result +chi2rnd(const typename obj_type::elem_type df, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + if(is_Col::value) + { + arma_debug_check( (n_cols != 1), "chi2rnd(): incompatible size" ); + } + else + if(is_Row::value) + { + arma_debug_check( (n_rows != 1), "chi2rnd(): incompatible size" ); + } + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + op_chi2rnd::fill_constant_df(out, df); + + return out; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_Mat::value && is_real::value), + obj_type + >::result +chi2rnd(const typename obj_type::elem_type df, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return chi2rnd(df, s.n_rows, s.n_cols); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_Mat::value && is_real::value), + obj_type + >::result +chi2rnd(const typename obj_type::elem_type df, const uword n_elem) + { + arma_extra_debug_sigprint(); + + if(is_Row::value) + { + return chi2rnd(df, 1, n_elem); + } + else + { + return chi2rnd(df, n_elem, 1); + } + } + + + +arma_warn_unused +inline +mat +chi2rnd(const double df, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + return chi2rnd(df, n_rows, n_cols); + } + + + +arma_warn_unused +inline +mat +chi2rnd(const double df, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return chi2rnd(df, s.n_rows, s.n_cols); + } + + + +arma_warn_unused +inline +vec +chi2rnd(const double df, const uword n_elem) + { + arma_extra_debug_sigprint(); + + return chi2rnd(df, n_elem, 1); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_chol.hpp b/src/armadillo/include/armadillo_bits/fn_chol.hpp new file mode 100644 index 0000000..dfd9e6e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_chol.hpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_chol +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +chol + ( + const Base& X, + const char* layout = "upper" + ) + { + arma_extra_debug_sigprint(); + + const char sig = (layout != nullptr) ? layout[0] : char(0); + + arma_debug_check( ((sig != 'u') && (sig != 'l')), "chol(): layout must be \"upper\" or \"lower\"" ); + + return Op(X.get_ref(), ((sig == 'u') ? 0 : 1), 0 ); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +chol + ( + Mat& out, + const Base& X, + const char* layout = "upper" + ) + { + arma_extra_debug_sigprint(); + + const char sig = (layout != nullptr) ? layout[0] : char(0); + + arma_debug_check( ((sig != 'u') && (sig != 'l')), "chol(): layout must be \"upper\" or \"lower\"" ); + + const bool status = op_chol::apply_direct(out, X.get_ref(), ((sig == 'u') ? 0 : 1)); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "chol(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +chol + ( + Mat& out, + Mat& P, + const Base& X, + const char* layout = "upper", + const char* P_mode = "matrix" + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const char sig_layout = (layout != nullptr) ? layout[0] : char(0); + const char sig_P_mode = (P_mode != nullptr) ? P_mode[0] : char(0); + + arma_debug_check( ((sig_layout != 'u') && (sig_layout != 'l')), "chol(): argument 'layout' must be \"upper\" or \"lower\"" ); + arma_debug_check( ((sig_P_mode != 'm') && (sig_P_mode != 'v')), "chol(): argument 'P_mode' must be \"vector\" or \"matrix\"" ); + + out = X.get_ref(); + + arma_debug_check( (out.is_square() == false), "chol(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if(out.is_empty()) + { + P.reset(); + return true; + } + + if((arma_config::debug) && (auxlib::rudimentary_sym_check(out) == false)) + { + if(is_cx::no ) { arma_debug_warn_level(1, "chol(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "chol(): given matrix is not hermitian"); } + } + + bool status = false; + + if(sig_P_mode == 'v') + { + status = auxlib::chol_pivot(out, P, ((sig_layout == 'u') ? 0 : 1)); + } + else + if(sig_P_mode == 'm') + { + Mat P_vec; + + status = auxlib::chol_pivot(out, P_vec, ((sig_layout == 'u') ? 0 : 1)); + + if(status) + { + // construct P + + const uword N = P_vec.n_rows; + + P.zeros(N,N); + + for(uword i=0; i < N; ++i) { P.at(P_vec[i], i) = uword(1); } + } + } + + if(status == false) + { + out.soft_reset(); + P.soft_reset(); + arma_debug_warn_level(3, "chol(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_clamp.hpp b/src/armadillo/include/armadillo_bits/fn_clamp.hpp new file mode 100644 index 0000000..a7da6d4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_clamp.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_clamp +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && is_cx::no, + const mtOp + >::result +clamp(const T1& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + return mtOp(mtOp_dual_aux_indicator(), X, min_val, max_val); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && is_cx::yes, + const mtOp + >::result +clamp(const T1& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + return mtOp(mtOp_dual_aux_indicator(), X, min_val, max_val); + } + + + +template +arma_warn_unused +inline +const mtOpCube +clamp(const BaseCube& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val, typename arma_not_cx::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtOpCube(mtOpCube_dual_aux_indicator(), X.get_ref(), min_val, max_val); + } + + + +template +arma_warn_unused +inline +const mtOpCube +clamp(const BaseCube& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val, typename arma_cx_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtOpCube(mtOpCube_dual_aux_indicator(), X.get_ref(), min_val, max_val); + } + + + +template +arma_warn_unused +inline +SpMat +clamp(const SpBase& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + SpMat out = X.get_ref(); + + out.clamp(min_val, max_val); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_cond_rcond.hpp b/src/armadillo/include/armadillo_bits/fn_cond_rcond.hpp new file mode 100644 index 0000000..fae0a06 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_cond_rcond.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_cond +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2::value, typename T1::pod_type>::result +cond(const Base& X) + { + arma_extra_debug_sigprint(); + + return op_cond::apply(X.get_ref()); + } + + + +template +arma_warn_unused +inline +typename enable_if2::value, typename T1::pod_type>::result +rcond(const Base& X) + { + arma_extra_debug_sigprint(); + + return op_rcond::apply(X.get_ref()); + } + + + +// template +// arma_warn_unused +// inline +// typename enable_if2::value, typename T1::pod_type>::result +// rcond(const SpBase& X) +// { +// arma_extra_debug_sigprint(); +// +// return sp_auxlib::rcond(X.get_ref()); +// } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_conv.hpp b/src/armadillo/include/armadillo_bits/fn_conv.hpp new file mode 100644 index 0000000..44b7a0e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_conv.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_conv +//! @{ + + + +//! Convolution, which is also equivalent to polynomial multiplication and FIR digital filtering. + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_same_type::value), + const Glue + >::result +conv(const T1& A, const T2& B, const char* shape = "full") + { + arma_extra_debug_sigprint(); + + const char sig = (shape != nullptr) ? shape[0] : char(0); + + arma_debug_check( ((sig != 'f') && (sig != 's')), "conv(): unsupported value of 'shape' parameter" ); + + const uword mode = (sig == 's') ? uword(1) : uword(0); + + return Glue(A, B, mode); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_same_type::value), + const Glue + >::result +conv2(const T1& A, const T2& B, const char* shape = "full") + { + arma_extra_debug_sigprint(); + + const char sig = (shape != nullptr) ? shape[0] : char(0); + + arma_debug_check( ((sig != 'f') && (sig != 's')), "conv2(): unsupported value of 'shape' parameter" ); + + const uword mode = (sig == 's') ? uword(1) : uword(0); + + return Glue(A, B, mode); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_conv_to.hpp b/src/armadillo/include/armadillo_bits/fn_conv_to.hpp new file mode 100644 index 0000000..dbfc7fe --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_conv_to.hpp @@ -0,0 +1,720 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_conv_to +//! @{ + + + +//! conversion from Armadillo Base and BaseCube objects to scalars +//! NOTE: use as_scalar() instead; this functionality is kept only for compatibility with old user code +template +class conv_to + { + public: + + template + inline static out_eT from(const Base& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static out_eT from(const Base& in, const typename arma_cx_only::result* junk = nullptr); + + template + inline static out_eT from(const BaseCube& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static out_eT from(const BaseCube& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +template +template +arma_warn_unused +inline +out_eT +conv_to::from(const Base& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_type_check(( is_supported_elem_type::value == false )); + + const Proxy P(in.get_ref()); + + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); + + return out_eT(Proxy::use_at ? P.at(0,0) : P[0]); + } + + + +template +template +arma_warn_unused +inline +out_eT +conv_to::from(const Base& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_type_check(( is_supported_elem_type::value == false )); + + const Proxy P(in.get_ref()); + + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); + + out_eT out; + + arrayops::convert_cx_scalar(out, (Proxy::use_at ? P.at(0,0) : P[0])); + + return out; + } + + + +template +template +arma_warn_unused +inline +out_eT +conv_to::from(const BaseCube& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_type_check(( is_supported_elem_type::value == false )); + + const ProxyCube P(in.get_ref()); + + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); + + return out_eT(ProxyCube::use_at ? P.at(0,0,0) : P[0]); + } + + + +template +template +arma_warn_unused +inline +out_eT +conv_to::from(const BaseCube& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_type_check(( is_supported_elem_type::value == false )); + + const ProxyCube P(in.get_ref()); + + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); + + out_eT out; + + arrayops::convert_cx_scalar(out, (ProxyCube::use_at ? P.at(0,0,0) : P[0])); + + return out; + } + + + +//! conversion to Armadillo matrices from Armadillo Base objects, as well as from std::vector +template +class conv_to< Mat > + { + public: + + template + inline static Mat from(const Base& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static Mat from(const Base& in, const typename arma_cx_only::result* junk = nullptr); + + template + inline static Mat from(const SpBase& in); + + + + template + inline static Mat from(const std::vector& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static Mat from(const std::vector& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +template +template +arma_warn_unused +inline +Mat +conv_to< Mat >::from(const Base& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + Mat out(X.n_rows, X.n_cols, arma_nozeros_indicator()); + + arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +template +template +arma_warn_unused +inline +Mat +conv_to< Mat >::from(const Base& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + Mat out(X.n_rows, X.n_cols, arma_nozeros_indicator()); + + arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +template +template +arma_warn_unused +inline +Mat +conv_to< Mat >::from(const SpBase& in) + { + arma_extra_debug_sigprint(); + + return Mat(in.get_ref()); + } + + + +template +template +arma_warn_unused +inline +Mat +conv_to< Mat >::from(const std::vector& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword N = uword( in.size() ); + + Mat out(N, 1, arma_nozeros_indicator()); + + if(N > 0) + { + arrayops::convert( out.memptr(), &(in[0]), N ); + } + + return out; + } + + + +template +template +arma_warn_unused +inline +Mat +conv_to< Mat >::from(const std::vector& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword N = uword( in.size() ); + + Mat out(N, 1, arma_nozeros_indicator()); + + if(N > 0) + { + arrayops::convert_cx( out.memptr(), &(in[0]), N ); + } + + return out; + } + + + +//! conversion to Armadillo row vectors from Armadillo Base objects, as well as from std::vector +template +class conv_to< Row > + { + public: + + template + inline static Row from(const Base& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static Row from(const Base& in, const typename arma_cx_only::result* junk = nullptr); + + + + template + inline static Row from(const std::vector& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static Row from(const std::vector& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +template +template +arma_warn_unused +inline +Row +conv_to< Row >::from(const Base& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); + + Row out(X.n_elem, arma_nozeros_indicator()); + + arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +template +template +arma_warn_unused +inline +Row +conv_to< Row >::from(const Base& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); + + Row out(X.n_rows, X.n_cols, arma_nozeros_indicator()); + + arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +template +template +arma_warn_unused +inline +Row +conv_to< Row >::from(const std::vector& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword N = uword( in.size() ); + + Row out(N, arma_nozeros_indicator()); + + if(N > 0) + { + arrayops::convert( out.memptr(), &(in[0]), N ); + } + + return out; + } + + + +template +template +arma_warn_unused +inline +Row +conv_to< Row >::from(const std::vector& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword N = uword( in.size() ); + + Row out(N, arma_nozeros_indicator()); + + if(N > 0) + { + arrayops::convert_cx( out.memptr(), &(in[0]), N ); + } + + return out; + } + + + +//! conversion to Armadillo column vectors from Armadillo Base objects, as well as from std::vector +template +class conv_to< Col > + { + public: + + template + inline static Col from(const Base& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static Col from(const Base& in, const typename arma_cx_only::result* junk = nullptr); + + + + template + inline static Col from(const std::vector& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static Col from(const std::vector& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +template +template +arma_warn_unused +inline +Col +conv_to< Col >::from(const Base& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); + + Col out(X.n_elem, arma_nozeros_indicator()); + + arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +template +template +arma_warn_unused +inline +Col +conv_to< Col >::from(const Base& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); + + Col out(X.n_rows, X.n_cols, arma_nozeros_indicator()); + + arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +template +template +arma_warn_unused +inline +Col +conv_to< Col >::from(const std::vector& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword N = uword( in.size() ); + + Col out(N, arma_nozeros_indicator()); + + if(N > 0) + { + arrayops::convert( out.memptr(), &(in[0]), N ); + } + + return out; + } + + + +template +template +arma_warn_unused +inline +Col +conv_to< Col >::from(const std::vector& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword N = uword( in.size() ); + + Col out(N, arma_nozeros_indicator()); + + if(N > 0) + { + arrayops::convert_cx( out.memptr(), &(in[0]), N ); + } + + return out; + } + + + +//! convert between SpMat types +template +class conv_to< SpMat > + { + public: + + template + inline static SpMat from(const SpBase& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static SpMat from(const SpBase& in, const typename arma_cx_only::result* junk = nullptr); + + template + inline static SpMat from(const Base& in); + }; + + + +template +template +arma_warn_unused +inline +SpMat +conv_to< SpMat >::from(const SpBase& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const unwrap_spmat tmp(in.get_ref()); + const SpMat& X = tmp.M; + + SpMat out(arma_layout_indicator(), X); + + arrayops::convert( access::rwp(out.values), X.values, X.n_nonzero ); + + out.remove_zeros(); + + return out; + } + + + +template +template +arma_warn_unused +inline +SpMat +conv_to< SpMat >::from(const SpBase& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const unwrap_spmat tmp(in.get_ref()); + const SpMat& X = tmp.M; + + SpMat out(arma_layout_indicator(), X); + + arrayops::convert_cx( access::rwp(out.values), X.values, X.n_nonzero ); + + out.remove_zeros(); + + return out; + } + + + +template +template +arma_warn_unused +inline +SpMat +conv_to< SpMat >::from(const Base& in) + { + arma_extra_debug_sigprint(); + + return SpMat(in.get_ref()); + } + + + +//! conversion to Armadillo cubes from Armadillo BaseCube objects +template +class conv_to< Cube > + { + public: + + template + inline static Cube from(const BaseCube& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static Cube from(const BaseCube& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +template +template +arma_warn_unused +inline +Cube +conv_to< Cube >::from(const BaseCube& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const unwrap_cube tmp( in.get_ref() ); + const Cube& X = tmp.M; + + Cube out(X.n_rows, X.n_cols, X.n_slices, arma_nozeros_indicator()); + + arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +template +template +arma_warn_unused +inline +Cube +conv_to< Cube >::from(const BaseCube& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const unwrap_cube tmp( in.get_ref() ); + const Cube& X = tmp.M; + + Cube out(X.n_rows, X.n_cols, X.n_slices, arma_nozeros_indicator()); + + arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); + + return out; + } + + + +//! conversion to std::vector from Armadillo Base objects +template +class conv_to< std::vector > + { + public: + + template + inline static std::vector from(const Base& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static std::vector from(const Base& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +template +template +arma_warn_unused +inline +std::vector +conv_to< std::vector >::from(const Base& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); + + const uword N = X.n_elem; + + std::vector out(N); + + if(N > 0) + { + arrayops::convert( &(out[0]), X.memptr(), N ); + } + + return out; + } + + + +template +template +arma_warn_unused +inline +std::vector +conv_to< std::vector >::from(const Base& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(in.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); + + const uword N = X.n_elem; + + std::vector out(N); + + if(N > 0) + { + arrayops::convert_cx( &(out[0]), X.memptr(), N ); + } + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_cor.hpp b/src/armadillo/include/armadillo_bits/fn_cor.hpp new file mode 100644 index 0000000..18cd2fa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_cor.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_cor +//! @{ + + + +template +arma_warn_unused +inline +const Op +cor(const Base& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "cor(): parameter 'norm_type' must be 0 or 1" ); + + return Op(X.get_ref(), norm_type, 0); + } + + + +template +arma_warn_unused +inline +const Glue +cor(const Base& A, const Base& B, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "cor(): parameter 'norm_type' must be 0 or 1" ); + + return Glue(A.get_ref(), B.get_ref(), norm_type); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_cov.hpp b/src/armadillo/include/armadillo_bits/fn_cov.hpp new file mode 100644 index 0000000..ee61c00 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_cov.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_cov +//! @{ + + + +template +arma_warn_unused +inline +const Op +cov(const Base& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "cov(): parameter 'norm_type' must be 0 or 1" ); + + return Op(X.get_ref(), norm_type, 0); + } + + + +template +arma_warn_unused +inline +const Glue +cov(const Base& A, const Base& B, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "cov(): parameter 'norm_type' must be 0 or 1" ); + + return Glue(A.get_ref(), B.get_ref(), norm_type); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_cross.hpp b/src/armadillo/include/armadillo_bits/fn_cross.hpp new file mode 100644 index 0000000..bb08e17 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_cross.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_cross +//! @{ + + + +//! cross product (only valid for 3 dimensional vectors) +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value, + const Glue + >::result +cross(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return Glue(X, Y); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_cumprod.hpp b/src/armadillo/include/armadillo_bits/fn_cumprod.hpp new file mode 100644 index 0000000..f6cd1e1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_cumprod.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_cumprod +//! @{ + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +cumprod(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +cumprod(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +cumprod(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +cumprod(const T& x) + { + return x; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_cumsum.hpp b/src/armadillo/include/armadillo_bits/fn_cumsum.hpp new file mode 100644 index 0000000..ad6c637 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_cumsum.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_cumsum +//! @{ + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +cumsum(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +cumsum(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +cumsum(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +cumsum(const T& x) + { + return x; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_det.hpp b/src/armadillo/include/armadillo_bits/fn_det.hpp new file mode 100644 index 0000000..3941a85 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_det.hpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_det +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, typename T1::elem_type >::result +det(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + eT out_val = eT(0); + + const bool status = op_det::apply_direct(out_val, X.get_ref()); + + if(status == false) + { + out_val = eT(0); + arma_stop_runtime_error("det(): failed to find determinant"); + } + + return out_val; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +det(typename T1::elem_type& out_val, const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const bool status = op_det::apply_direct(out_val, X.get_ref()); + + if(status == false) + { + out_val = eT(0); + arma_debug_warn_level(3, "det(): failed to find determinant"); + } + + return status; + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +det(const T& x) + { + return x; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_diagmat.hpp b/src/armadillo/include/armadillo_bits/fn_diagmat.hpp new file mode 100644 index 0000000..7d6c5b0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_diagmat.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_diagmat +//! @{ + + +//! interpret a matrix or a vector as a diagonal matrix (ie. off-diagonal entries are zero) +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +diagmat(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +//! create a matrix with the k-th diagonal set to the given vector +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +diagmat(const T1& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return Op(X, row_offset, col_offset); + } + + + +template +arma_warn_unused +inline +const SpOp +diagmat(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +template +arma_warn_unused +inline +const SpOp +diagmat(const SpBase& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return SpOp(X.get_ref(), row_offset, col_offset); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_diags_spdiags.hpp b/src/armadillo/include/armadillo_bits/fn_diags_spdiags.hpp new file mode 100644 index 0000000..ceb8fd0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_diags_spdiags.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_diags_spdiags +//! @{ + + + +template +inline +Mat +diags(const Base& V_expr, const Base& D_expr, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UV(V_expr.get_ref()); + const Mat& V = UV.M; + + const quasi_unwrap UD(D_expr.get_ref()); + const Mat& D = UD.M; + + arma_debug_check( ((D.is_vec() == false) && (D.is_empty() == false)), "D must be a vector" ); + + arma_debug_check( (V.n_cols != D.n_elem), "number of colums in matrix V must match the length of vector D" ); + + Mat out(n_rows, n_cols, fill::zeros); + + for(uword i=0; i < D.n_elem; ++i) + { + const sword diag_id = D[i]; + + const uword row_offset = (diag_id < 0) ? uword(-diag_id) : 0; + const uword col_offset = (diag_id > 0) ? uword( diag_id) : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diags(): requested diagonal out of bounds" + ); + + const uword diag_len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + const uword V_start = (diag_id < 0) ? uword(0) : uword(diag_id); + + const eT* V_colmem = V.colptr(i); + + for(uword j=0; j < diag_len; ++j) + { + const uword V_index = V_start + j; + + if(V_index >= V.n_rows) { break; } + + out.at(j + row_offset, j + col_offset) = V_colmem[V_index]; + } + } + + return out; + } + + + +template +inline +SpMat +spdiags(const Base& V_expr, const Base& D_expr, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UV(V_expr.get_ref()); + const Mat& V = UV.M; + + const quasi_unwrap UD(D_expr.get_ref()); + const Mat& D = UD.M; + + arma_debug_check( ((D.is_vec() == false) && (D.is_empty() == false)), "D must be a vector" ); + + arma_debug_check( (V.n_cols != D.n_elem), "number of colums in matrix V must match the length of vector D" ); + + MapMat tmp(n_rows, n_cols); + + for(uword i=0; i < D.n_elem; ++i) + { + const sword diag_id = D[i]; + + const uword row_offset = (diag_id < 0) ? uword(-diag_id) : 0; + const uword col_offset = (diag_id > 0) ? uword( diag_id) : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diags(): requested diagonal out of bounds" + ); + + const uword diag_len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + const uword V_start = (diag_id < 0) ? uword(0) : uword(diag_id); + + const eT* V_colmem = V.colptr(i); + + for(uword j=0; j < diag_len; ++j) + { + const uword V_index = V_start + j; + + if(V_index >= V.n_rows) { break; } + + tmp.at(j + row_offset, j + col_offset) = V_colmem[V_index]; + } + } + + return SpMat(tmp); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_diagvec.hpp b/src/armadillo/include/armadillo_bits/fn_diagvec.hpp new file mode 100644 index 0000000..873c350 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_diagvec.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_diagvec +//! @{ + + +//! extract main diagonal from matrix +template +arma_warn_unused +arma_inline +const Op +diagvec(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +//! extract arbitrary diagonal from matrix +template +arma_warn_unused +arma_inline +const Op +diagvec(const Base& X, const sword diag_id) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), ((diag_id < 0) ? -diag_id : diag_id), ((diag_id < 0) ? 1 : 0) ); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +diagvec(const SpBase& X, const sword diag_id = 0) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), ((diag_id < 0) ? -diag_id : diag_id), ((diag_id < 0) ? 1 : 0) ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_diff.hpp b/src/armadillo/include/armadillo_bits/fn_diff.hpp new file mode 100644 index 0000000..2d7e8d3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_diff.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_diff +//! @{ + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +diff + ( + const T1& X, + const uword k = 1 + ) + { + arma_extra_debug_sigprint(); + + return Op(X, k, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +diff + ( + const T1& X, + const uword k = 1 + ) + { + arma_extra_debug_sigprint(); + + return Op(X, k, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +diff + ( + const T1& X, + const uword k, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + return Op(X, k, dim); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_dot.hpp b/src/armadillo/include/armadillo_bits/fn_dot.hpp new file mode 100644 index 0000000..d2cbfc8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_dot.hpp @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_dot +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::yes, + typename T1::elem_type + >::result +dot + ( + const T1& A, + const T2& B + ) + { + arma_extra_debug_sigprint(); + + return op_dot::apply(A,B); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::no, + typename promote_type::result + >::result +dot + ( + const T1& A, + const T2& B + ) + { + arma_extra_debug_sigprint(); + + return op_dot_mixed::apply(A,B); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value, + typename T1::elem_type + >::result +norm_dot + ( + const T1& A, + const T2& B + ) + { + arma_extra_debug_sigprint(); + + return op_norm_dot::apply(A,B); + } + + + +// +// cdot + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value && is_cx::no, + typename T1::elem_type + >::result +cdot + ( + const T1& A, + const T2& B + ) + { + arma_extra_debug_sigprint(); + + return op_dot::apply(A,B); + } + + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value && is_cx::yes, + typename T1::elem_type + >::result +cdot + ( + const T1& A, + const T2& B + ) + { + arma_extra_debug_sigprint(); + + return op_cdot::apply(A,B); + } + + + +// convert dot(htrans(x), y) to cdot(x,y) + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_same_type::value && is_cx::yes, + typename T1::elem_type + >::result +dot + ( + const Op& A, + const T2& B + ) + { + arma_extra_debug_sigprint(); + + return cdot(A.m, B); + } + + + +// +// for sparse matrices +// + + + +namespace priv + { + + template + arma_hot + inline + typename T1::elem_type + dot_helper(const SpProxy& pa, const SpProxy& pb) + { + typedef typename T1::elem_type eT; + + // Iterate over both objects and see when they are the same + eT result = eT(0); + + typename SpProxy::const_iterator_type a_it = pa.begin(); + typename SpProxy::const_iterator_type a_end = pa.end(); + + typename SpProxy::const_iterator_type b_it = pb.begin(); + typename SpProxy::const_iterator_type b_end = pb.end(); + + while((a_it != a_end) && (b_it != b_end)) + { + if(a_it == b_it) + { + result += (*a_it) * (*b_it); + + ++a_it; + ++b_it; + } + else if((a_it.col() < b_it.col()) || ((a_it.col() == b_it.col()) && (a_it.row() < b_it.row()))) + { + // a_it is "behind" + ++a_it; + } + else + { + // b_it is "behind" + ++b_it; + } + } + + return result; + } + + } + + + +//! dot product of two sparse objects +template +arma_warn_unused +arma_hot +inline +typename +enable_if2 + <(is_arma_sparse_type::value) && (is_arma_sparse_type::value) && (is_same_type::value), + typename T1::elem_type + >::result +dot + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + const SpProxy pa(x); + const SpProxy pb(y); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "dot()"); + + typedef typename T1::elem_type eT; + + typedef typename SpProxy::stored_type pa_Q_type; + typedef typename SpProxy::stored_type pb_Q_type; + + if( + ( (SpProxy::use_iterator == false) && (SpProxy::use_iterator == false) ) + && ( (is_SpMat::value == true ) && (is_SpMat::value == true ) ) + ) + { + const unwrap_spmat tmp_a(pa.Q); + const unwrap_spmat tmp_b(pb.Q); + + const SpMat& A = tmp_a.M; + const SpMat& B = tmp_b.M; + + if( &A == &B ) + { + // We can do it directly! + return op_dot::direct_dot_arma(A.n_nonzero, A.values, A.values); + } + else + { + return priv::dot_helper(pa,pb); + } + } + else + { + return priv::dot_helper(pa,pb); + } + } + + + +//! dot product of one dense and one sparse object +template +arma_warn_unused +arma_hot +inline +typename +enable_if2 + <(is_arma_type::value) && (is_arma_sparse_type::value) && (is_same_type::value), + typename T1::elem_type + >::result +dot + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + const Proxy pa(x); + const SpProxy pb(y); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "dot()"); + + typedef typename T1::elem_type eT; + + eT result = eT(0); + + typename SpProxy::const_iterator_type it = pb.begin(); + typename SpProxy::const_iterator_type it_end = pb.end(); + + // use_at == false won't save us operations + while(it != it_end) + { + result += (*it) * pa.at(it.row(), it.col()); + ++it; + } + + return result; + } + + + +//! dot product of one sparse and one dense object +template +arma_warn_unused +arma_hot +inline +typename +enable_if2 + <(is_arma_sparse_type::value) && (is_arma_type::value) && (is_same_type::value), + typename T1::elem_type + >::result +dot + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + // this is commutative + return dot(y, x); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_eig_gen.hpp b/src/armadillo/include/armadillo_bits/fn_eig_gen.hpp new file mode 100644 index 0000000..ad22821 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_eig_gen.hpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_eig_gen +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, Col< std::complex > >::result +eig_gen + ( + const Base& expr, + const char* option = "nobalance" + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + const char sig = (option != nullptr) ? option[0] : char(0); + + arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); + + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + + Col eigvals; + Mat eigvecs; + + const bool status = (sig == 'b') ? auxlib::eig_gen_balance(eigvals, eigvecs, false, expr.get_ref()) : auxlib::eig_gen(eigvals, eigvecs, false, expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + arma_stop_runtime_error("eig_gen(): decomposition failed"); + } + + return eigvals; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_gen + ( + Col< std::complex >& eigvals, + const Base< typename T1::elem_type, T1>& expr, + const char* option = "nobalance" + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + const char sig = (option != nullptr) ? option[0] : char(0); + + arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); + + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + + Mat eigvecs; + + const bool status = (sig == 'b') ? auxlib::eig_gen_balance(eigvals, eigvecs, false, expr.get_ref()) : auxlib::eig_gen(eigvals, eigvecs, false, expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + arma_debug_warn_level(3, "eig_gen(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_gen + ( + Col< std::complex >& eigvals, + Mat< std::complex >& eigvecs, + const Base& expr, + const char* option = "nobalance" + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&eigvecs)), "eig_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + const char sig = (option != nullptr) ? option[0] : char(0); + + arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); + + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + + const bool status = (sig == 'b') ? auxlib::eig_gen_balance(eigvals, eigvecs, true, expr.get_ref()) : auxlib::eig_gen(eigvals, eigvecs, true, expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + eigvecs.soft_reset(); + arma_debug_warn_level(3, "eig_gen(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_gen + ( + Col< std::complex >& eigvals, + Mat< std::complex >& leigvecs, + Mat< std::complex >& reigvecs, + const Base& expr, + const char* option = "nobalance" + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&leigvecs)), "eig_gen(): parameter 'eigval' is an alias of parameter 'leigvec'" ); + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&reigvecs)), "eig_gen(): parameter 'eigval' is an alias of parameter 'reigvec'" ); + arma_debug_check( (void_ptr(&leigvecs) == void_ptr(&reigvecs)), "eig_gen(): parameter 'leigvec' is an alias of parameter 'reigvec'" ); + + const char sig = (option != nullptr) ? option[0] : char(0); + + arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); + + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + + const bool status = (sig == 'b') ? auxlib::eig_gen_twosided_balance(eigvals, leigvecs, reigvecs, expr.get_ref()) : auxlib::eig_gen_twosided(eigvals, leigvecs, reigvecs, expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + leigvecs.soft_reset(); + reigvecs.soft_reset(); + arma_debug_warn_level(3, "eig_gen(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_eig_pair.hpp b/src/armadillo/include/armadillo_bits/fn_eig_pair.hpp new file mode 100644 index 0000000..cef1389 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_eig_pair.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_eig_pair +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, Col< std::complex > >::result +eig_pair + ( + const Base& A_expr, + const Base& B_expr + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Col< std::complex > eigvals; + Mat< std::complex > eigvecs; + + const bool status = auxlib::eig_pair(eigvals, eigvecs, false, A_expr.get_ref(), B_expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + arma_stop_runtime_error("eig_pair(): decomposition failed"); + } + + return eigvals; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_pair + ( + Col< std::complex >& eigvals, + const Base< typename T1::elem_type, T1 >& A_expr, + const Base< typename T1::elem_type, T2 >& B_expr + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvecs; + + const bool status = auxlib::eig_pair(eigvals, eigvecs, false, A_expr.get_ref(), B_expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + arma_debug_warn_level(3, "eig_pair(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_pair + ( + Col< std::complex >& eigvals, + Mat< std::complex >& eigvecs, + const Base< typename T1::elem_type, T1 >& A_expr, + const Base< typename T1::elem_type, T2 >& B_expr + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&eigvecs)), "eig_pair(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + const bool status = auxlib::eig_pair(eigvals, eigvecs, true, A_expr.get_ref(), B_expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + eigvecs.soft_reset(); + arma_debug_warn_level(3, "eig_pair(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_pair + ( + Col< std::complex >& eigvals, + Mat< std::complex >& leigvecs, + Mat< std::complex >& reigvecs, + const Base< typename T1::elem_type, T1 >& A_expr, + const Base< typename T1::elem_type, T2 >& B_expr + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&leigvecs)), "eig_pair(): parameter 'eigval' is an alias of parameter 'leigvec'" ); + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&reigvecs)), "eig_pair(): parameter 'eigval' is an alias of parameter 'reigvec'" ); + arma_debug_check( (void_ptr(&leigvecs) == void_ptr(&reigvecs)), "eig_pair(): parameter 'leigvec' is an alias of parameter 'reigvec'" ); + + const bool status = auxlib::eig_pair_twosided(eigvals, leigvecs, reigvecs, A_expr.get_ref(), B_expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + leigvecs.soft_reset(); + reigvecs.soft_reset(); + arma_debug_warn_level(3, "eig_pair(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_eig_sym.hpp b/src/armadillo/include/armadillo_bits/fn_eig_sym.hpp new file mode 100644 index 0000000..12043b6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_eig_sym.hpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_eig_sym +//! @{ + + +//! Eigenvalues of real/complex symmetric/hermitian matrix X +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_sym + ( + Col& eigval, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + Mat A(X.get_ref()); + + const bool status = auxlib::eig_sym(eigval, A); + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eig_sym(): decomposition failed"); + } + + return status; + } + + + +//! Eigenvalues of real/complex symmetric/hermitian matrix X +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, Col >::result +eig_sym + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + Col< T> eigval; + Mat A(X.get_ref()); + + const bool status = auxlib::eig_sym(eigval, A); + + if(status == false) + { + eigval.reset(); + arma_stop_runtime_error("eig_sym(): decomposition failed"); + } + + return eigval; + } + + + +//! internal helper function +template +inline +bool +eig_sym_helper + ( + Col::result>& eigval, + Mat& eigvec, + const Mat& X, + const char method_sig, + const char* caller_sig + ) + { + arma_extra_debug_sigprint(); + + if((arma_config::debug) && (auxlib::rudimentary_sym_check(X) == false)) + { + if(is_cx::no ) { arma_debug_warn_level(1, caller_sig, ": given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, caller_sig, ": given matrix is not hermitian"); } + } + + bool status = false; + + if(method_sig == 'd') { status = auxlib::eig_sym_dc(eigval, eigvec, X); } + + if(status == false) { status = auxlib::eig_sym(eigval, eigvec, X); } + + return status; + } + + + +//! Eigenvalues and eigenvectors of real/complex symmetric/hermitian matrix X +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_sym + ( + Col& eigval, + Mat& eigvec, + const Base& expr, + const char* method = "dc" + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 's') && (sig != 'd')), "eig_sym(): unknown method specified" ); + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eig_sym(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + const quasi_unwrap U(expr.get_ref()); + + const bool is_alias = U.is_alias(eigvec); + + Mat eigvec_tmp; + Mat& eigvec_out = (is_alias == false) ? eigvec : eigvec_tmp; + + const bool status = eig_sym_helper(eigval, eigvec_out, U.M, sig, "eig_sym()"); + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eig_sym(): decomposition failed"); + } + else + { + if(is_alias) { eigvec.steal_mem(eigvec_tmp); } + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_eigs_gen.hpp b/src/armadillo/include/armadillo_bits/fn_eigs_gen.hpp new file mode 100644 index 0000000..6f1a617 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_eigs_gen.hpp @@ -0,0 +1,425 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_eigs_gen +//! @{ + + +//! eigenvalues of general sparse matrix X +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col< std::complex > >::result +eigs_gen + ( + const SpBase& X, + const uword n_eigvals, + const char* form = "lm", + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + Col< std::complex > eigval; + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_gen(): decomposition failed"); + } + + return eigval; + } + + + +//! this form is deprecated; use eigs_gen(X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, Col< std::complex > >::result +eigs_gen + ( + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::pod_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_gen(X, n_eigvals, form, opts); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col< std::complex > >::result +eigs_gen + ( + const SpBase& X, + const uword n_eigvals, + const std::complex sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + Col< std::complex > eigval; + + bool status = false; + + // If X is real and sigma is truly complex, treat X as complex. + // The reason is that we are still not able to apply truly complex shifts to real matrices + if( (is_real::yes) && (std::imag(sigma) != T(0)) ) + { + status = sp_auxlib::eigs_gen(eigval, eigvec, conv_to< SpMat< std::complex > >::from(X), n_eigvals, sigma, opts); + } + else + { + status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, sigma, opts); + } + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_gen(): decomposition failed"); + } + + return eigval; + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col< std::complex > >::result +eigs_gen + ( + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + Col< std::complex > eigval; + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, std::complex(T(sigma)), opts); + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_gen(): decomposition failed"); + } + + return eigval; + } + + + +//! eigenvalues of general sparse matrix X +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + const SpBase& X, + const uword n_eigvals, + const char* form = "lm", + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_gen(eigval, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::pod_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_gen(eigval, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + const SpBase& X, + const uword n_eigvals, + const std::complex sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + + bool status = false; + + // If X is real and sigma is truly complex, treat X as complex. + // The reason is that we are still not able to apply truly complex shifts to real matrices + if( (is_real::yes) && (std::imag(sigma) != T(0)) ) + { + status = sp_auxlib::eigs_gen(eigval, eigvec, conv_to< SpMat< std::complex > >::from(X), n_eigvals, sigma, opts); + } + else + { + status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, sigma, opts); + } + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, std::complex(T(sigma)), opts); + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +//! eigenvalues and eigenvectors of general sparse matrix X +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + Mat< std::complex >& eigvec, + const SpBase& X, + const uword n_eigvals, + const char* form = "lm", + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + // typedef typename T1::pod_type T; + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_gen(eigval, eigvec, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + Mat< std::complex >& eigvec, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::pod_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_gen(eigval, eigvec, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + Mat< std::complex >& eigvec, + const SpBase& X, + const uword n_eigvals, + const std::complex sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + bool status = false; + + // If X is real and sigma is truly complex, treat X as complex. + // The reason is that we are still not able to apply truly complex shifts to real matrices + if( (is_real::yes) && (std::imag(sigma) != T(0)) ) + { + status = sp_auxlib::eigs_gen(eigval, eigvec, conv_to< SpMat< std::complex > >::from(X), n_eigvals, sigma, opts); + } + else + { + status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, sigma, opts); + } + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + Mat< std::complex >& eigvec, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, std::complex(T(sigma)), opts); + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_eigs_sym.hpp b/src/armadillo/include/armadillo_bits/fn_eigs_sym.hpp new file mode 100644 index 0000000..935aab9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_eigs_sym.hpp @@ -0,0 +1,290 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_eigs_sym +//! @{ + + +//! eigenvalues of symmetric real sparse matrix X +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col >::result +eigs_sym + ( + const SpBase& X, + const uword n_eigvals, + const char* form = "lm", + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + Mat eigvec; + Col eigval; + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_sym(): decomposition failed"); + } + + return eigval; + } + + + +//! this form is deprecated; use eigs_sym(X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, Col >::result +eigs_sym + ( + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::elem_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_sym(X, n_eigvals, form, opts); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col >::result +eigs_sym + ( + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat eigvec; + Col eigval; + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, T(sigma), opts); + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_sym(): decomposition failed"); + } + + return eigval; + } + + + +//! eigenvalues of symmetric real sparse matrix X +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + const SpBase& X, + const uword n_eigvals, + const char* form = "lm", + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + Mat eigvec; + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_sym(eigval, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::elem_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_sym(eigval, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat eigvec; + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, T(sigma), opts); + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); + } + + return status; + } + + + +//! eigenvalues and eigenvectors of symmetric real sparse matrix X +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + Mat& eigvec, + const SpBase& X, + const uword n_eigvals, + const char* form = "lm", + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_sym(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_sym(eigval, eigvec, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + Mat& eigvec, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::elem_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_sym(eigval, eigvec, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + Mat& eigvec, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_sym(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, T(sigma), opts); + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_elem.hpp b/src/armadillo/include/armadillo_bits/fn_elem.hpp new file mode 100644 index 0000000..917537f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_elem.hpp @@ -0,0 +1,1167 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_elem +//! @{ + + +// +// real + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const T1& >::result +real(const T1& X) + { + arma_extra_debug_sigprint(); + + return X; + } + + + +template +arma_warn_unused +arma_inline +const T1& +real(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + return X.get_ref(); + } + + + +template +arma_warn_unused +arma_inline +const T1& +real(const SpBase& A) + { + arma_extra_debug_sigprint(); + + return A.get_ref(); + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_cx::yes), const mtOp >::result +real(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp( X ); + } + + + +template +arma_warn_unused +inline +const mtOpCube +real(const BaseCube, T1>& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube( X.get_ref() ); + } + + + +template +arma_warn_unused +arma_inline +const mtSpOp +real(const SpBase,T1>& A) + { + arma_extra_debug_sigprint(); + + return mtSpOp(A.get_ref()); + } + + + +// +// imag + +template +arma_warn_unused +inline +const Gen< Mat, gen_zeros > +imag(const Base& X) + { + arma_extra_debug_sigprint(); + + const Proxy A(X.get_ref()); + + return Gen< Mat, gen_zeros>(A.get_n_rows(), A.get_n_cols()); + } + + + +template +arma_warn_unused +inline +const GenCube +imag(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const ProxyCube A(X.get_ref()); + + return GenCube(A.get_n_rows(), A.get_n_cols(), A.get_n_slices()); + } + + + +template +arma_warn_unused +inline +SpMat +imag(const SpBase& A) + { + arma_extra_debug_sigprint(); + + const SpProxy P(A.get_ref()); + + return SpMat(P.get_n_rows(), P.get_n_cols()); + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_cx::yes), const mtOp >::result +imag(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp( X ); + } + + + +template +arma_warn_unused +inline +const mtOpCube +imag(const BaseCube,T1>& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube( X.get_ref() ); + } + + + +template +arma_warn_unused +arma_inline +const mtSpOp +imag(const SpBase,T1>& A) + { + arma_extra_debug_sigprint(); + + return mtSpOp(A.get_ref()); + } + + + +// +// log + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +log(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +log(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// log2 + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +log2(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +log2(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// log10 + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +log10(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +log10(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// log1p + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +log1p(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const eOpCube >::result +log1p(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// exp + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +exp(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +exp(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// exp2 + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +exp2(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +exp2(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// exp10 + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +exp10(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +exp10(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// expm1 + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +expm1(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const eOpCube >::result +expm1(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// abs + + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +abs(const T1& X) + { + arma_extra_debug_sigprint(); + + return eOp(X); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +abs(const BaseCube& X, const typename arma_not_cx::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + + arma_ignore(junk); + + return eOpCube(X.get_ref()); + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_cx::yes), const mtOp >::result +abs(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X); + } + + + +template +arma_warn_unused +inline +const mtOpCube +abs(const BaseCube< std::complex,T1>& X, const typename arma_cx_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + + arma_ignore(junk); + + return mtOpCube( X.get_ref() ); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +abs(const SpBase& X, const typename arma_not_cx::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return SpOp(X.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const mtSpOp +abs(const SpBase< std::complex, T1>& X, const typename arma_cx_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtSpOp(X.get_ref()); + } + + + +// +// arg + + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +arg(const T1& X) + { + arma_extra_debug_sigprint(); + + return eOp(X); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +arg(const BaseCube& X, const typename arma_not_cx::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + + arma_ignore(junk); + + return eOpCube(X.get_ref()); + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_cx::yes), const mtOp >::result +arg(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X); + } + + + +template +arma_warn_unused +inline +const mtOpCube +arg(const BaseCube< std::complex,T1>& X, const typename arma_cx_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + + arma_ignore(junk); + + return mtOpCube( X.get_ref() ); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +arg(const SpBase& X, const typename arma_not_cx::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return SpOp(X.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const mtSpOp +arg(const SpBase< std::complex, T1>& X, const typename arma_cx_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtSpOp(X.get_ref()); + } + + + +// +// square + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +square(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +square(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +square(const SpBase& A) + { + arma_extra_debug_sigprint(); + + return SpOp(A.get_ref()); + } + + + +// +// sqrt + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +sqrt(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +sqrt(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +sqrt(const SpBase& A) + { + arma_extra_debug_sigprint(); + + return SpOp(A.get_ref()); + } + + + +// +// conj + +template +arma_warn_unused +arma_inline +const T1& +conj(const Base& A) + { + arma_extra_debug_sigprint(); + + return A.get_ref(); + } + + + +template +arma_warn_unused +arma_inline +const T1& +conj(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return A.get_ref(); + } + + + +template +arma_warn_unused +arma_inline +const T1& +conj(const SpBase& A) + { + arma_extra_debug_sigprint(); + + return A.get_ref(); + } + + + +template +arma_warn_unused +arma_inline +const eOp +conj(const Base,T1>& A) + { + arma_extra_debug_sigprint(); + + return eOp(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +conj(const BaseCube,T1>& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +conj(const SpBase,T1>& A) + { + arma_extra_debug_sigprint(); + + return SpOp(A.get_ref()); + } + + + +// pow + +template +arma_warn_unused +arma_inline +const eOp +pow(const Base& A, const typename T1::elem_type exponent) + { + arma_extra_debug_sigprint(); + + return eOp(A.get_ref(), exponent); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +pow(const BaseCube& A, const typename T1::elem_type exponent) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref(), exponent); + } + + + +// pow, specialised handling (non-complex exponent for complex matrices) + +template +arma_warn_unused +arma_inline +const eOp +pow(const Base& A, const typename T1::elem_type::value_type exponent) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + return eOp(A.get_ref(), eT(exponent)); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +pow(const BaseCube& A, const typename T1::elem_type::value_type exponent) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + return eOpCube(A.get_ref(), eT(exponent)); + } + + + +// +// floor + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +floor(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +floor(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +floor(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +// +// ceil + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +ceil(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +ceil(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +ceil(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +// +// round + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +round(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +round(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +round(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +// +// trunc + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +trunc(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +trunc(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +trunc(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +// +// sign + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +sign(const eT x) + { + arma_extra_debug_sigprint(); + + return arma_sign(x); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +sign(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +sign(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +sign(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +// +// erf + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +erf(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const eOpCube >::result +erf(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// erfc + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +erfc(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const eOpCube >::result +erfc(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// lgamma + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +lgamma(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const eOpCube >::result +lgamma(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// tgamma + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +tgamma(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const eOpCube >::result +tgamma(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// the functions below are currently unused; reserved for potential future use + +template void exp_approx(const T1&) { arma_stop_logic_error("unimplemented"); } +template void log_approx(const T1&) { arma_stop_logic_error("unimplemented"); } +template void approx_exp(const T1&) { arma_stop_logic_error("unimplemented"); } +template void approx_log(const T1&) { arma_stop_logic_error("unimplemented"); } + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_eps.hpp b/src/armadillo/include/armadillo_bits/fn_eps.hpp new file mode 100644 index 0000000..f68ba5d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_eps.hpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup fn_eps +//! @{ + + + +template +arma_warn_unused +inline +const eOp +eps(const Base& X, const typename arma_not_cx::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return eOp(X.get_ref()); + } + + + +template +arma_warn_unused +inline +Mat< typename T1::pod_type > +eps(const Base< std::complex, T1>& X, const typename arma_cx_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + const unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + Mat out(A.n_rows, A.n_cols, arma_nozeros_indicator()); + + T* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + + const uword n_elem = A.n_elem; + + for(uword i=0; i +arma_warn_unused +arma_inline +typename arma_integral_only::result +eps(const eT& x) + { + arma_ignore(x); + + return eT(0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_real_only::result +eps(const eT& x) + { + return eop_aux::direct_eps(x); + } + + + +template +arma_warn_unused +arma_inline +typename arma_real_only::result +eps(const std::complex& x) + { + return eop_aux::direct_eps(x); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_expmat.hpp b/src/armadillo/include/armadillo_bits/fn_expmat.hpp new file mode 100644 index 0000000..5e5909b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_expmat.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_expmat +//! @{ + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_real::value, + const Op + >::result +expmat(const Base& A) + { + arma_extra_debug_sigprint(); + + return Op(A.get_ref()); + } + + + +template +inline +typename +enable_if2 + < + is_real::value, + bool + >::result +expmat(Mat& B, const Base& A) + { + arma_extra_debug_sigprint(); + + const bool status = op_expmat::apply_direct(B, A); + + if(status == false) + { + B.soft_reset(); + arma_debug_warn_level(3, "expmat(): given matrix appears ill-conditioned"); + } + + return status; + } + + + +// + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +expmat_sym(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +expmat_sym(Mat& Y, const Base& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_expmat_sym::apply_direct(Y, X.get_ref()); + + if(status == false) + { + Y.soft_reset(); + arma_debug_warn_level(3, "expmat_sym(): transformation failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_eye.hpp b/src/armadillo/include/armadillo_bits/fn_eye.hpp new file mode 100644 index 0000000..4252ffa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_eye.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_eye +//! @{ + + + +arma_warn_unused +arma_inline +const Gen +eye(const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + return Gen(n_rows, n_cols); + } + + + +arma_warn_unused +arma_inline +const Gen +eye(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return Gen(s.n_rows, s.n_cols); + } + + + +template +arma_warn_unused +arma_inline +const Gen +eye(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_Col::value) { arma_debug_check( (n_cols != 1), "eye(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "eye(): incompatible size" ); } + + return Gen(n_rows, n_cols); + } + + + +template +arma_warn_unused +arma_inline +const Gen +eye(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return eye(s.n_rows, s.n_cols); + } + + + +template +arma_warn_unused +inline +obj_type +eye(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_SpCol::value) { arma_debug_check( (n_cols != 1), "eye(): incompatible size" ); } + if(is_SpRow::value) { arma_debug_check( (n_rows != 1), "eye(): incompatible size" ); } + + obj_type out; + + out.eye(n_rows, n_cols); + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +eye(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return eye(s.n_rows, s.n_cols); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_fft.hpp b/src/armadillo/include/armadillo_bits/fn_fft.hpp new file mode 100644 index 0000000..d2d11fb --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_fft.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_fft +//! @{ + + + +// 1D FFT & 1D IFFT + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_real::value), + const mtOp, T1, op_fft_real> + >::result +fft(const T1& A) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_fft_real>(A, uword(0), uword(1)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_real::value), + const mtOp, T1, op_fft_real> + >::result +fft(const T1& A, const uword N) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_fft_real>(A, N, uword(0)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx_float::yes || is_cx_double::yes)), + const Op + >::result +fft(const T1& A) + { + arma_extra_debug_sigprint(); + + return Op(A, uword(0), uword(1)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx_float::yes || is_cx_double::yes)), + const Op + >::result +fft(const T1& A, const uword N) + { + arma_extra_debug_sigprint(); + + return Op(A, N, uword(0)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx_float::yes || is_cx_double::yes)), + const Op + >::result +ifft(const T1& A) + { + arma_extra_debug_sigprint(); + + return Op(A, uword(0), uword(1)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx_float::yes || is_cx_double::yes)), + const Op + >::result +ifft(const T1& A, const uword N) + { + arma_extra_debug_sigprint(); + + return Op(A, N, uword(0)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_fft2.hpp b/src/armadillo/include/armadillo_bits/fn_fft2.hpp new file mode 100644 index 0000000..51ea0da --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_fft2.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_fft2 +//! @{ + + + +// 2D FFT & 2D IFFT + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + Mat< std::complex > + >::result +fft2(const T1& A) + { + arma_extra_debug_sigprint(); + + // not exactly efficient, but "better-than-nothing" implementation + + typedef typename T1::pod_type T; + + Mat< std::complex > B = fft(A); + + // for square matrices, strans() will work out that an inplace transpose can be done, + // hence we can potentially avoid creating a temporary matrix + + B = strans(B); + + return strans( fft(B) ); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + Mat< std::complex > + >::result +fft2(const T1& A, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap tmp(A); + const Mat& B = tmp.M; + + const bool do_resize = (B.n_rows != n_rows) || (B.n_cols != n_cols); + + return (do_resize) ? fft2(resize(B,n_rows,n_cols)) : fft2(B); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx_float::yes || is_cx_double::yes)), + Mat< std::complex > + >::result +ifft2(const T1& A) + { + arma_extra_debug_sigprint(); + + // not exactly efficient, but "better-than-nothing" implementation + + typedef typename T1::pod_type T; + + Mat< std::complex > B = ifft(A); + + // for square matrices, strans() will work out that an inplace transpose can be done, + // hence we can potentially avoid creating a temporary matrix + + B = strans(B); + + return strans( ifft(B) ); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx_float::yes || is_cx_double::yes)), + Mat< std::complex > + >::result +ifft2(const T1& A, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap tmp(A); + const Mat& B = tmp.M; + + const bool do_resize = (B.n_rows != n_rows) || (B.n_cols != n_cols); + + return (do_resize) ? ifft2(resize(B,n_rows,n_cols)) : ifft2(B); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_find.hpp b/src/armadillo/include/armadillo_bits/fn_find.hpp new file mode 100644 index 0000000..5efb254 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_find.hpp @@ -0,0 +1,469 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_find +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +find(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X); + } + + + +template +arma_warn_unused +inline +const mtOp +find(const Base& X, const uword k, const char* direction = "first") + { + arma_extra_debug_sigprint(); + + const char sig = (direction != nullptr) ? direction[0] : char(0); + + arma_debug_check + ( + ( (sig != 'f') && (sig != 'F') && (sig != 'l') && (sig != 'L') ), + "find(): direction must be \"first\" or \"last\"" + ); + + const uword type = ( (sig == 'f') || (sig == 'F') ) ? 0 : 1; + + return mtOp(X.get_ref(), k, type); + } + + + +// + + + +template +arma_warn_unused +inline +uvec +find(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.get_ref()); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find(R); + } + + + +template +arma_warn_unused +inline +uvec +find(const BaseCube& X, const uword k, const char* direction = "first") + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.get_ref()); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find(R, k, direction); + } + + + +template +arma_warn_unused +inline +uvec +find(const mtOpCube& X, const uword k = 0, const char* direction = "first") + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.m); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find( mtOp, op_rel_type>(R, X.aux), k, direction ); + } + + + +template +arma_warn_unused +inline +uvec +find(const mtGlueCube& X, const uword k = 0, const char* direction = "first") + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + const unwrap_cube tmp1(X.A); + const unwrap_cube tmp2(X.B); + + arma_debug_assert_same_size( tmp1.M, tmp2.M, "relational operator" ); + + const Mat R1( const_cast< eT1* >(tmp1.M.memptr()), tmp1.M.n_elem, 1, false ); + const Mat R2( const_cast< eT2* >(tmp2.M.memptr()), tmp2.M.n_elem, 1, false ); + + return find( mtGlue, Mat, glue_rel_type>(R1, R2), k, direction ); + } + + + +// + + + +template +arma_warn_unused +inline +Col +find(const SpBase& X, const uword k = 0) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X.get_ref()); + + const uword n_rows = P.get_n_rows(); + const uword n_nz = P.get_n_nonzero(); + + Mat tmp(n_nz, 1, arma_nozeros_indicator()); + + uword* tmp_mem = tmp.memptr(); + + typename SpProxy::const_iterator_type it = P.begin(); + + for(uword i=0; i out; + + const uword count = (k == 0) ? uword(n_nz) : uword( (std::min)(n_nz, k) ); + + out.steal_mem_col(tmp, count); + + return out; + } + + + +template +arma_warn_unused +inline +Col +find(const SpBase& X, const uword k, const char* direction) + { + arma_extra_debug_sigprint(); + + arma_ignore(X); + arma_ignore(k); + arma_ignore(direction); + + arma_check(true, "find(SpBase,k,direction): not implemented yet"); // TODO + + Col out; + + return out; + } + + + +// + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +find_finite(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +find_nonfinite(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +find_nan(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X); + } + + + +// + + + +template +arma_warn_unused +inline +uvec +find_finite(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.get_ref()); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find_finite(R); + } + + + +template +arma_warn_unused +inline +uvec +find_nonfinite(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.get_ref()); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find_nonfinite(R); + } + + + +template +arma_warn_unused +inline +uvec +find_nan(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.get_ref()); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find_nan(R); + } + + + +// + + + +template +arma_warn_unused +inline +Col +find_finite(const SpBase& X) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X.get_ref()); + + const uword n_rows = P.get_n_rows(); + const uword n_nz = P.get_n_nonzero(); + + Mat tmp(n_nz, 1, arma_nozeros_indicator()); + + uword* tmp_mem = tmp.memptr(); + + typename SpProxy::const_iterator_type it = P.begin(); + + uword count = 0; + + for(uword i=0; i out; + + if(count > 0) { out.steal_mem_col(tmp, count); } + + return out; + } + + + +template +arma_warn_unused +inline +Col +find_nonfinite(const SpBase& X) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X.get_ref()); + + const uword n_rows = P.get_n_rows(); + const uword n_nz = P.get_n_nonzero(); + + Mat tmp(n_nz, 1, arma_nozeros_indicator()); + + uword* tmp_mem = tmp.memptr(); + + typename SpProxy::const_iterator_type it = P.begin(); + + uword count = 0; + + for(uword i=0; i out; + + if(count > 0) { out.steal_mem_col(tmp, count); } + + return out; + } + + + +template +arma_warn_unused +inline +Col +find_nan(const SpBase& X) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X.get_ref()); + + const uword n_rows = P.get_n_rows(); + const uword n_nz = P.get_n_nonzero(); + + Mat tmp(n_nz, 1, arma_nozeros_indicator()); + + uword* tmp_mem = tmp.memptr(); + + typename SpProxy::const_iterator_type it = P.begin(); + + uword count = 0; + + for(uword i=0; i out; + + if(count > 0) { out.steal_mem_col(tmp, count); } + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_find_unique.hpp b/src/armadillo/include/armadillo_bits/fn_find_unique.hpp new file mode 100644 index 0000000..4d90ca1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_find_unique.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_find_unique +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +find_unique + ( + const T1& X, + const bool ascending_indices = true + ) + { + arma_extra_debug_sigprint(); + + return mtOp(X, ((ascending_indices) ? uword(1) : uword(0)), uword(0)); + } + + + +template +arma_warn_unused +inline +uvec +find_unique + ( + const BaseCube& X, + const bool ascending_indices = true + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.get_ref()); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find_unique(R,ascending_indices); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_flip.hpp b/src/armadillo/include/armadillo_bits/fn_flip.hpp new file mode 100644 index 0000000..811a5e0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_flip.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_flip +//! @{ + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +flipud(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +fliplr(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +flipud(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +fliplr(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref()); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_hess.hpp b/src/armadillo/include/armadillo_bits/fn_hess.hpp new file mode 100644 index 0000000..adf31d6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_hess.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_hess +//! @{ + + +template +inline +bool +hess + ( + Mat& H, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + Col tao; + + const bool status = auxlib::hess(H, X.get_ref(), tao); + + if(H.n_rows > 2) + { + for(uword i=0; i < H.n_rows-2; ++i) + { + H(span(i+2, H.n_rows-1), i).zeros(); + } + } + + if(status == false) + { + H.soft_reset(); + arma_debug_warn_level(3, "hess(): decomposition failed"); + } + + return status; + } + + + +template +arma_warn_unused +inline +Mat +hess + ( + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + Mat H; + Col tao; + + const bool status = auxlib::hess(H, X.get_ref(), tao); + + if(H.n_rows > 2) + { + for(uword i=0; i < H.n_rows-2; ++i) + { + H(span(i+2, H.n_rows-1), i).zeros(); + } + } + + if(status == false) + { + H.soft_reset(); + arma_stop_runtime_error("hess(): decomposition failed"); + } + + return H; + } + + + +template +inline +bool +hess + ( + Mat& U, + Mat& H, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_debug_check( void_ptr(&U) == void_ptr(&H), "hess(): 'U' is an alias of 'H'" ); + + typedef typename T1::elem_type eT; + + Col tao; + + const bool status = auxlib::hess(H, X.get_ref(), tao); + + if(H.n_rows == 0) + { + U.reset(); + } + else + if(H.n_rows == 1) + { + U.ones(1, 1); + } + else + if(H.n_rows == 2) + { + U.eye(2, 2); + } + else + { + U.eye(size(H)); + + Col v; + + for(uword i=0; i < H.n_rows-2; ++i) + { + // TODO: generate v in a more efficient manner; + // TODO: the .ones() operation is an overkill, as most of v is overwritten afterwards + + v.ones(H.n_rows-i-1); + + v(span(1, H.n_rows-i-2)) = H(span(i+2, H.n_rows-1), i); + + U(span::all, span(i+1, H.n_rows-1)) -= tao(i) * (U(span::all, span(i+1, H.n_rows-1)) * v * v.t()); + } + + U(span::all, H.n_rows-1) = U(span::all, H.n_rows-1) * (eT(1) - tao(H.n_rows-2)); + + for(uword i=0; i < H.n_rows-2; ++i) + { + H(span(i+2, H.n_rows-1), i).zeros(); + } + } + + if(status == false) + { + U.soft_reset(); + H.soft_reset(); + arma_debug_warn_level(3, "hess(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_hist.hpp b/src/armadillo/include/armadillo_bits/fn_hist.hpp new file mode 100644 index 0000000..018de94 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_hist.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_hist +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_cx::no, + const mtOp + >::result +hist(const T1& A, const uword n_bins = 10) + { + arma_extra_debug_sigprint(); + + return mtOp(A, n_bins, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_cx::no && is_same_type::value, + const mtGlue + >::result +hist(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, Y); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_cx::no && is_same_type::value, + const mtGlue + >::result +hist(const T1& X, const T2& Y, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, Y, dim); + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_histc.hpp b/src/armadillo/include/armadillo_bits/fn_histc.hpp new file mode 100644 index 0000000..e99f896 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_histc.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_histc +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_cx::no && is_same_type::value, + const mtGlue + >::result +histc(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, Y); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_cx::no && is_same_type::value, + const mtGlue + >::result +histc(const T1& X, const T2& Y, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, Y, dim); + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_index_max.hpp b/src/armadillo/include/armadillo_bits/fn_index_max.hpp new file mode 100644 index 0000000..aad33f8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_index_max.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_index_max +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, uword>::result +index_max(const T1& X) + { + arma_extra_debug_sigprint(); + + return X.index_max(); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const mtOp >::result +index_max(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const mtOp >::result +index_max(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtOp(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +const mtOpCube +index_max + ( + const BaseCube& X, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), dim, 0, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::elem_type + >::result +index_max(const T1& x) + { + arma_extra_debug_sigprint(); + + return x.index_max(); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + Mat + >::result +index_max(const T1& X) + { + arma_extra_debug_sigprint(); + + Mat out; + + op_index_max::apply(out, X, 0); + + return out; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + Mat + >::result +index_max(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + Mat out; + + op_index_max::apply(out, X, dim); + + return out; + } + + + +arma_warn_unused +inline +uword +index_max(const SizeMat& s) + { + return (s.n_rows >= s.n_cols) ? uword(0) : uword(1); + } + + + +arma_warn_unused +inline +uword +index_max(const SizeCube& s) + { + const uword tmp_val = (s.n_rows >= s.n_cols) ? s.n_rows : s.n_cols; + const uword tmp_index = (s.n_rows >= s.n_cols) ? uword(0) : uword(1); + + return (tmp_val >= s.n_slices) ? tmp_index : uword(2); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_index_min.hpp b/src/armadillo/include/armadillo_bits/fn_index_min.hpp new file mode 100644 index 0000000..e1b3ce2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_index_min.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_index_min +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, uword>::result +index_min(const T1& X) + { + arma_extra_debug_sigprint(); + + return X.index_min(); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const mtOp >::result +index_min(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const mtOp >::result +index_min(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtOp(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +const mtOpCube +index_min + ( + const BaseCube& X, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), dim, 0, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::elem_type + >::result +index_min(const T1& x) + { + arma_extra_debug_sigprint(); + + return x.index_min(); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + Mat + >::result +index_min(const T1& X) + { + arma_extra_debug_sigprint(); + + Mat out; + + op_index_min::apply(out, X, 0); + + return out; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + Mat + >::result +index_min(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + Mat out; + + op_index_min::apply(out, X, dim); + + return out; + } + + + +arma_warn_unused +inline +uword +index_min(const SizeMat& s) + { + return (s.n_rows <= s.n_cols) ? uword(0) : uword(1); + } + + + +arma_warn_unused +inline +uword +index_min(const SizeCube& s) + { + const uword tmp_val = (s.n_rows <= s.n_cols) ? s.n_rows : s.n_cols; + const uword tmp_index = (s.n_rows <= s.n_cols) ? uword(0) : uword(1); + + return (tmp_val <= s.n_slices) ? tmp_index : uword(2); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_inplace_strans.hpp b/src/armadillo/include/armadillo_bits/fn_inplace_strans.hpp new file mode 100644 index 0000000..ad09cf2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_inplace_strans.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_inplace_strans +//! @{ + + + +template +inline +void +inplace_strans + ( + Mat& X, + const char* method = "std" + ) + { + arma_extra_debug_sigprint(); + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_strans(): unknown method specified" ); + + const bool low_memory = (sig == 'l'); + + if( (low_memory == false) || (X.n_rows == X.n_cols) ) + { + op_strans::apply_mat_inplace(X); + } + else + { + // in-place algorithm inspired by: + // Fred G. Gustavson, Tadeusz Swirszcz. + // In-Place Transposition of Rectangular Matrices. + // Applied Parallel Computing. State of the Art in Scientific Computing. + // Lecture Notes in Computer Science. Volume 4699, pp. 560-569, 2007. + + + // X.set_size() will check whether we can change the dimensions of X; + // X.set_size() will also reuse existing memory, as the number of elements hasn't changed + + X.set_size(X.n_cols, X.n_rows); + + const uword m = X.n_cols; + const uword n = X.n_rows; + + std::vector visited(X.n_elem); // TODO: replace std::vector with a better implementation + + for(uword col = 0; col < m; ++col) + for(uword row = 0; row < n; ++row) + { + const uword pos = col*n + row; + + if(visited[pos] == false) + { + uword curr_pos = pos; + + eT val = X.at(row, col); + + while(visited[curr_pos] == false) + { + visited[curr_pos] = true; + + const uword j = curr_pos / m; + const uword i = curr_pos - m * j; + + const eT tmp = X.at(j, i); + X.at(j, i) = val; + val = tmp; + + curr_pos = i*n + j; + } + } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_inplace_trans.hpp b/src/armadillo/include/armadillo_bits/fn_inplace_trans.hpp new file mode 100644 index 0000000..0e23848 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_inplace_trans.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_inplace_trans +//! @{ + + + +template +inline +typename +enable_if2 + < + is_cx::no, + void + >::result +inplace_htrans + ( + Mat& X, + const char* method = "std" + ) + { + arma_extra_debug_sigprint(); + + inplace_strans(X, method); + } + + + +template +inline +typename +enable_if2 + < + is_cx::yes, + void + >::result +inplace_htrans + ( + Mat& X, + const char* method = "std" + ) + { + arma_extra_debug_sigprint(); + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_htrans(): unknown method specified" ); + + const bool low_memory = (sig == 'l'); + + if( (low_memory == false) || (X.n_rows == X.n_cols) ) + { + op_htrans::apply_mat_inplace(X); + } + else + { + inplace_strans(X, method); + + X = conj(X); + } + } + + + +template +inline +typename +enable_if2 + < + is_cx::no, + void + >::result +inplace_trans + ( + Mat& X, + const char* method = "std" + ) + { + arma_extra_debug_sigprint(); + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_trans(): unknown method specified" ); + + inplace_strans(X, method); + } + + + +template +inline +typename +enable_if2 + < + is_cx::yes, + void + >::result +inplace_trans + ( + Mat& X, + const char* method = "std" + ) + { + arma_extra_debug_sigprint(); + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_trans(): unknown method specified" ); + + inplace_htrans(X, method); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_interp1.hpp b/src/armadillo/include/armadillo_bits/fn_interp1.hpp new file mode 100644 index 0000000..d115423 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_interp1.hpp @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_interp1 +//! @{ + + + +template +inline +void +interp1_helper_nearest(const Mat& XG, const Mat& YG, const Mat& XI, Mat& YI, const eT extrap_val) + { + arma_extra_debug_sigprint(); + + const eT XG_min = XG.min(); + const eT XG_max = XG.max(); + + YI.copy_size(XI); + + const eT* XG_mem = XG.memptr(); + const eT* YG_mem = YG.memptr(); + const eT* XI_mem = XI.memptr(); + eT* YI_mem = YI.memptr(); + + const uword NG = XG.n_elem; + const uword NI = XI.n_elem; + + uword best_j = 0; + + for(uword i=0; i::inf; + + const eT XI_val = XI_mem[i]; + + if((XI_val < XG_min) || (XI_val > XG_max)) + { + YI_mem[i] = extrap_val; + } + else + if(arma_isnan(XI_val)) + { + YI_mem[i] = Datum::nan; + } + else + { + // XG and XI are guaranteed to be sorted in ascending manner, + // so start searching XG from last known optimum position + + for(uword j=best_j; j= eT(0)) ? tmp : -tmp; + + if(err >= best_err) + { + // error is going up, so we have found the optimum position + break; + } + else + { + best_err = err; + best_j = j; // remember the optimum position + } + } + + YI_mem[i] = YG_mem[best_j]; + } + } + } + + + +template +inline +void +interp1_helper_linear(const Mat& XG, const Mat& YG, const Mat& XI, Mat& YI, const eT extrap_val) + { + arma_extra_debug_sigprint(); + + const eT XG_min = XG.min(); + const eT XG_max = XG.max(); + + YI.copy_size(XI); + + const eT* XG_mem = XG.memptr(); + const eT* YG_mem = YG.memptr(); + const eT* XI_mem = XI.memptr(); + eT* YI_mem = YI.memptr(); + + const uword NG = XG.n_elem; + const uword NI = XI.n_elem; + + uword a_best_j = 0; + uword b_best_j = 0; + + for(uword i=0; i XG_max)) + { + YI_mem[i] = extrap_val; + } + else + if(arma_isnan(XI_val)) + { + YI_mem[i] = Datum::nan; + } + else + { + // XG and XI are guaranteed to be sorted in ascending manner, + // so start searching XG from last known optimum position + + eT a_best_err = Datum::inf; + eT b_best_err = Datum::inf; + + for(uword j=a_best_j; j= eT(0)) ? tmp : -tmp; + + if(err >= a_best_err) + { + break; + } + else + { + a_best_err = err; + a_best_j = j; + } + } + + if( (XG_mem[a_best_j] - XI_val) <= eT(0) ) + { + // a_best_j is to the left of the interpolated position + + b_best_j = ( (a_best_j+1) < NG) ? (a_best_j+1) : a_best_j; + } + else + { + // a_best_j is to the right of the interpolated position + + b_best_j = (a_best_j >= 1) ? (a_best_j-1) : a_best_j; + } + + b_best_err = std::abs( XG_mem[b_best_j] - XI_val ); + + if(a_best_j > b_best_j) + { + std::swap(a_best_j, b_best_j ); + std::swap(a_best_err, b_best_err); + } + + const eT weight = (a_best_err > eT(0)) ? (a_best_err / (a_best_err + b_best_err)) : eT(0); + + YI_mem[i] = (eT(1) - weight)*YG_mem[a_best_j] + (weight)*YG_mem[b_best_j]; + } + } + } + + + +template +inline +void +interp1_helper(const Mat& X, const Mat& Y, const Mat& XI, Mat& YI, const uword sig, const eT extrap_val) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ((X.is_vec() == false) || (Y.is_vec() == false) || (XI.is_vec() == false)), "interp1(): currently only vectors are supported" ); + + arma_debug_check( (X.n_elem != Y.n_elem), "interp1(): X and Y must have the same number of elements" ); + + arma_debug_check( (X.n_elem < 2), "interp1(): X must have at least two unique elements" ); + + // sig = 10: nearest neighbour + // sig = 11: nearest neighbour, assume monotonic increase in X and XI + // + // sig = 20: linear + // sig = 21: linear, assume monotonic increase in X and XI + + if(sig == 11) { interp1_helper_nearest(X, Y, XI, YI, extrap_val); return; } + if(sig == 21) { interp1_helper_linear (X, Y, XI, YI, extrap_val); return; } + + uvec X_indices; + + try { X_indices = find_unique(X,false); } catch(...) { } + + // NOTE: find_unique(X,false) provides indices of elements sorted in ascending order + // NOTE: find_unique(X,false) will reset X_indices if X has NaN + + const uword N_subset = X_indices.n_elem; + + arma_debug_check( (N_subset < 2), "interp1(): X must have at least two unique elements" ); + + Mat X_sanitised(N_subset, 1, arma_nozeros_indicator()); + Mat Y_sanitised(N_subset, 1, arma_nozeros_indicator()); + + eT* X_sanitised_mem = X_sanitised.memptr(); + eT* Y_sanitised_mem = Y_sanitised.memptr(); + + const eT* X_mem = X.memptr(); + const eT* Y_mem = Y.memptr(); + + const uword* X_indices_mem = X_indices.memptr(); + + for(uword i=0; i XI_tmp; + uvec XI_indices; + + const bool XI_is_sorted = XI.is_sorted(); // NOTE: .is_sorted() currently doesn't detect NaN + + if(XI_is_sorted == false) + { + XI_indices = sort_index(XI); // NOTE: sort_index() will throw if XI has NaN + + const uword N = XI.n_elem; + + XI_tmp.copy_size(XI); + + const uword* XI_indices_mem = XI_indices.memptr(); + + const eT* XI_mem = XI.memptr(); + eT* XI_tmp_mem = XI_tmp.memptr(); + + for(uword i=0; i& XI_sorted = (XI_is_sorted) ? XI : XI_tmp; + + // NOTE: XI_sorted may have NaN + + + if(sig == 10) { interp1_helper_nearest(X_sanitised, Y_sanitised, XI_sorted, YI, extrap_val); } + else if(sig == 20) { interp1_helper_linear (X_sanitised, Y_sanitised, XI_sorted, YI, extrap_val); } + + + if( (XI_is_sorted == false) && (YI.n_elem > 0) ) + { + Mat YI_unsorted; + + YI_unsorted.copy_size(YI); + + const eT* YI_mem = YI.memptr(); + eT* YI_unsorted_mem = YI_unsorted.memptr(); + + const uword N = XI_sorted.n_elem; + const uword* XI_indices_mem = XI_indices.memptr(); + + for(uword i=0; i +inline +typename +enable_if2 + < + is_real::value, + void + >::result +interp1 + ( + const Base& X, + const Base& Y, + const Base& XI, + Mat& YI, + const char* method = "linear", + const typename T1::elem_type extrap_val = Datum::nan + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + uword sig = 0; + + if(method != nullptr) + if(method[0] != char(0)) + if(method[1] != char(0)) + { + const char c1 = method[0]; + const char c2 = method[1]; + + if(c1 == 'n') { sig = 10; } // nearest neighbour + else if(c1 == 'l') { sig = 20; } // linear + else + { + if( (c1 == '*') && (c2 == 'n') ) { sig = 11; } // nearest neighour, assume monotonic increase in X and XI + if( (c1 == '*') && (c2 == 'l') ) { sig = 21; } // linear, assume monotonic increase in X and XI + } + } + + arma_debug_check( (sig == 0), "interp1(): unsupported interpolation type" ); + + const quasi_unwrap X_tmp( X.get_ref()); + const quasi_unwrap Y_tmp( Y.get_ref()); + const quasi_unwrap XI_tmp(XI.get_ref()); + + if( X_tmp.is_alias(YI) || Y_tmp.is_alias(YI) || XI_tmp.is_alias(YI) ) + { + Mat tmp; + + interp1_helper(X_tmp.M, Y_tmp.M, XI_tmp.M, tmp, sig, extrap_val); + + YI.steal_mem(tmp); + } + else + { + interp1_helper(X_tmp.M, Y_tmp.M, XI_tmp.M, YI, sig, extrap_val); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_interp2.hpp b/src/armadillo/include/armadillo_bits/fn_interp2.hpp new file mode 100644 index 0000000..b9b6127 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_interp2.hpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_interp2 +//! @{ + + + +template +inline +void +interp2_helper_nearest(const Mat& XG, const Mat& ZG, const Mat& XI, Mat& ZI, const eT extrap_val, const uword mode) + { + arma_extra_debug_sigprint(); + + const eT XG_min = XG.min(); + const eT XG_max = XG.max(); + + // mode = 0: interpolate across rows (eg. expand in vertical direction) + // mode = 1: interpolate across columns (eg. expand in horizontal direction) + + if(mode == 0) { ZI.set_size(XI.n_elem, ZG.n_cols); } + if(mode == 1) { ZI.set_size(ZG.n_rows, XI.n_elem); } + + const eT* XG_mem = XG.memptr(); + const eT* XI_mem = XI.memptr(); + + const uword NG = XG.n_elem; + const uword NI = XI.n_elem; + + uword best_j = 0; + + for(uword i=0; i::inf; + + const eT XI_val = XI_mem[i]; + + if((XI_val < XG_min) || (XI_val > XG_max)) + { + if(mode == 0) { ZI.row(i).fill(extrap_val); } + if(mode == 1) { ZI.col(i).fill(extrap_val); } + } + else + { + // XG and XI are guaranteed to be sorted in ascending manner, + // so start searching XG from last known optimum position + + for(uword j=best_j; j= eT(0)) ? tmp : -tmp; + + if(err >= best_err) + { + // error is going up, so we have found the optimum position + break; + } + else + { + best_err = err; + best_j = j; // remember the optimum position + } + } + + if(mode == 0) { ZI.row(i) = ZG.row(best_j); } + if(mode == 1) { ZI.col(i) = ZG.col(best_j); } + } + } + } + + + +template +inline +void +interp2_helper_linear(const Mat& XG, const Mat& ZG, const Mat& XI, Mat& ZI, const eT extrap_val, const uword mode) + { + arma_extra_debug_sigprint(); + + const eT XG_min = XG.min(); + const eT XG_max = XG.max(); + + // mode = 0: interpolate across rows (eg. expand in vertical direction) + // mode = 1: interpolate across columns (eg. expand in horizontal direction) + + if(mode == 0) { ZI.set_size(XI.n_elem, ZG.n_cols); } + if(mode == 1) { ZI.set_size(ZG.n_rows, XI.n_elem); } + + const eT* XG_mem = XG.memptr(); + const eT* XI_mem = XI.memptr(); + + const uword NG = XG.n_elem; + const uword NI = XI.n_elem; + + uword a_best_j = 0; + uword b_best_j = 0; + + for(uword i=0; i XG_max)) + { + if(mode == 0) { ZI.row(i).fill(extrap_val); } + if(mode == 1) { ZI.col(i).fill(extrap_val); } + } + else + { + // XG and XI are guaranteed to be sorted in ascending manner, + // so start searching XG from last known optimum position + + eT a_best_err = Datum::inf; + eT b_best_err = Datum::inf; + + for(uword j=a_best_j; j= eT(0)) ? tmp : -tmp; + + if(err >= a_best_err) + { + break; + } + else + { + a_best_err = err; + a_best_j = j; + } + } + + if( (XG_mem[a_best_j] - XI_val) <= eT(0) ) + { + // a_best_j is to the left of the interpolated position + + b_best_j = ( (a_best_j+1) < NG) ? (a_best_j+1) : a_best_j; + } + else + { + // a_best_j is to the right of the interpolated position + + b_best_j = (a_best_j >= 1) ? (a_best_j-1) : a_best_j; + } + + b_best_err = std::abs( XG_mem[b_best_j] - XI_val ); + + if(a_best_j > b_best_j) + { + std::swap(a_best_j, b_best_j ); + std::swap(a_best_err, b_best_err); + } + + const eT weight = (a_best_err > eT(0)) ? (a_best_err / (a_best_err + b_best_err)) : eT(0); + + if(mode == 0) { ZI.row(i) = (eT(1) - weight)*ZG.row(a_best_j) + (weight)*ZG.row(b_best_j); } + if(mode == 1) { ZI.col(i) = (eT(1) - weight)*ZG.col(a_best_j) + (weight)*ZG.col(b_best_j); } + } + } + } + + + +template +inline +typename +enable_if2< is_real::value, void >::result +interp2 + ( + const Base& X, + const Base& Y, + const Base& Z, + const Base& XI, + const Base& YI, + Mat& ZI, + const char* method = "linear", + const typename T1::elem_type extrap_val = Datum::nan + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 'n') && (sig != 'l')), "interp2(): unsupported interpolation type" ); + + const quasi_unwrap UXG( X.get_ref() ); + const quasi_unwrap UYG( Y.get_ref() ); + const quasi_unwrap UZG( Z.get_ref() ); + const quasi_unwrap UXI( XI.get_ref() ); + const quasi_unwrap UYI( YI.get_ref() ); + + arma_debug_check( (UXG.M.is_vec() == false), "interp2(): X must resolve to a vector" ); + arma_debug_check( (UYG.M.is_vec() == false), "interp2(): Y must resolve to a vector" ); + + arma_debug_check( (UXI.M.is_vec() == false), "interp2(): XI must resolve to a vector" ); + arma_debug_check( (UYI.M.is_vec() == false), "interp2(): YI must resolve to a vector" ); + + arma_debug_check( (UXG.M.n_elem < 2), "interp2(): X must have at least two unique elements" ); + arma_debug_check( (UYG.M.n_elem < 2), "interp2(): Y must have at least two unique elements" ); + + arma_debug_check( (UXG.M.n_elem != UZG.M.n_cols), "interp2(): number of elements in X must equal the number of columns in Z" ); + arma_debug_check( (UYG.M.n_elem != UZG.M.n_rows), "interp2(): number of elements in Y must equal the number of rows in Z" ); + + arma_debug_check( (UXG.M.is_sorted("strictascend") == false), "interp2(): X must be monotonically increasing" ); + arma_debug_check( (UYG.M.is_sorted("strictascend") == false), "interp2(): Y must be monotonically increasing" ); + + arma_debug_check( (UXI.M.is_sorted("strictascend") == false), "interp2(): XI must be monotonically increasing" ); + arma_debug_check( (UYI.M.is_sorted("strictascend") == false), "interp2(): YI must be monotonically increasing" ); + + Mat tmp; + + if( UXG.is_alias(ZI) || UXI.is_alias(ZI) ) + { + Mat out; + + if(sig == 'n') + { + interp2_helper_nearest(UYG.M, UZG.M, UYI.M, tmp, extrap_val, 0); + interp2_helper_nearest(UXG.M, tmp, UXI.M, out, extrap_val, 1); + } + else + if(sig == 'l') + { + interp2_helper_linear(UYG.M, UZG.M, UYI.M, tmp, extrap_val, 0); + interp2_helper_linear(UXG.M, tmp, UXI.M, out, extrap_val, 1); + } + + ZI.steal_mem(out); + } + else + { + if(sig == 'n') + { + interp2_helper_nearest(UYG.M, UZG.M, UYI.M, tmp, extrap_val, 0); + interp2_helper_nearest(UXG.M, tmp, UXI.M, ZI, extrap_val, 1); + } + else + if(sig == 'l') + { + interp2_helper_linear(UYG.M, UZG.M, UYI.M, tmp, extrap_val, 0); + interp2_helper_linear(UXG.M, tmp, UXI.M, ZI, extrap_val, 1); + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_intersect.hpp b/src/armadillo/include/armadillo_bits/fn_intersect.hpp new file mode 100644 index 0000000..37afa52 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_intersect.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_intersect +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + ( is_arma_type::value && is_arma_type::value && is_same_type::value ), + const Glue + >::result +intersect + ( + const T1& A, + const T2& B + ) + { + arma_extra_debug_sigprint(); + + return Glue(A, B); + } + + + +template +inline +void +intersect + ( + Mat& C, + uvec& iA, + uvec& iB, + const Base& A, + const Base& B + ) + { + arma_extra_debug_sigprint(); + + glue_intersect::apply(C, iA, iB, A.get_ref(), B.get_ref(), true); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_inv.hpp b/src/armadillo/include/armadillo_bits/fn_inv.hpp new file mode 100644 index 0000000..65589f7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_inv.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_inv +//! @{ + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +inv + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv + ( + Mat& out, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_gen_default::apply_direct(out, X.get_ref(), "inv()"); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "inv(): matrix is singular"); + } + + return status; + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +inv + ( + const Base& X, + const inv_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), opts.flags, uword(0)); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv + ( + Mat& out, + const Base& X, + const inv_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_gen_full::apply_direct(out, X.get_ref(), "inv()", opts.flags); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "inv(): matrix is singular"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv + ( + Mat& out_inv, + typename T1::pod_type& out_rcond, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + op_inv_gen_state inv_state; + + const bool status = op_inv_gen_rcond::apply_direct(out_inv, inv_state, X.get_ref()); + + out_rcond = inv_state.rcond; + + if(status == false) + { + out_rcond = T(0); + out_inv.soft_reset(); + arma_debug_warn_level(3, "inv(): matrix is singular"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_inv_sympd.hpp b/src/armadillo/include/armadillo_bits/fn_inv_sympd.hpp new file mode 100644 index 0000000..ffd1d0d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_inv_sympd.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_inv_sympd +//! @{ + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +inv_sympd + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv_sympd + ( + Mat& out, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_spd_default::apply_direct(out, X.get_ref()); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "inv_sympd(): matrix is singular or not positive definite"); + } + + return status; + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +inv_sympd + ( + const Base& X, + const inv_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), opts.flags, uword(0)); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv_sympd + ( + Mat& out, + const Base& X, + const inv_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_spd_full::apply_direct(out, X.get_ref(), opts.flags); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "inv_sympd(): matrix is singular or not positive definite"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv_sympd + ( + Mat& out_inv, + typename T1::pod_type& out_rcond, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + op_inv_spd_state inv_state; + + const bool status = op_inv_spd_rcond::apply_direct(out_inv, inv_state, X.get_ref()); + + out_rcond = inv_state.rcond; + + if(status == false) + { + out_rcond = T(0); + out_inv.soft_reset(); + arma_debug_warn_level(3, "inv_sympd(): matrix is singular or not positive definite"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_join.hpp b/src/armadillo/include/armadillo_bits/fn_join.hpp new file mode 100644 index 0000000..6d3ed07 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_join.hpp @@ -0,0 +1,502 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_join +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_same_type::value), + const Glue + >::result +join_cols(const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + return Glue(A, B); + } + + + +template +arma_warn_unused +inline +Mat +join_cols(const Base& A, const Base& B, const Base& C) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +Mat +join_cols(const Base& A, const Base& B, const Base& C, const Base& D) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_same_type::value), + const Glue + >::result +join_vert(const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + return Glue(A, B); + } + + + +template +arma_warn_unused +inline +Mat +join_vert(const Base& A, const Base& B, const Base& C) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +Mat +join_vert(const Base& A, const Base& B, const Base& C, const Base& D) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_same_type::value), + const Glue + >::result +join_rows(const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + return Glue(A, B); + } + + + +template +arma_warn_unused +inline +Mat +join_rows(const Base& A, const Base& B, const Base& C) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +Mat +join_rows(const Base& A, const Base& B, const Base& C, const Base& D) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_same_type::value), + const Glue + >::result +join_horiz(const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + return Glue(A, B); + } + + + +template +arma_warn_unused +inline +Mat +join_horiz(const Base& A, const Base& B, const Base& C) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +Mat +join_horiz(const Base& A, const Base& B, const Base& C, const Base& D) + { + arma_extra_debug_sigprint(); + + Mat out; + + glue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +// +// for cubes + +template +arma_warn_unused +inline +const GlueCube +join_slices(const BaseCube& A, const BaseCube& B) + { + arma_extra_debug_sigprint(); + + return GlueCube(A.get_ref(), B.get_ref()); + } + + + +template +arma_warn_unused +inline +Cube +join_slices(const Base& A, const Base& B) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(A.get_ref()); + const quasi_unwrap UB(B.get_ref()); + + arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "join_slices(): incompatible dimensions"); + + Cube out(UA.M.n_rows, UA.M.n_cols, 2, arma_nozeros_indicator()); + + arrayops::copy(out.slice_memptr(0), UA.M.memptr(), UA.M.n_elem); + arrayops::copy(out.slice_memptr(1), UB.M.memptr(), UB.M.n_elem); + + return out; + } + + + +template +arma_warn_unused +inline +Cube +join_slices(const Base& A, const BaseCube& B) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(A.get_ref()); + + const Cube M(const_cast(U.M.memptr()), U.M.n_rows, U.M.n_cols, 1, false); + + return join_slices(M,B); + } + + + +template +arma_warn_unused +inline +Cube +join_slices(const BaseCube& A, const Base& B) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(B.get_ref()); + + const Cube M(const_cast(U.M.memptr()), U.M.n_rows, U.M.n_cols, 1, false); + + return join_slices(A,M); + } + + + +// +// for sparse matrices + +template +arma_warn_unused +inline +const SpGlue +join_cols(const SpBase& A, const SpBase& B) + { + arma_extra_debug_sigprint(); + + return SpGlue(A.get_ref(), B.get_ref()); + } + + + +template +arma_warn_unused +inline +SpMat +join_cols(const SpBase& A, const SpBase& B, const SpBase& C) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +SpMat +join_cols(const SpBase& A, const SpBase& B, const SpBase& C, const SpBase& D) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +const SpGlue +join_vert(const SpBase& A, const SpBase& B) + { + arma_extra_debug_sigprint(); + + return SpGlue(A.get_ref(), B.get_ref()); + } + + + +template +arma_warn_unused +inline +SpMat +join_vert(const SpBase& A, const SpBase& B, const SpBase& C) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +SpMat +join_vert(const SpBase& A, const SpBase& B, const SpBase& C, const SpBase& D) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_cols::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +const SpGlue +join_rows(const SpBase& A, const SpBase& B) + { + arma_extra_debug_sigprint(); + + return SpGlue(A.get_ref(), B.get_ref()); + } + + + +template +arma_warn_unused +inline +SpMat +join_rows(const SpBase& A, const SpBase& B, const SpBase& C) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +SpMat +join_rows(const SpBase& A, const SpBase& B, const SpBase& C, const SpBase& D) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +const SpGlue +join_horiz(const SpBase& A, const SpBase& B) + { + arma_extra_debug_sigprint(); + + return SpGlue(A.get_ref(), B.get_ref()); + } + + + +template +arma_warn_unused +inline +SpMat +join_horiz(const SpBase& A, const SpBase& B, const SpBase& C) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +SpMat +join_horiz(const SpBase& A, const SpBase& B, const SpBase& C, const SpBase& D) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_join_rows::apply(out, A.get_ref(), B.get_ref(), C.get_ref(), D.get_ref()); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_kmeans.hpp b/src/armadillo/include/armadillo_bits/fn_kmeans.hpp new file mode 100644 index 0000000..8707435 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_kmeans.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_kmeans +//! @{ + + + +template +inline +typename enable_if2::value, bool>::result +kmeans + ( + Mat& means, + const Base& data, + const uword k, + const gmm_seed_mode& seed_mode, + const uword n_iter, + const bool print_mode + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + gmm_priv::gmm_diag model; + + const bool status = model.kmeans_wrapper(means, data.get_ref(), k, seed_mode, n_iter, print_mode); + + if(status) + { + means = model.means; + } + else + { + means.soft_reset(); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_kron.hpp b/src/armadillo/include/armadillo_bits/fn_kron.hpp new file mode 100644 index 0000000..61b5be2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_kron.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_kron +//! @{ + + + +template +arma_warn_unused +arma_inline +const Glue +kron(const Base& A, const Base& B) + { + arma_extra_debug_sigprint(); + + return Glue(A.get_ref(), B.get_ref()); + } + + + +template +arma_warn_unused +inline +Mat::eT> +kron(const Base,T1>& X, const Base& Y) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT1; + + promote_type::check(); + + const quasi_unwrap tmp1(X.get_ref()); + const quasi_unwrap tmp2(Y.get_ref()); + + const Mat& A = tmp1.M; + const Mat& B = tmp2.M; + + Mat out; + + glue_kron::direct_kron(out, A, B); + + return out; + } + + + +template +arma_warn_unused +inline +Mat::eT> +kron(const Base& X, const Base,T2>& Y) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT2; + + promote_type::check(); + + const quasi_unwrap tmp1(X.get_ref()); + const quasi_unwrap tmp2(Y.get_ref()); + + const Mat& A = tmp1.M; + const Mat& B = tmp2.M; + + Mat out; + + glue_kron::direct_kron(out, A, B); + + return out; + } + + + +template +arma_warn_unused +arma_inline +const SpGlue +kron(const SpBase& A, const SpBase& B) + { + arma_extra_debug_sigprint(); + + return SpGlue(A.get_ref(), B.get_ref()); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_log_det.hpp b/src/armadillo/include/armadillo_bits/fn_log_det.hpp new file mode 100644 index 0000000..3ea463e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_log_det.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_log_det +//! @{ + + + +//! log determinant of mat +template +inline +bool +log_det + ( + typename T1::elem_type& out_val, + typename T1::pod_type& out_sign, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const bool status = op_log_det::apply_direct(out_val, out_sign, X.get_ref()); + + if(status == false) + { + out_val = eT(Datum::nan); + out_sign = T(0); + + arma_debug_warn_level(3, "log_det(): failed to find determinant"); + } + + return status; + } + + + +template +arma_warn_unused +inline +std::complex +log_det + ( + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + eT out_val = eT(0); + T out_sign = T(0); + + const bool status = op_log_det::apply_direct(out_val, out_sign, X.get_ref()); + + if(status == false) + { + out_val = eT(Datum::nan); + out_sign = T(0); + + arma_stop_runtime_error("log_det(): failed to find determinant"); + } + + return (out_sign >= T(1)) ? std::complex(out_val) : (out_val + std::complex(T(0),Datum::pi)); + } + + + +// + + + +template +inline +bool +log_det_sympd + ( + typename T1::pod_type& out_val, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + + out_val = T(0); + + const bool status = op_log_det_sympd::apply_direct(out_val, X.get_ref()); + + if(status == false) + { + out_val = Datum::nan; + + arma_debug_warn_level(3, "log_det_sympd(): given matrix is not symmetric positive definite"); + } + + return status; + } + + + +template +arma_warn_unused +inline +typename T1::pod_type +log_det_sympd + ( + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + + T out_val = T(0); + + const bool status = op_log_det_sympd::apply_direct(out_val, X.get_ref()); + + if(status == false) + { + out_val = Datum::nan; + + arma_stop_runtime_error("log_det_sympd(): given matrix is not symmetric positive definite"); + } + + return out_val; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_log_normpdf.hpp b/src/armadillo/include/armadillo_bits/fn_log_normpdf.hpp new file mode 100644 index 0000000..cb404db --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_log_normpdf.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_log_normpdf +//! @{ + + + +template +inline +typename enable_if2< (is_real::value), void >::result +log_normpdf_helper(Mat& out, const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(Proxy::use_at || Proxy::use_at || Proxy::use_at) + { + const quasi_unwrap UX(X_expr.get_ref()); + const quasi_unwrap UM(M_expr.get_ref()); + const quasi_unwrap US(S_expr.get_ref()); + + log_normpdf_helper(out, UX.M, UM.M, US.M); + + return; + } + + const Proxy PX(X_expr.get_ref()); + const Proxy PM(M_expr.get_ref()); + const Proxy PS(S_expr.get_ref()); + + arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "log_normpdf(): size mismatch" ); + + out.set_size(PX.get_n_rows(), PX.get_n_cols()); + + eT* out_mem = out.memptr(); + + const uword N = PX.get_n_elem(); + + typename Proxy::ea_type X_ea = PX.get_ea(); + typename Proxy::ea_type M_ea = PM.get_ea(); + typename Proxy::ea_type S_ea = PS.get_ea(); + + const bool use_mp = arma_config::openmp && mp_gate::eval(N); + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i::log_sqrt2pi); + } + } + #endif + } + else + { + for(uword i=0; i::log_sqrt2pi); + } + } + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +log_normpdf(const eT x) + { + const eT out = (eT(-0.5) * (x*x)) - Datum::log_sqrt2pi; + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +log_normpdf(const eT x, const eT mu, const eT sigma) + { + const eT tmp = (x - mu) / sigma; + + const eT out = (eT(-0.5) * (tmp*tmp)) - (std::log(sigma) + Datum::log_sqrt2pi); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const eT x, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UM(M_expr.get_ref()); + const Mat& M = UM.M; + + Mat out; + + log_normpdf_helper(out, x*ones< Mat >(arma::size(M)), M, S_expr.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const Base& X_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + log_normpdf_helper(out, X, zeros< Mat >(arma::size(X)), ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const Base& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + log_normpdf_helper(out, X, mu*ones< Mat >(arma::size(X)), sigma*ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + Mat out; + + log_normpdf_helper(out, X_expr.get_ref(), M_expr.get_ref(), S_expr.get_ref()); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_logmat.hpp b/src/armadillo/include/armadillo_bits/fn_logmat.hpp new file mode 100644 index 0000000..e169987 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_logmat.hpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_logmat +//! @{ + + + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_supported_blas_type::value && is_cx::no), const mtOp, T1, op_logmat> >::result +logmat(const Base& X, const uword n_iters = 100u) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_logmat>(X.get_ref(), n_iters, uword(0)); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_supported_blas_type::value && is_cx::yes), const Op >::result +logmat(const Base& X, const uword n_iters = 100u) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), n_iters, uword(0)); + } + + + +template +inline +typename enable_if2< (is_supported_blas_type::value && is_cx::no), bool >::result +logmat(Mat< std::complex >& Y, const Base& X, const uword n_iters = 100u) + { + arma_extra_debug_sigprint(); + + const bool status = op_logmat::apply_direct(Y, X.get_ref(), n_iters); + + if(status == false) + { + Y.soft_reset(); + arma_debug_warn_level(3, "logmat(): transformation failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< (is_supported_blas_type::value && is_cx::yes), bool >::result +logmat(Mat& Y, const Base& X, const uword n_iters = 100u) + { + arma_extra_debug_sigprint(); + + const bool status = op_logmat_cx::apply_direct(Y, X.get_ref(), n_iters); + + if(status == false) + { + Y.soft_reset(); + arma_debug_warn_level(3, "logmat(): transformation failed"); + } + + return status; + } + + + +// + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +logmat_sympd(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +logmat_sympd(Mat& Y, const Base& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_logmat_sympd::apply_direct(Y, X.get_ref()); + + if(status == false) + { + Y.soft_reset(); + arma_debug_warn_level(3, "logmat_sympd(): transformation failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_lu.hpp b/src/armadillo/include/armadillo_bits/fn_lu.hpp new file mode 100644 index 0000000..60c52b8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_lu.hpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_lu +//! @{ + + + +//! immediate lower upper decomposition, permutation info is embedded into L (similar to Matlab/Octave) +template +inline +bool +lu + ( + Mat& L, + Mat& U, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_debug_check( (&L == &U), "lu(): L and U are the same object" ); + + const bool status = auxlib::lu(L, U, X); + + if(status == false) + { + L.soft_reset(); + U.soft_reset(); + arma_debug_warn_level(3, "lu(): decomposition failed"); + } + + return status; + } + + + +//! immediate lower upper decomposition, also providing the permutation matrix +template +inline +bool +lu + ( + Mat& L, + Mat& U, + Mat& P, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_debug_check( ( (&L == &U) || (&L == &P) || (&U == &P) ), "lu(): two or more output objects are the same object" ); + + const bool status = auxlib::lu(L, U, P, X); + + if(status == false) + { + L.soft_reset(); + U.soft_reset(); + P.soft_reset(); + arma_debug_warn_level(3, "lu(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_max.hpp b/src/armadillo/include/armadillo_bits/fn_max.hpp new file mode 100644 index 0000000..dcbf1ce --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_max.hpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_max +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, typename T1::elem_type >::result +max(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_max::max(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const Op >::result +max(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +max(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +max(const T& x) + { + return x; + } + + + +//! element-wise maximum +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + ( is_arma_type::value && is_arma_type::value && is_same_type::value ), + const Glue + >::result +max + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + return Glue(X, Y); + } + + + +template +arma_warn_unused +arma_inline +const OpCube +max + ( + const BaseCube& X, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), dim, 0); + } + + + +template +arma_warn_unused +arma_inline +const GlueCube +max + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return GlueCube(X.get_ref(), Y.get_ref()); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::elem_type + >::result +max(const T1& x) + { + arma_extra_debug_sigprint(); + + return spop_max::vector_max(x); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const SpOp + >::result +max(const T1& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X, 0, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + const SpOp + >::result +max(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return SpOp(X, dim, 0); + } + + + +// elementwise sparse max +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::value), + const SpGlue + >::result +max(const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + return SpGlue(x, y); + } + + + +//! elementwise max of dense and sparse objects with the same element type +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), + Mat + >::result +max + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat out; + + spglue_max::dense_sparse_max(out, x, y); + + return out; + } + + + +//! elementwise max of sparse and dense objects with the same element type +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), + Mat + >::result +max + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat out; + + // Just call the other order (these operations are commutative) + // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order + spglue_max::dense_sparse_max(out, y, x); + + return out; + } + + + +arma_warn_unused +inline +uword +max(const SizeMat& s) + { + return (std::max)(s.n_rows, s.n_cols); + } + + + +arma_warn_unused +inline +uword +max(const SizeCube& s) + { + return (std::max)( (std::max)(s.n_rows, s.n_cols), s.n_slices ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_mean.hpp b/src/armadillo/include/armadillo_bits/fn_mean.hpp new file mode 100644 index 0000000..b1400c1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_mean.hpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_mean +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, typename T1::elem_type >::result +mean(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_mean::mean_all(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const Op >::result +mean(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +mean(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +mean(const T& x) + { + return x; + } + + + +template +arma_warn_unused +arma_inline +const OpCube +mean + ( + const BaseCube& X, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), dim, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::elem_type + >::result +mean(const T1& x) + { + arma_extra_debug_sigprint(); + + return spop_mean::mean_all(x); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const SpOp + >::result +mean(const T1& x) + { + arma_extra_debug_sigprint(); + + return SpOp(x, 0, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + const SpOp + >::result +mean(const T1& x, const uword dim) + { + arma_extra_debug_sigprint(); + + return SpOp(x, dim, 0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_median.hpp b/src/armadillo/include/armadillo_bits/fn_median.hpp new file mode 100644 index 0000000..48ff756 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_median.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_median +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, typename T1::elem_type >::result +median(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_median::median_vec(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const Op >::result +median(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +median(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +median(const T& x) + { + return x; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_min.hpp b/src/armadillo/include/armadillo_bits/fn_min.hpp new file mode 100644 index 0000000..3baa128 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_min.hpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_min +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, typename T1::elem_type >::result +min(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_min::min(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const Op >::result +min(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +min(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +min(const T& x) + { + return x; + } + + + +//! element-wise minimum +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + ( is_arma_type::value && is_arma_type::value && is_same_type::value ), + const Glue + >::result +min + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + return Glue(X, Y); + } + + + +template +arma_warn_unused +arma_inline +const OpCube +min + ( + const BaseCube& X, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), dim, 0); + } + + + +template +arma_warn_unused +arma_inline +const GlueCube +min + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return GlueCube(X.get_ref(), Y.get_ref()); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::elem_type + >::result +min(const T1& x) + { + arma_extra_debug_sigprint(); + + return spop_min::vector_min(x); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const SpOp + >::result +min(const T1& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X, 0, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + const SpOp + >::result +min(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return SpOp(X, dim, 0); + } + + + +// elementwise sparse min +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::value), + const SpGlue + >::result +min(const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + return SpGlue(x, y); + } + + + +//! elementwise min of dense and sparse objects with the same element type +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), + Mat + >::result +min + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat out; + + spglue_min::dense_sparse_min(out, x, y); + + return out; + } + + + +//! elementwise min of sparse and dense objects with the same element type +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), + Mat + >::result +min + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat out; + + // Just call the other order (these operations are commutative) + // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order + spglue_min::dense_sparse_min(out, y, x); + + return out; + } + + + +arma_warn_unused +inline +uword +min(const SizeMat& s) + { + return (std::min)(s.n_rows, s.n_cols); + } + + + +arma_warn_unused +inline +uword +min(const SizeCube& s) + { + return (std::min)( (std::min)(s.n_rows, s.n_cols), s.n_slices ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_misc.hpp b/src/armadillo/include/armadillo_bits/fn_misc.hpp new file mode 100644 index 0000000..51930e4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_misc.hpp @@ -0,0 +1,587 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_misc +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_Mat::value, + out_type + >::result +linspace + ( + const typename out_type::pod_type start, + const typename out_type::pod_type end, + const uword num = 100u + ) + { + arma_extra_debug_sigprint(); + + typedef typename out_type::elem_type eT; + typedef typename out_type::pod_type T; + + out_type x; + + if(num == 1) + { + x.set_size(1); + + x[0] = eT(end); + } + else + if(num >= 2) + { + x.set_size(num); + + eT* x_mem = x.memptr(); + + const uword num_m1 = num - 1; + + if(is_non_integral::value) + { + const T delta = (end-start)/T(num_m1); + + for(uword i=0; i= start) ? double(end-start)/double(num_m1) : -double(start-end)/double(num_m1); + + for(uword i=0; i(start, end, num); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_Mat::value && is_real::value), + out_type + >::result +logspace + ( + const typename out_type::pod_type A, + const typename out_type::pod_type B, + const uword N = 50u + ) + { + arma_extra_debug_sigprint(); + + typedef typename out_type::elem_type eT; + typedef typename out_type::pod_type T; + + out_type x = linspace(A,B,N); + + const uword n_elem = x.n_elem; + + eT* x_mem = x.memptr(); + + for(uword i=0; i < n_elem; ++i) + { + x_mem[i] = std::pow(T(10), x_mem[i]); + } + + return x; + } + + + +arma_warn_unused +inline +vec +logspace(const double A, const double B, const uword N = 50u) + { + arma_extra_debug_sigprint(); + return logspace(A, B, N); + } + + + +// +// log_exp_add + +template +arma_warn_unused +inline +typename arma_real_only::result +log_add_exp(eT log_a, eT log_b) + { + if(log_a < log_b) + { + std::swap(log_a, log_b); + } + + const eT negdelta = log_b - log_a; + + if( (negdelta < Datum::log_min) || (arma_isfinite(negdelta) == false) ) + { + return log_a; + } + else + { + return (log_a + std::log1p(std::exp(negdelta))); + } + } + + + +// for compatibility with earlier versions +template +arma_warn_unused +inline +typename arma_real_only::result +log_add(eT log_a, eT log_b) + { + return log_add_exp(log_a, log_b); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +arma_inline +bool +is_finite(const eT x, const typename arma_scalar_only::result* junk = nullptr) + { + arma_ignore(junk); + + return arma_isfinite(x); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +inline +bool +is_finite(const Base& X) + { + arma_extra_debug_sigprint(); + + return X.is_finite(); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +inline +bool +is_finite(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return X.is_finite(); + } + + + +//! kept for compatibility with old user code +template +arma_warn_unused +inline +bool +is_finite(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + return X.is_finite(); + } + + + +template +inline +void +swap(Mat& A, Mat& B) + { + arma_extra_debug_sigprint(); + + A.swap(B); + } + + + +template +inline +void +swap(Cube& A, Cube& B) + { + arma_extra_debug_sigprint(); + + A.swap(B); + } + + + +arma_warn_unused +inline +uvec +ind2sub(const SizeMat& s, const uword i) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + + arma_debug_check( (i >= (s_n_rows * s.n_cols) ), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uvec out(2, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + + out_mem[0] = row; + out_mem[1] = col; + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_same_type::yes), umat >::result +ind2sub(const SizeMat& s, const T1& indices) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_elem = s_n_rows * s.n_cols; + + const Proxy P(indices); + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + const uword P_n_elem = P.get_n_elem(); + + const bool P_is_empty = (P_n_elem == 0); + const bool P_is_vec = ((P_n_rows == 1) || (P_n_cols == 1)); + + arma_debug_check( ((P_is_empty == false) && (P_is_vec == false)), "ind2sub(): parameter 'indices' must be a vector" ); + + umat out(2, P_n_elem, arma_nozeros_indicator()); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword count=0; count < P_n_elem; ++count) + { + const uword i = Pea[count]; + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + } + } + else + { + if(P_n_rows == 1) + { + for(uword count=0; count < P_n_cols; ++count) + { + const uword i = P.at(0,count); + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + } + } + else + if(P_n_cols == 1) + { + for(uword count=0; count < P_n_rows; ++count) + { + const uword i = P.at(count,0); + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword row = i % s_n_rows; + const uword col = i / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + } + } + } + + return out; + } + + + +arma_warn_unused +inline +uvec +ind2sub(const SizeCube& s, const uword i) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_elem_slice = s_n_rows * s.n_cols; + + arma_debug_check( (i >= (s_n_elem_slice * s.n_slices) ), "ind2sub(): index out of range" ); + + const uword slice = i / s_n_elem_slice; + const uword j = i - (slice * s_n_elem_slice); + const uword row = j % s_n_rows; + const uword col = j / s_n_rows; + + uvec out(3, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + + out_mem[0] = row; + out_mem[1] = col; + out_mem[2] = slice; + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_arma_type::value && is_same_type::yes), umat >::result +ind2sub(const SizeCube& s, const T1& indices) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_elem_slice = s_n_rows * s.n_cols; + const uword s_n_elem = s.n_slices * s_n_elem_slice; + + const quasi_unwrap U(indices); + + arma_debug_check( ((U.M.is_empty() == false) && (U.M.is_vec() == false)), "ind2sub(): parameter 'indices' must be a vector" ); + + const uword U_n_elem = U.M.n_elem; + const uword* U_mem = U.M.memptr(); + + umat out(3, U_n_elem, arma_nozeros_indicator()); + + for(uword count=0; count < U_n_elem; ++count) + { + const uword i = U_mem[count]; + + arma_debug_check( (i >= s_n_elem), "ind2sub(): index out of range" ); + + const uword slice = i / s_n_elem_slice; + const uword j = i - (slice * s_n_elem_slice); + const uword row = j % s_n_rows; + const uword col = j / s_n_rows; + + uword* out_colptr = out.colptr(count); + + out_colptr[0] = row; + out_colptr[1] = col; + out_colptr[2] = slice; + } + + return out; + } + + + +arma_warn_unused +arma_inline +uword +sub2ind(const SizeMat& s, const uword row, const uword col) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + + arma_debug_check( ((row >= s_n_rows) || (col >= s.n_cols)), "sub2ind(): subscript out of range" ); + + return uword(row + col*s_n_rows); + } + + + +template +arma_warn_unused +inline +uvec +sub2ind(const SizeMat& s, const Base& subscripts) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + const quasi_unwrap U(subscripts.get_ref()); + + arma_debug_check( (U.M.n_rows != 2), "sub2ind(): matrix of subscripts must have 2 rows" ); + + const uword U_M_n_cols = U.M.n_cols; + + uvec out(U_M_n_cols, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + const uword* U_M_mem = U.M.memptr(); + + for(uword count=0; count < U_M_n_cols; ++count) + { + const uword row = U_M_mem[0]; + const uword col = U_M_mem[1]; + + U_M_mem += 2; // next column + + arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols)), "sub2ind(): subscript out of range" ); + + out_mem[count] = uword(row + col*s_n_rows); + } + + return out; + } + + + +arma_warn_unused +arma_inline +uword +sub2ind(const SizeCube& s, const uword row, const uword col, const uword slice) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols) || (slice >= s.n_slices)), "sub2ind(): subscript out of range" ); + + return uword( (slice * s_n_rows * s_n_cols) + (col * s_n_rows) + row ); + } + + + +template +arma_warn_unused +inline +uvec +sub2ind(const SizeCube& s, const Base& subscripts) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + const uword s_n_slices = s.n_slices; + + const quasi_unwrap U(subscripts.get_ref()); + + arma_debug_check( (U.M.n_rows != 3), "sub2ind(): matrix of subscripts must have 3 rows" ); + + const uword U_M_n_cols = U.M.n_cols; + + uvec out(U_M_n_cols, arma_nozeros_indicator()); + + uword* out_mem = out.memptr(); + const uword* U_M_mem = U.M.memptr(); + + for(uword count=0; count < U_M_n_cols; ++count) + { + const uword row = U_M_mem[0]; + const uword col = U_M_mem[1]; + const uword slice = U_M_mem[2]; + + U_M_mem += 3; // next column + + arma_debug_check( ((row >= s_n_rows) || (col >= s_n_cols) || (slice >= s_n_slices)), "sub2ind(): subscript out of range" ); + + out_mem[count] = uword( (slice * s_n_rows * s_n_cols) + (col * s_n_rows) + row ); + } + + return out; + } + + + +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_same_type::value), + const Glue + >::result +affmul(const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + return Glue(A,B); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_mvnrnd.hpp b/src/armadillo/include/armadillo_bits/fn_mvnrnd.hpp new file mode 100644 index 0000000..dd873d0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_mvnrnd.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_mvnrnd +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_real::value, + const Glue + >::result +mvnrnd(const Base& M, const Base& C) + { + arma_extra_debug_sigprint(); + + return Glue(M.get_ref(), C.get_ref()); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_real::value, + const Glue + >::result +mvnrnd(const Base& M, const Base& C, const uword N) + { + arma_extra_debug_sigprint(); + + return Glue(M.get_ref(), C.get_ref(), N); + } + + + +template +inline +typename +enable_if2 + < + is_real::value, + bool + >::result +mvnrnd(Mat& out, const Base& M, const Base& C) + { + arma_extra_debug_sigprint(); + + const bool status = glue_mvnrnd::apply_direct(out, M.get_ref(), C.get_ref(), uword(1)); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); + } + + return status; + } + + + +template +inline +typename +enable_if2 + < + is_real::value, + bool + >::result +mvnrnd(Mat& out, const Base& M, const Base& C, const uword N) + { + arma_extra_debug_sigprint(); + + const bool status = glue_mvnrnd::apply_direct(out, M.get_ref(), C.get_ref(), N); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_n_unique.hpp b/src/armadillo/include/armadillo_bits/fn_n_unique.hpp new file mode 100644 index 0000000..2f00b72 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_n_unique.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_n_unique +//! @{ + + +//! \brief +//! Get the number of unique nonzero elements in two sparse matrices. +//! This is very useful for determining the amount of memory necessary before +//! a sparse matrix operation on two matrices. + +template +inline +uword +n_unique + ( + const SpBase& x, + const SpBase& y, + const op_n_unique_type junk + ) + { + arma_extra_debug_sigprint(); + + const SpProxy pa(x.get_ref()); + const SpProxy pb(y.get_ref()); + + return n_unique(pa,pb,junk); + } + + + +template +arma_hot +inline +uword +n_unique + ( + const SpProxy& pa, + const SpProxy& pb, + const op_n_unique_type junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typename SpProxy::const_iterator_type x_it = pa.begin(); + typename SpProxy::const_iterator_type x_it_end = pa.end(); + + typename SpProxy::const_iterator_type y_it = pb.begin(); + typename SpProxy::const_iterator_type y_it_end = pb.end(); + + uword total_n_nonzero = 0; + + while( (x_it != x_it_end) || (y_it != y_it_end) ) + { + if(x_it == y_it) + { + if(op_n_unique_type::eval((*x_it), (*y_it)) != typename T1::elem_type(0)) + { + ++total_n_nonzero; + } + + ++x_it; + ++y_it; + } + else + { + if((x_it.col() < y_it.col()) || ((x_it.col() == y_it.col()) && (x_it.row() < y_it.row()))) // if y is closer to the end + { + if(op_n_unique_type::eval((*x_it), typename T1::elem_type(0)) != typename T1::elem_type(0)) + { + ++total_n_nonzero; + } + + ++x_it; + } + else // x is closer to the end + { + if(op_n_unique_type::eval(typename T1::elem_type(0), (*y_it)) != typename T1::elem_type(0)) + { + ++total_n_nonzero; + } + + ++y_it; + } + } + } + + return total_n_nonzero; + } + + +// Simple operators. +struct op_n_unique_add + { + template inline static eT eval(const eT& l, const eT& r) { return (l + r); } + }; + +struct op_n_unique_sub + { + template inline static eT eval(const eT& l, const eT& r) { return (l - r); } + }; + +struct op_n_unique_mul + { + template inline static eT eval(const eT& l, const eT& r) { return (l * r); } + }; + +struct op_n_unique_count + { + template inline static eT eval(const eT&, const eT&) { return eT(1); } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_nonzeros.hpp b/src/armadillo/include/armadillo_bits/fn_nonzeros.hpp new file mode 100644 index 0000000..202efe1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_nonzeros.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_nonzeros +//! @{ + + +template +arma_warn_unused +inline +const Op +nonzeros(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +arma_warn_unused +inline +const SpToDOp +nonzeros(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpToDOp(X.get_ref()); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_norm.hpp b/src/armadillo/include/armadillo_bits/fn_norm.hpp new file mode 100644 index 0000000..a8f05f0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_norm.hpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_norm +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, typename T1::pod_type >::result +norm + ( + const T1& X, + const uword k = uword(2), + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + + const Proxy P(X); + + if(P.get_n_elem() == 0) { return T(0); } + + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1); + + if(is_vec) + { + if(k == uword(1)) { return op_norm::vec_norm_1(P); } + if(k == uword(2)) { return op_norm::vec_norm_2(P); } + + arma_debug_check( (k == 0), "norm(): unsupported vector norm type" ); + + return op_norm::vec_norm_k(P, int(k)); + } + else + { + const quasi_unwrap::stored_type> U(P.Q); + + if(k == uword(1)) { return op_norm::mat_norm_1(U.M); } + if(k == uword(2)) { return op_norm::mat_norm_2(U.M); } + + arma_stop_logic_error("norm(): unsupported matrix norm type"); + } + + return T(0); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, typename T1::pod_type >::result +norm + ( + const T1& X, + const char* method, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + + const Proxy P(X); + + if(P.get_n_elem() == 0) { return T(0); } + + const char sig = (method != nullptr) ? method[0] : char(0); + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1); + + if(is_vec) + { + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { return op_norm::vec_norm_max(P); } + if( (sig == '-') ) { return op_norm::vec_norm_min(P); } + if( (sig == 'f') || (sig == 'F') ) { return op_norm::vec_norm_2(P); } + + arma_stop_logic_error("norm(): unsupported vector norm type"); + } + else + { + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) // inf norm + { + const quasi_unwrap::stored_type> U(P.Q); + + return op_norm::mat_norm_inf(U.M); + } + else + if( (sig == 'f') || (sig == 'F') ) + { + return op_norm::vec_norm_2(P); + } + + arma_stop_logic_error("norm(): unsupported matrix norm type"); + } + + return T(0); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, double >::result +norm + ( + const T1& X, + const uword k = uword(2), + const typename arma_integral_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(resolves_to_colvector::value) { return norm(conv_to< Col >::from(X), k); } + if(resolves_to_rowvector::value) { return norm(conv_to< Row >::from(X), k); } + + return norm(conv_to< Mat >::from(X), k); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, double >::result +norm + ( + const T1& X, + const char* method, + const typename arma_integral_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(resolves_to_colvector::value) { return norm(conv_to< Col >::from(X), method); } + if(resolves_to_rowvector::value) { return norm(conv_to< Row >::from(X), method); } + + return norm(conv_to< Mat >::from(X), method); + } + + + +// +// norms for sparse matrices + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_sparse_type::value, typename T1::pod_type >::result +norm + ( + const T1& expr, + const uword k = uword(2), + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(is_SpSubview_col::value) + { + const SpSubview_col& sv = reinterpret_cast< const SpSubview_col& >(expr); + + if(sv.n_rows == sv.m.n_rows) + { + const SpMat& m = sv.m; + const uword col = sv.aux_col1; + const eT* mem = &(m.values[ m.col_ptrs[col] ]); + + return spop_norm::vec_norm_k(mem, sv.n_nonzero, k); + } + } + + const unwrap_spmat U(expr); + const SpMat& X = U.M; + + if(X.n_nonzero == 0) { return T(0); } + + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (X.n_rows == 1) || (X.n_cols == 1); + + if(is_vec) + { + return spop_norm::vec_norm_k(X.values, X.n_nonzero, k); + } + else + { + if(k == uword(1)) { return spop_norm::mat_norm_1(X); } + if(k == uword(2)) { return spop_norm::mat_norm_2(X); } + + arma_stop_logic_error("norm(): unsupported or unimplemented norm type for sparse matrices"); + } + + return T(0); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_sparse_type::value, typename T1::pod_type >::result +norm + ( + const T1& expr, + const char* method, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap_spmat U(expr); + const SpMat& X = U.M; + + if(X.n_nonzero == 0) { return T(0); } + + // create a fake dense vector to allow reuse of code for dense vectors + Col fake_vector( access::rwp(X.values), X.n_nonzero, false ); + + const Proxy< Col > P_fake_vector(fake_vector); + + + const char sig = (method != nullptr) ? method[0] : char(0); + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (X.n_rows == 1) || (X.n_cols == 1); + + if(is_vec) + { + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) // max norm + { + return op_norm::vec_norm_max(P_fake_vector); + } + else + if(sig == '-') // min norm + { + const T val = op_norm::vec_norm_min(P_fake_vector); + + return (X.n_nonzero < X.n_elem) ? T((std::min)(T(0), val)) : T(val); + } + else + if( (sig == 'f') || (sig == 'F') ) + { + return op_norm::vec_norm_2(P_fake_vector); + } + + arma_stop_logic_error("norm(): unsupported vector norm type"); + } + else + { + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) // inf norm + { + return spop_norm::mat_norm_inf(X); + } + else + if( (sig == 'f') || (sig == 'F') ) + { + return op_norm::vec_norm_2(P_fake_vector); + } + + arma_stop_logic_error("norm(): unsupported matrix norm type"); + } + + return T(0); + } + + + +// +// approximate norms + + +template +arma_warn_unused +inline +typename T1::pod_type +norm2est + ( + const Base& X, + const typename T1::pod_type tolerance = 0, + const uword max_iter = 100, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return op_norm2est::norm2est(X.get_ref(), tolerance, max_iter); + } + + + +template +arma_warn_unused +inline +typename T1::pod_type +norm2est + ( + const SpBase& X, + const typename T1::pod_type tolerance = 0, + const uword max_iter = 100, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return op_norm2est::norm2est(X.get_ref(), tolerance, max_iter); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_normalise.hpp b/src/armadillo/include/armadillo_bits/fn_normalise.hpp new file mode 100644 index 0000000..ae07430 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_normalise.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_normalise +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +normalise + ( + const T1& X, + const uword p = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return Op(X, p, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +normalise + ( + const T1& X, + const uword p = uword(2), + const uword dim = 0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return Op(X, p, dim); + } + + + +template +arma_warn_unused +inline +const SpOp +normalise + ( + const SpBase& expr, + const uword p = uword(2), + const uword dim = 0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return SpOp(expr.get_ref(), p, dim); + } + + + +//! for compatibility purposes: allows compiling user code designed for earlier versions of Armadillo +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_supported_blas_type::value, + Col + >::result +normalise(const T& val) + { + Col out(1, arma_nozeros_indicator()); + + out[0] = (val != T(0)) ? T(val / (std::abs)(val)) : T(val); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_normcdf.hpp b/src/armadillo/include/armadillo_bits/fn_normcdf.hpp new file mode 100644 index 0000000..06ed5cb --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_normcdf.hpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_normcdf +//! @{ + + + +template +inline +typename enable_if2< (is_real::value), void >::result +normcdf_helper(Mat& out, const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(Proxy::use_at || Proxy::use_at || Proxy::use_at) + { + const quasi_unwrap UX(X_expr.get_ref()); + const quasi_unwrap UM(M_expr.get_ref()); + const quasi_unwrap US(S_expr.get_ref()); + + normcdf_helper(out, UX.M, UM.M, US.M); + + return; + } + + const Proxy PX(X_expr.get_ref()); + const Proxy PM(M_expr.get_ref()); + const Proxy PS(S_expr.get_ref()); + + arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "normcdf(): size mismatch" ); + + out.set_size(PX.get_n_rows(), PX.get_n_cols()); + + eT* out_mem = out.memptr(); + + const uword N = PX.get_n_elem(); + + typename Proxy::ea_type X_ea = PX.get_ea(); + typename Proxy::ea_type M_ea = PM.get_ea(); + typename Proxy::ea_type S_ea = PS.get_ea(); + + const bool use_mp = arma_config::openmp && mp_gate::eval(N); + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i::sqrt2)); + + out_mem[i] = eT(0.5) * std::erfc(tmp); + } + } + #endif + } + else + { + for(uword i=0; i::sqrt2)); + + out_mem[i] = eT(0.5) * std::erfc(tmp); + } + } + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +normcdf(const eT x) + { + const eT out = eT(0.5) * std::erfc( x / (-Datum::sqrt2) ); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +normcdf(const eT x, const eT mu, const eT sigma) + { + const eT tmp = (x - mu) / (sigma * (-Datum::sqrt2)); + + const eT out = eT(0.5) * std::erfc(tmp); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normcdf(const eT x, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UM(M_expr.get_ref()); + const Mat& M = UM.M; + + Mat out; + + normcdf_helper(out, x*ones< Mat >(arma::size(M)), M, S_expr.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normcdf(const Base& X_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + normcdf_helper(out, X, zeros< Mat >(arma::size(X)), ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normcdf(const Base& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + normcdf_helper(out, X, mu*ones< Mat >(arma::size(X)), sigma*ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normcdf(const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + Mat out; + + normcdf_helper(out, X_expr.get_ref(), M_expr.get_ref(), S_expr.get_ref()); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_normpdf.hpp b/src/armadillo/include/armadillo_bits/fn_normpdf.hpp new file mode 100644 index 0000000..e05af41 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_normpdf.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_normpdf +//! @{ + + + +template +inline +typename enable_if2< (is_real::value), void >::result +normpdf_helper(Mat& out, const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(Proxy::use_at || Proxy::use_at || Proxy::use_at) + { + const quasi_unwrap UX(X_expr.get_ref()); + const quasi_unwrap UM(M_expr.get_ref()); + const quasi_unwrap US(S_expr.get_ref()); + + normpdf_helper(out, UX.M, UM.M, US.M); + + return; + } + + const Proxy PX(X_expr.get_ref()); + const Proxy PM(M_expr.get_ref()); + const Proxy PS(S_expr.get_ref()); + + arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "normpdf(): size mismatch" ); + + out.set_size(PX.get_n_rows(), PX.get_n_cols()); + + eT* out_mem = out.memptr(); + + const uword N = PX.get_n_elem(); + + typename Proxy::ea_type X_ea = PX.get_ea(); + typename Proxy::ea_type M_ea = PM.get_ea(); + typename Proxy::ea_type S_ea = PS.get_ea(); + + const bool use_mp = arma_config::openmp && mp_gate::eval(N); + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i::sqrt2pi); + } + } + #endif + } + else + { + for(uword i=0; i::sqrt2pi); + } + } + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +normpdf(const eT x) + { + const eT out = std::exp(eT(-0.5) * (x*x)) / Datum::sqrt2pi; + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +normpdf(const eT x, const eT mu, const eT sigma) + { + const eT tmp = (x - mu) / sigma; + + const eT out = std::exp(eT(-0.5) * (tmp*tmp)) / (sigma * Datum::sqrt2pi); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normpdf(const eT x, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UM(M_expr.get_ref()); + const Mat& M = UM.M; + + Mat out; + + normpdf_helper(out, x*ones< Mat >(arma::size(M)), M, S_expr.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normpdf(const Base& X_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + normpdf_helper(out, X, zeros< Mat >(arma::size(X)), ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normpdf(const Base& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + normpdf_helper(out, X, mu*ones< Mat >(arma::size(X)), sigma*ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +normpdf(const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + Mat out; + + normpdf_helper(out, X_expr.get_ref(), M_expr.get_ref(), S_expr.get_ref()); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_numel.hpp b/src/armadillo/include/armadillo_bits/fn_numel.hpp new file mode 100644 index 0000000..fe5c191 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_numel.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_numel +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, uword >::result +numel(const T1& X) + { + arma_extra_debug_sigprint(); + + const Proxy P(X); + + return P.get_n_elem(); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_cube_type::value, uword >::result +numel(const T1& X) + { + arma_extra_debug_sigprint(); + + const ProxyCube P(X); + + return P.get_n_elem(); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_sparse_type::value, uword >::result +numel(const T1& X) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X); + + return P.get_n_elem(); + } + + + +template +arma_warn_unused +inline +uword +numel(const field& X) + { + arma_extra_debug_sigprint(); + + return X.n_elem; + } + + + +template +arma_warn_unused +inline +uword +numel(const subview_field& X) + { + arma_extra_debug_sigprint(); + + return X.n_elem; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_ones.hpp b/src/armadillo/include/armadillo_bits/fn_ones.hpp new file mode 100644 index 0000000..ae8b622 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_ones.hpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_ones +//! @{ + + + +arma_warn_unused +arma_inline +const Gen +ones(const uword n_elem) + { + arma_extra_debug_sigprint(); + + return Gen(n_elem, 1); + } + + + +template +arma_warn_unused +arma_inline +const Gen +ones(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + return Gen(n_rows, n_cols); + } + + + +arma_warn_unused +arma_inline +const Gen +ones(const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + return Gen(n_rows, n_cols); + } + + + +arma_warn_unused +arma_inline +const Gen +ones(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return Gen(s.n_rows, s.n_cols); + } + + + +template +arma_warn_unused +inline +const Gen +ones(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_Col::value) { arma_debug_check( (n_cols != 1), "ones(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "ones(): incompatible size" ); } + + return Gen(n_rows, n_cols); + } + + + +template +arma_warn_unused +inline +const Gen +ones(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return ones(s.n_rows, s.n_cols); + } + + + +arma_warn_unused +arma_inline +const GenCube +ones(const uword n_rows, const uword n_cols, const uword n_slices) + { + arma_extra_debug_sigprint(); + + return GenCube(n_rows, n_cols, n_slices); + } + + + +arma_warn_unused +arma_inline +const GenCube +ones(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return GenCube(s.n_rows, s.n_cols, s.n_slices); + } + + + +template +arma_warn_unused +arma_inline +const GenCube +ones(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return GenCube(n_rows, n_cols, n_slices); + } + + + +template +arma_warn_unused +arma_inline +const GenCube +ones(const SizeCube& s, const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return GenCube(s.n_rows, s.n_cols, s.n_slices); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_orth_null.hpp b/src/armadillo/include/armadillo_bits/fn_orth_null.hpp new file mode 100644 index 0000000..fe68906 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_orth_null.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_orth_null +//! @{ + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_real::value, const Op >::result +orth(const Base& X, const typename T1::pod_type tol = 0.0) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + return Op(X.get_ref(), eT(tol)); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +orth(Mat& out, const Base& X, const typename T1::pod_type tol = 0.0) + { + arma_extra_debug_sigprint(); + + const bool status = op_orth::apply_direct(out, X.get_ref(), tol); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "orth(): svd failed"); + } + + return status; + } + + + +// + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_real::value, const Op >::result +null(const Base& X, const typename T1::pod_type tol = 0.0) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + return Op(X.get_ref(), eT(tol)); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +null(Mat& out, const Base& X, const typename T1::pod_type tol = 0.0) + { + arma_extra_debug_sigprint(); + + const bool status = op_null::apply_direct(out, X.get_ref(), tol); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "null(): svd failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_pinv.hpp b/src/armadillo/include/armadillo_bits/fn_pinv.hpp new file mode 100644 index 0000000..6a87322 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_pinv.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_pinv +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, const Op >::result +pinv + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, const Op >::result +pinv + ( + const Base& X, + const typename T1::pod_type tol, + const char* method = nullptr + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + uword method_id = 0; // default setting + + if(method != nullptr) + { + const char sig = method[0]; + + arma_debug_check( ((sig != 's') && (sig != 'd')), "pinv(): unknown method specified" ); + + if(sig == 's') { method_id = 1; } + if(sig == 'd') { method_id = 2; } + } + + return Op(X.get_ref(), eT(tol), method_id, uword(0)); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +pinv + ( + Mat& out, + const Base& X, + const typename T1::pod_type tol = 0.0, + const char* method = nullptr + ) + { + arma_extra_debug_sigprint(); + + uword method_id = 0; // default setting + + if(method != nullptr) + { + const char sig = method[0]; + + arma_debug_check( ((sig != 's') && (sig != 'd')), "pinv(): unknown method specified" ); + + if(sig == 's') { method_id = 1; } + if(sig == 'd') { method_id = 2; } + } + + const bool status = op_pinv::apply_direct(out, X.get_ref(), tol, method_id); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "pinv(): svd failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_polyfit.hpp b/src/armadillo/include/armadillo_bits/fn_polyfit.hpp new file mode 100644 index 0000000..e51e37b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_polyfit.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_polyfit +//! @{ + + + +template +inline +typename +enable_if2 + < + is_supported_blas_type::value, + bool + >::result +polyfit(Mat& out, const Base& X, const Base& Y, const uword N) + { + arma_extra_debug_sigprint(); + + const bool status = glue_polyfit::apply_direct(out, X.get_ref(), Y.get_ref(), N); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "polyfit(): failed"); + } + + return status; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_supported_blas_type::value, + const Glue + >::result +polyfit(const Base& X, const Base& Y, const uword N) + { + arma_extra_debug_sigprint(); + + return Glue(X.get_ref(), Y.get_ref(), N); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_polyval.hpp b/src/armadillo/include/armadillo_bits/fn_polyval.hpp new file mode 100644 index 0000000..f22c728 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_polyval.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_polyval +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_supported_blas_type::value && is_arma_type::value && is_same_type::value), + const Glue + >::result +polyval(const Base& P, const T2& X) + { + arma_extra_debug_sigprint(); + + return Glue(P.get_ref(), X); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_powext.hpp b/src/armadillo/include/armadillo_bits/fn_powext.hpp new file mode 100644 index 0000000..a971219 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_powext.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_powext +//! @{ + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Glue + >::result +pow + ( + const T1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return Glue(X, Y.get_ref()); + } + + + +template +arma_warn_unused +inline +Mat +pow + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext::apply(X,Y); + } + + + +template +arma_warn_unused +arma_inline +const GlueCube +pow + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return GlueCube(X.get_ref(), Y.get_ref()); + } + + + +template +arma_warn_unused +inline +Cube +pow + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext::apply(X,Y); + } + + + +// + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + ( is_arma_type::value && is_cx::yes ), + const mtGlue + >::result +pow + ( + const T1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, Y.get_ref()); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_cx::yes, + Mat + >::result +pow + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext_cx::apply(X,Y); + } + + + +template +arma_warn_unused +arma_inline +const mtGlueCube +pow + ( + const BaseCube< std::complex, T1>& X, + const BaseCube< typename T1::pod_type , T2>& Y + ) + { + arma_extra_debug_sigprint(); + + return mtGlueCube(X.get_ref(), Y.get_ref()); + } + + + +template +arma_warn_unused +inline +Cube< std::complex > +pow + ( + const subview_cube_each1< std::complex >& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext_cx::apply(X,Y); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_powmat.hpp b/src/armadillo/include/armadillo_bits/fn_powmat.hpp new file mode 100644 index 0000000..17d0293 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_powmat.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_powmat +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +powmat(const Base& X, const int y) + { + arma_extra_debug_sigprint(); + + const uword aux_a = (y < int(0)) ? uword(-y) : uword(y); + const uword aux_b = (y < int(0)) ? uword(1) : uword(0); + + return Op(X.get_ref(), aux_a, aux_b); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +powmat + ( + Mat& out, + const Base& X, + const int y + ) + { + arma_extra_debug_sigprint(); + + const uword y_val = (y < int(0)) ? uword(-y) : uword(y); + const bool y_neg = (y < int(0)); + + const bool status = op_powmat::apply_direct(out, X.get_ref(), y_val, y_neg); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "powmat(): transformation failed"); + } + + return status; + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const mtOp,T1,op_powmat_cx> >::result +powmat(const Base& X, const double y) + { + arma_extra_debug_sigprint(); + + typedef std::complex out_eT; + + return mtOp('j', X.get_ref(), out_eT(y)); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +powmat + ( + Mat< std::complex >& out, + const Base& X, + const double y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const bool status = op_powmat_cx::apply_direct(out, X.get_ref(), T(y)); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "powmat(): transformation failed"); + } + + return status; + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_princomp.hpp b/src/armadillo/include/armadillo_bits/fn_princomp.hpp new file mode 100644 index 0000000..0a251b5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_princomp.hpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_princomp +//! @{ + + + +//! \brief +//! principal component analysis -- 4 arguments version +//! coeff_out -> principal component coefficients +//! score_out -> projected samples +//! latent_out -> eigenvalues of principal vectors +//! tsquared_out -> Hotelling's T^2 statistic +template +inline +bool +princomp + ( + Mat& coeff_out, + Mat& score_out, + Col& latent_out, + Col& tsquared_out, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool status = op_princomp::direct_princomp(coeff_out, score_out, latent_out, tsquared_out, X); + + if(status == false) + { + coeff_out.soft_reset(); + score_out.soft_reset(); + latent_out.soft_reset(); + tsquared_out.soft_reset(); + + arma_debug_warn_level(3, "princomp(): decomposition failed"); + } + + return status; + } + + + +//! \brief +//! principal component analysis -- 3 arguments version +//! coeff_out -> principal component coefficients +//! score_out -> projected samples +//! latent_out -> eigenvalues of principal vectors +template +inline +bool +princomp + ( + Mat& coeff_out, + Mat& score_out, + Col& latent_out, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool status = op_princomp::direct_princomp(coeff_out, score_out, latent_out, X); + + if(status == false) + { + coeff_out.soft_reset(); + score_out.soft_reset(); + latent_out.soft_reset(); + + arma_debug_warn_level(3, "princomp(): decomposition failed"); + } + + return status; + } + + + +//! \brief +//! principal component analysis -- 2 arguments version +//! coeff_out -> principal component coefficients +//! score_out -> projected samples +template +inline +bool +princomp + ( + Mat& coeff_out, + Mat& score_out, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool status = op_princomp::direct_princomp(coeff_out, score_out, X); + + if(status == false) + { + coeff_out.soft_reset(); + score_out.soft_reset(); + + arma_debug_warn_level(3, "princomp(): decomposition failed"); + } + + return status; + } + + + +//! \brief +//! principal component analysis -- 1 argument version +//! coeff_out -> principal component coefficients +template +inline +bool +princomp + ( + Mat& coeff_out, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool status = op_princomp::direct_princomp(coeff_out, X); + + if(status == false) + { + coeff_out.soft_reset(); + + arma_debug_warn_level(3, "princomp(): decomposition failed"); + } + + return status; + } + + + +template +arma_warn_unused +inline +const Op +princomp + ( + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return Op(X.get_ref()); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_prod.hpp b/src/armadillo/include/armadillo_bits/fn_prod.hpp new file mode 100644 index 0000000..c15110f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_prod.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_prod +//! @{ + + +//! \brief +//! Delayed product of elements of a matrix along a specified dimension (either rows or columns). +//! The result is stored in a dense matrix that has either one column or one row. +//! For dim = 0, find the sum of each column (ie. traverse across rows) +//! For dim = 1, find the sum of each row (ie. traverse across columns) +//! The default is dim = 0. +//! NOTE: this function works differently than in Matlab/Octave. + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, typename T1::elem_type >::result +prod(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_prod::prod(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const Op >::result +prod(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +prod(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +prod(const T& x) + { + return x; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_qr.hpp b/src/armadillo/include/armadillo_bits/fn_qr.hpp new file mode 100644 index 0000000..3d49a1b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_qr.hpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_qr +//! @{ + + + +//! QR decomposition +template +inline +bool +qr + ( + Mat& Q, + Mat& R, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_debug_check( (&Q == &R), "qr(): Q and R are the same object" ); + + const bool status = auxlib::qr(Q, R, X); + + if(status == false) + { + Q.soft_reset(); + R.soft_reset(); + arma_debug_warn_level(3, "qr(): decomposition failed"); + } + + return status; + } + + + +//! economical QR decomposition +template +inline +bool +qr_econ + ( + Mat& Q, + Mat& R, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_debug_check( (&Q == &R), "qr_econ(): Q and R are the same object" ); + + const bool status = auxlib::qr_econ(Q, R, X); + + if(status == false) + { + Q.soft_reset(); + R.soft_reset(); + arma_debug_warn_level(3, "qr_econ(): decomposition failed"); + } + + return status; + } + + + +//! QR decomposition with pivoting +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +qr + ( + Mat& Q, + Mat& R, + Mat& P, + const Base& X, + const char* P_mode = "matrix" + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (&Q == &R), "qr(): Q and R are the same object" ); + + const char sig = (P_mode != nullptr) ? P_mode[0] : char(0); + + arma_debug_check( ((sig != 'm') && (sig != 'v')), "qr(): argument 'P_mode' must be \"vector\" or \"matrix\"" ); + + bool status = false; + + if(sig == 'v') + { + status = auxlib::qr_pivot(Q, R, P, X); + } + else + if(sig == 'm') + { + Mat P_vec; + + status = auxlib::qr_pivot(Q, R, P_vec, X); + + if(status) + { + // construct P + + const uword N = P_vec.n_rows; + + P.zeros(N,N); + + for(uword row=0; row < N; ++row) { P.at(P_vec[row], row) = uword(1); } + } + } + + if(status == false) + { + Q.soft_reset(); + R.soft_reset(); + P.soft_reset(); + arma_debug_warn_level(3, "qr(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_quantile.hpp b/src/armadillo/include/armadillo_bits/fn_quantile.hpp new file mode 100644 index 0000000..6c1ea2b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_quantile.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_quantile +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_cx::no && is_real::value, + const mtGlue + >::result +quantile(const T1& X, const Base& P) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, P.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_cx::no && is_real::value, + const mtGlue + >::result +quantile(const T1& X, const Base& P, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, P.get_ref(), dim); + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_qz.hpp b/src/armadillo/include/armadillo_bits/fn_qz.hpp new file mode 100644 index 0000000..9979dfa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_qz.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_qz +//! @{ + + + +//! QZ decomposition for pair of N-by-N general matrices A and B +template +inline +typename +enable_if2 + < + is_supported_blas_type::value, + bool + >::result +qz + ( + Mat& AA, + Mat& BB, + Mat& Q, + Mat& Z, + const Base& A_expr, + const Base& B_expr, + const char* select = "none" + ) + { + arma_extra_debug_sigprint(); + + const char sig = (select != nullptr) ? select[0] : char(0); + + arma_debug_check( ( (sig != 'n') && (sig != 'l') && (sig != 'r') && (sig != 'i') && (sig != 'o') ), "qz(): unknown select form" ); + + const bool status = auxlib::qz(AA, BB, Q, Z, A_expr.get_ref(), B_expr.get_ref(), sig); + + if(status == false) + { + AA.soft_reset(); + BB.soft_reset(); + Q.soft_reset(); + Z.soft_reset(); + arma_debug_warn_level(3, "qz(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_randg.hpp b/src/armadillo/include/armadillo_bits/fn_randg.hpp new file mode 100644 index 0000000..a0e998a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_randg.hpp @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_randg +//! @{ + + + +template +arma_warn_unused +inline +obj_type +randg(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + if(is_Col::value) + { + arma_debug_check( (n_cols != 1), "randg(): incompatible size" ); + } + else + if(is_Row::value) + { + arma_debug_check( (n_rows != 1), "randg(): incompatible size" ); + } + + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + arma_rng::randg::fill(out.memptr(), out.n_elem, a, b); + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +randg(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randg(s.n_rows, s.n_cols, param); + } + + + +template +arma_warn_unused +inline +obj_type +randg(const uword n_elem, const distr_param& param = distr_param(), const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + return randg(n_rows, n_cols, param); + } + + + +arma_warn_unused +inline +mat +randg(const uword n_rows, const uword n_cols, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randg(n_rows, n_cols, param); + } + + + +arma_warn_unused +inline +mat +randg(const SizeMat& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randg(s.n_rows, s.n_cols, param); + } + + + +arma_warn_unused +inline +vec +randg(const uword n_elem, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randg(n_elem, uword(1), param); + } + + + +arma_warn_unused +inline +double +randg(const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + double out_val = double(0); + + arma_rng::randg::fill(&out_val, uword(1), a, b); + + return out_val; + } + + + +template +arma_warn_unused +inline +typename arma_real_or_cx_only::result +randg(const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + eT out_val = eT(0); + + arma_rng::randg::fill(&out_val, uword(1), a, b); + + return out_val; + } + + + +template +arma_warn_unused +inline +cube_type +randg(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename cube_type::elem_type eT; + + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + arma_rng::randg::fill(out.memptr(), out.n_elem, a, b); + + return out; + } + + + +template +arma_warn_unused +inline +cube_type +randg(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randg(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +arma_warn_unused +inline +cube +randg(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randg(n_rows, n_cols, n_slices, param); + } + + + +arma_warn_unused +inline +cube +randg(const SizeCube& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randg(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_randi.hpp b/src/armadillo/include/armadillo_bits/fn_randi.hpp new file mode 100644 index 0000000..2aae9b5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_randi.hpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_randi +//! @{ + + + +template +arma_warn_unused +inline +obj_type +randi(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + if(is_Col::value) + { + arma_debug_check( (n_cols != 1), "randi(): incompatible size" ); + } + else + if(is_Row::value) + { + arma_debug_check( (n_rows != 1), "randi(): incompatible size" ); + } + + int a = 0; + int b = arma_rng::randi::max_val(); + + param.get_int_vals(a,b); + + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + arma_rng::randi::fill(out.memptr(), out.n_elem, a, b); + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +randi(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randi(s.n_rows, s.n_cols, param); + } + + + +template +arma_warn_unused +inline +obj_type +randi(const uword n_elem, const distr_param& param = distr_param(), const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + if(is_Row::value) + { + return randi(1, n_elem, param); + } + else + { + return randi(n_elem, 1, param); + } + } + + + +arma_warn_unused +inline +imat +randi(const uword n_rows, const uword n_cols, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randi(n_rows, n_cols, param); + } + + + +arma_warn_unused +inline +imat +randi(const SizeMat& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randi(s.n_rows, s.n_cols, param); + } + + + +arma_warn_unused +inline +ivec +randi(const uword n_elem, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randi(n_elem, uword(1), param); + } + + + +arma_warn_unused +inline +sword +randi(const distr_param& param) + { + arma_extra_debug_sigprint(); + + int a = 0; + int b = arma_rng::randi::max_val(); + + param.get_int_vals(a,b); + + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); + + sword out_val = sword(0); + + arma_rng::randi::fill(&out_val, uword(1), a, b); + + return out_val; + } + + + +template +arma_warn_unused +inline +typename arma_scalar_only::result +randi(const distr_param& param) + { + arma_extra_debug_sigprint(); + + int a = 0; + int b = arma_rng::randi::max_val(); + + param.get_int_vals(a,b); + + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); + + eT out_val = eT(0); + + arma_rng::randi::fill(&out_val, uword(1), a, b); + + return out_val; + } + + + +arma_warn_unused +inline +sword +randi() + { + arma_extra_debug_sigprint(); + + return sword( arma_rng::randi() ); + } + + + +template +arma_warn_unused +inline +typename arma_scalar_only::result +randi() + { + arma_extra_debug_sigprint(); + + return eT( arma_rng::randi() ); + } + + + +template +arma_warn_unused +inline +cube_type +randi(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename cube_type::elem_type eT; + + int a = 0; + int b = arma_rng::randi::max_val(); + + param.get_int_vals(a,b); + + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); + + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + arma_rng::randi::fill(out.memptr(), out.n_elem, a, b); + + return out; + } + + + +template +arma_warn_unused +inline +cube_type +randi(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randi(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +arma_warn_unused +inline +icube +randi(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randi(n_rows, n_cols, n_slices, param); + } + + + +arma_warn_unused +inline +icube +randi(const SizeCube& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randi(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_randn.hpp b/src/armadillo/include/armadillo_bits/fn_randn.hpp new file mode 100644 index 0000000..37cb3db --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_randn.hpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_randn +//! @{ + + + +// scalars + +arma_warn_unused +inline +double +randn() + { + arma_extra_debug_sigprint(); + + return double(arma_rng::randn()); + } + + + +template +arma_warn_unused +inline +typename arma_real_or_cx_only::result +randn() + { + arma_extra_debug_sigprint(); + + return eT(arma_rng::randn()); + } + + + +arma_warn_unused +inline +double +randn(const distr_param& param) + { + arma_extra_debug_sigprint(); + + if(param.state == 0) { return double(arma_rng::randn()); } + + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + const double val = double(arma_rng::randn()); + + return ((val * sd) + mu); + } + + + +template +arma_warn_unused +inline +typename arma_real_or_cx_only::result +randn(const distr_param& param) + { + arma_extra_debug_sigprint(); + + if(param.state == 0) { return eT(arma_rng::randn()); } + + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + eT val = eT(0); + + arma_rng::randn::fill(&val, 1, mu, sd); // using fill() as eT can be complex + + return val; + } + + + +// vectors + +arma_warn_unused +inline +vec +randn(const uword n_elem, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + vec out(n_elem, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), n_elem, mu, sd); + } + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +randn(const uword n_elem, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; + } + + + +// matrices + +arma_warn_unused +inline +mat +randn(const uword n_rows, const uword n_cols, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + mat out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; + } + + + +arma_warn_unused +inline +mat +randn(const SizeMat& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randn(s.n_rows, s.n_cols, param); + } + + + +template +arma_warn_unused +inline +obj_type +randn(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + if(is_Col::value) { arma_debug_check( (n_cols != 1), "randn(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "randn(): incompatible size" ); } + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +randn(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randn(s.n_rows, s.n_cols, param); + } + + + +// cubes + + +arma_warn_unused +inline +cube +randn(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + cube out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; + } + + + +arma_warn_unused +inline +cube +randn(const SizeCube& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randn(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +template +arma_warn_unused +inline +cube_type +randn(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename cube_type::elem_type eT; + + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; + } + + + +template +arma_warn_unused +inline +cube_type +randn(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randn(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_randperm.hpp b/src/armadillo/include/armadillo_bits/fn_randperm.hpp new file mode 100644 index 0000000..19623a7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_randperm.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_randperm +//! @{ + + + +template +inline +void +internal_randperm_helper(obj_type& x, const uword N, const uword N_keep) + { + arma_extra_debug_sigprint(); + + typedef typename obj_type::elem_type eT; + + // see op_sort_index_bones.hpp for the definition of arma_sort_index_packet + // and the associated comparison functor + + typedef arma_sort_index_packet packet; + + std::vector packet_vec(N); + + for(uword i=0; i < N; ++i) + { + packet_vec[i].val = int(arma_rng::randi()); + packet_vec[i].index = i; + } + + arma_sort_index_helper_ascend comparator; + + if(N >= 2) + { + if(N_keep < N) + { + typename std::vector::iterator first = packet_vec.begin(); + typename std::vector::iterator nth = first + N_keep; + typename std::vector::iterator pastlast = packet_vec.end(); + + std::partial_sort(first, nth, pastlast, comparator); + } + else + { + std::sort( packet_vec.begin(), packet_vec.end(), comparator ); + } + } + + if(is_Row::value) + { + x.set_size(1,N_keep); + } + else + { + x.set_size(N_keep,1); + } + + eT* x_mem = x.memptr(); + + for(uword i=0; i < N_keep; ++i) + { + x_mem[i] = eT( packet_vec[i].index ); + } + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_Mat::value, obj_type >::result +randperm(const uword N) + { + arma_extra_debug_sigprint(); + + obj_type x; + + if(N > 0) { internal_randperm_helper(x, N, N); } + + return x; + } + + + +arma_warn_unused +inline +uvec +randperm(const uword N) + { + arma_extra_debug_sigprint(); + + uvec x; + + if(N > 0) { internal_randperm_helper(x, N, N); } + + return x; + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_Mat::value, obj_type >::result +randperm(const uword N, const uword M) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (M > N), "randperm(): 'M' must be less than or equal to 'N'" ); + + obj_type x; + + if( (N > 0) && (M > 0) ) { internal_randperm_helper(x, N, M); } + + return x; + } + + + +arma_warn_unused +inline +uvec +randperm(const uword N, const uword M) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (M > N), "randperm(): 'M' must be less than or equal to 'N'" ); + + uvec x; + + if( (N > 0) && (M > 0) ) { internal_randperm_helper(x, N, M); } + + return x; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_randu.hpp b/src/armadillo/include/armadillo_bits/fn_randu.hpp new file mode 100644 index 0000000..432c171 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_randu.hpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_randu +//! @{ + + + +// scalars + +arma_warn_unused +inline +double +randu() + { + arma_extra_debug_sigprint(); + + return double(arma_rng::randu()); + } + + + +template +arma_warn_unused +inline +typename arma_real_or_cx_only::result +randu() + { + arma_extra_debug_sigprint(); + + return eT(arma_rng::randu()); + } + + + +arma_warn_unused +inline +double +randu(const distr_param& param) + { + arma_extra_debug_sigprint(); + + if(param.state == 0) { return double(arma_rng::randu()); } + + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + const double val = double(arma_rng::randu()); + + return ((val * (b - a)) + a); + } + + + +template +arma_warn_unused +inline +typename arma_real_or_cx_only::result +randu(const distr_param& param) + { + arma_extra_debug_sigprint(); + + if(param.state == 0) { return eT(arma_rng::randu()); } + + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + eT val = eT(0); + + arma_rng::randu::fill(&val, 1, a, b); // using fill() as eT can be complex + + return val; + } + + + +// vectors + +arma_warn_unused +inline +vec +randu(const uword n_elem, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + vec out(n_elem, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), n_elem, a, b); + } + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +randu(const uword n_elem, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; + } + + + +// matrices + +arma_warn_unused +inline +mat +randu(const uword n_rows, const uword n_cols, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + mat out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; + } + + + +arma_warn_unused +inline +mat +randu(const SizeMat& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randu(s.n_rows, s.n_cols, param); + } + + + +template +arma_warn_unused +inline +obj_type +randu(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + if(is_Col::value) { arma_debug_check( (n_cols != 1), "randu(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "randu(): incompatible size" ); } + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +randu(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randu(s.n_rows, s.n_cols, param); + } + + + +// cubes + + +arma_warn_unused +inline +cube +randu(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + cube out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; + } + + + +arma_warn_unused +inline +cube +randu(const SizeCube& s, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + return randu(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +template +arma_warn_unused +inline +cube_type +randu(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename cube_type::elem_type eT; + + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; + } + + + +template +arma_warn_unused +inline +cube_type +randu(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return randu(s.n_rows, s.n_cols, s.n_slices, param); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_range.hpp b/src/armadillo/include/armadillo_bits/fn_range.hpp new file mode 100644 index 0000000..3a28094 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_range.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_range +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, typename T1::elem_type >::result +range(const T1& X) + { + arma_extra_debug_sigprint(); + + return op_range::vector_range(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const Op >::result +range(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +range(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_rank.hpp b/src/armadillo/include/armadillo_bits/fn_rank.hpp new file mode 100644 index 0000000..7701a04 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_rank.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_rank +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, uword >::result +rank(const Base& expr, const typename T1::pod_type tol = 0) + { + arma_extra_debug_sigprint(); + + uword out = uword(0); + + const bool status = op_rank::apply(out, expr.get_ref(), tol); + + if(status == false) { arma_stop_runtime_error("rank(): failed"); return uword(0); } + + return out; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +rank(uword& out, const Base& expr, const typename T1::pod_type tol = 0) + { + arma_extra_debug_sigprint(); + + out = uword(0); + + return op_rank::apply(out, expr.get_ref(), tol); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_regspace.hpp b/src/armadillo/include/armadillo_bits/fn_regspace.hpp new file mode 100644 index 0000000..83e7de2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_regspace.hpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_regspace +//! @{ + + + +template +inline +void +internal_regspace_default_delta + ( + Mat& x, + const typename Mat::pod_type start, + const typename Mat::pod_type end + ) + { + arma_extra_debug_sigprint(); + + typedef typename Mat::pod_type T; + + const bool ascend = (start <= end); + + const uword N = uword(1) + uword((ascend) ? (end-start) : (start-end)); + + x.set_size(N); + + eT* x_mem = x.memptr(); + + if(ascend) + { + for(uword i=0; i < N; ++i) { x_mem[i] = eT(start + T(i)); } + } + else + { + for(uword i=0; i < N; ++i) { x_mem[i] = eT(start - T(i)); } + } + } + + + +template +inline +typename enable_if2< (is_signed::value == true), void >::result +internal_regspace_var_delta + ( + Mat& x, + const typename Mat::pod_type start, + const sT delta, + const typename Mat::pod_type end + ) + { + arma_extra_debug_sigprint(); + arma_extra_debug_print("internal_regspace_var_delta(): signed version"); + + typedef typename Mat::pod_type T; + + if( ((start < end) && (delta < sT(0))) || ((start > end) && (delta > sT(0))) || (delta == sT(0)) ) { return; } + + const bool ascend = (start <= end); + + const T inc = (delta < sT(0)) ? T(-delta) : T(delta); + + const T M = ((ascend) ? T(end-start) : T(start-end)) / T(inc); + + const uword N = uword(1) + ( (is_non_integral::value) ? uword(std::floor(double(M))) : uword(M) ); + + x.set_size(N); + + eT* x_mem = x.memptr(); + + if(ascend) + { + for(uword i=0; i < N; ++i) { x_mem[i] = eT( start + T(i*inc) ); } + } + else + { + for(uword i=0; i < N; ++i) { x_mem[i] = eT( start - T(i*inc) ); } + } + } + + + +template +inline +typename enable_if2< (is_signed::value == false), void >::result +internal_regspace_var_delta + ( + Mat& x, + const typename Mat::pod_type start, + const uT delta, + const typename Mat::pod_type end + ) + { + arma_extra_debug_sigprint(); + arma_extra_debug_print("internal_regspace_var_delta(): unsigned version"); + + typedef typename Mat::pod_type T; + + if( ((start > end) && (delta > uT(0))) || (delta == uT(0)) ) { return; } + + const bool ascend = (start <= end); + + const T inc = T(delta); + + const T M = ((ascend) ? T(end-start) : T(start-end)) / T(inc); + + const uword N = uword(1) + ( (is_non_integral::value) ? uword(std::floor(double(M))) : uword(M) ); + + x.set_size(N); + + eT* x_mem = x.memptr(); + + if(ascend) + { + for(uword i=0; i < N; ++i) { x_mem[i] = eT( start + T(i*inc) ); } + } + else + { + for(uword i=0; i < N; ++i) { x_mem[i] = eT( start - T(i*inc) ); } + } + } + + + +template +inline +typename enable_if2< is_Mat::value && (is_signed::value == true), vec_type >::result +regspace + ( + const typename vec_type::pod_type start, + const sT delta, + const typename vec_type::pod_type end + ) + { + arma_extra_debug_sigprint(); + arma_extra_debug_print("regspace(): signed version"); + + vec_type x; + + if( ((delta == sT(+1)) && (start <= end)) || ((delta == sT(-1)) && (start > end)) ) + { + internal_regspace_default_delta(x, start, end); + } + else + { + internal_regspace_var_delta(x, start, delta, end); + } + + if(x.n_elem == 0) + { + if(is_Mat_only::value) { x.set_size(1,0); } + } + + return x; + } + + + +template +inline +typename enable_if2< is_Mat::value && (is_signed::value == false), vec_type >::result +regspace + ( + const typename vec_type::pod_type start, + const uT delta, + const typename vec_type::pod_type end + ) + { + arma_extra_debug_sigprint(); + arma_extra_debug_print("regspace(): unsigned version"); + + vec_type x; + + if( (delta == uT(+1)) && (start <= end) ) + { + internal_regspace_default_delta(x, start, end); + } + else + { + internal_regspace_var_delta(x, start, delta, end); + } + + if(x.n_elem == 0) + { + if(is_Mat_only::value) { x.set_size(1,0); } + } + + return x; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_Mat::value, + vec_type + >::result +regspace + ( + const typename vec_type::pod_type start, + const typename vec_type::pod_type end + ) + { + arma_extra_debug_sigprint(); + + vec_type x; + + internal_regspace_default_delta(x, start, end); + + if(x.n_elem == 0) + { + if(is_Mat_only::value) { x.set_size(1,0); } + } + + return x; + } + + + +arma_warn_unused +inline +vec +regspace(const double start, const double delta, const double end) + { + arma_extra_debug_sigprint(); + + return regspace(start, delta, end); + } + + + +arma_warn_unused +inline +vec +regspace(const double start, const double end) + { + arma_extra_debug_sigprint(); + + return regspace(start, end); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_repelem.hpp b/src/armadillo/include/armadillo_bits/fn_repelem.hpp new file mode 100644 index 0000000..5d1e817 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_repelem.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup fn_repelem +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +repelem(const T1& A, const uword r, const uword c) + { + arma_extra_debug_sigprint(); + + return Op(A, r, c); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +repelem(const SpBase& A, const uword r, const uword c) + { + arma_extra_debug_sigprint(); + + return SpOp(A.get_ref(), r, c); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_repmat.hpp b/src/armadillo/include/armadillo_bits/fn_repmat.hpp new file mode 100644 index 0000000..113bfb3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_repmat.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup fn_repmat +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +repmat(const T1& A, const uword r, const uword c) + { + arma_extra_debug_sigprint(); + + return Op(A, r, c); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +repmat(const SpBase& A, const uword r, const uword c) + { + arma_extra_debug_sigprint(); + + return SpOp(A.get_ref(), r, c); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_reshape.hpp b/src/armadillo/include/armadillo_bits/fn_reshape.hpp new file mode 100644 index 0000000..35bef09 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_reshape.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_reshape +//! @{ + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, const Op >::result +reshape(const T1& X, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + return Op(X, new_n_rows, new_n_cols); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, const Op >::result +reshape(const T1& X, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return Op(X, s.n_rows, s.n_cols); + } + + + +template +arma_frown("don't use this form: it will be removed") +inline +Mat +reshape(const Base& X, const uword new_n_rows, const uword new_n_cols, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_check( (dim > 1), "reshape(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(X.get_ref()); + const Mat& A = U.M; + + Mat out; + + if(dim == 0) + { + op_reshape::apply_mat_noalias(out, A, new_n_rows, new_n_cols); + } + else + if(dim == 1) + { + Mat tmp; + + op_strans::apply_mat_noalias(tmp, A); + + op_reshape::apply_mat_noalias(out, tmp, new_n_rows, new_n_cols); + } + + return out; + } + + + +template +arma_warn_unused +inline +const OpCube +reshape(const BaseCube& X, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), new_n_rows, new_n_cols, new_n_slices); + } + + + +template +arma_warn_unused +inline +const OpCube +reshape(const BaseCube& X, const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), s.n_rows, s.n_cols, s.n_slices); + } + + + +template +arma_warn_unused +inline +const SpOp +reshape(const SpBase& X, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), new_n_rows, new_n_cols); + } + + + +template +arma_warn_unused +inline +const SpOp +reshape(const SpBase& X, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), s.n_rows, s.n_cols); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_resize.hpp b/src/armadillo/include/armadillo_bits/fn_resize.hpp new file mode 100644 index 0000000..7088290 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_resize.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_resize +//! @{ + + + +template +arma_warn_unused +inline +const Op +resize(const Base& X, const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), in_n_rows, in_n_cols); + } + + + +template +arma_warn_unused +inline +const Op +resize(const Base& X, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), s.n_rows, s.n_cols); + } + + + +template +arma_warn_unused +inline +const OpCube +resize(const BaseCube& X, const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), in_n_rows, in_n_cols, in_n_slices); + } + + + +template +arma_warn_unused +inline +const OpCube +resize(const BaseCube& X, const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), s.n_rows, s.n_cols, s.n_slices); + } + + + +template +arma_warn_unused +inline +const SpOp +resize(const SpBase& X, const uword in_n_rows, const uword in_n_cols) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), in_n_rows, in_n_cols); + } + + + +template +arma_warn_unused +inline +const SpOp +resize(const SpBase& X, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), s.n_rows, s.n_cols); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_reverse.hpp b/src/armadillo/include/armadillo_bits/fn_reverse.hpp new file mode 100644 index 0000000..284c80d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_reverse.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_reverse +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +reverse + ( + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +reverse + ( + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, const Op >::result +reverse + ( + const T1& X, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +inline +const SpOp +reverse + ( + const SpBase& X, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), dim, 0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_roots.hpp b/src/armadillo/include/armadillo_bits/fn_roots.hpp new file mode 100644 index 0000000..80fe240 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_roots.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_roots +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_supported_blas_type::value, + const mtOp, T1, op_roots> + >::result +roots(const Base& X) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_roots>(X.get_ref()); + } + + + +template +inline +typename +enable_if2 + < + is_supported_blas_type::value, + bool + >::result +roots(Mat< std::complex >& out, const Base& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_roots::apply_direct(out, X.get_ref()); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "roots(): eigen decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_schur.hpp b/src/armadillo/include/armadillo_bits/fn_schur.hpp new file mode 100644 index 0000000..f9f277c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_schur.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_schur +//! @{ + + +template +inline +bool +schur + ( + Mat& S, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + Mat U; + + const bool status = auxlib::schur(U, S, X.get_ref(), false); + + if(status == false) + { + S.soft_reset(); + arma_debug_warn_level(3, "schur(): decomposition failed"); + } + + return status; + } + + + +template +arma_warn_unused +inline +Mat +schur + ( + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + Mat S; + Mat U; + + const bool status = auxlib::schur(U, S, X.get_ref(), false); + + if(status == false) + { + S.soft_reset(); + arma_stop_runtime_error("schur(): decomposition failed"); + } + + return S; + } + + + +template +inline +bool +schur + ( + Mat& U, + Mat& S, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + arma_debug_check( void_ptr(&U) == void_ptr(&S), "schur(): 'U' is an alias of 'S'" ); + + const bool status = auxlib::schur(U, S, X.get_ref(), true); + + if(status == false) + { + U.soft_reset(); + S.soft_reset(); + arma_debug_warn_level(3, "schur(): decomposition failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_shift.hpp b/src/armadillo/include/armadillo_bits/fn_shift.hpp new file mode 100644 index 0000000..d3de6a7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_shift.hpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup fn_shift +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +shift + ( + const T1& X, + const sword N + ) + { + arma_extra_debug_sigprint(); + + const uword len = (N < 0) ? uword(-N) : uword(N); + const uword neg = (N < 0) ? uword( 1) : uword(0); + + return Op(X, len, neg); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + Mat + >::result +shift + ( + const T1& X, + const sword N + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword len = (N < 0) ? uword(-N) : uword(N); + const uword neg = (N < 0) ? uword( 1) : uword(0); + + quasi_unwrap U(X); + + Mat out; + + op_shift::apply_noalias(out, U.M, len, neg, 0); + + return out; + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + (is_arma_type::value), + Mat + >::result +shift + ( + const T1& X, + const sword N, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_check( (dim > 1), "shift(): parameter 'dim' must be 0 or 1" ); + + const uword len = (N < 0) ? uword(-N) : uword(N); + const uword neg = (N < 0) ? uword( 1) : uword(0); + + quasi_unwrap U(X); + + Mat out; + + op_shift::apply_noalias(out, U.M, len, neg, dim); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_shuffle.hpp b/src/armadillo/include/armadillo_bits/fn_shuffle.hpp new file mode 100644 index 0000000..a0e0f65 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_shuffle.hpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup fn_shuffle +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +shuffle + ( + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +shuffle + ( + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + (is_arma_type::value), + const Op + >::result +shuffle + ( + const T1& X, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_size.hpp b/src/armadillo/include/armadillo_bits/fn_size.hpp new file mode 100644 index 0000000..b6ac80e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_size.hpp @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_size +//! @{ + + + +arma_warn_unused +inline +const SizeMat +size(const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + return SizeMat(n_rows, n_cols); + } + + + +template +arma_warn_unused +inline +const SizeMat +size(const Base& X) + { + arma_extra_debug_sigprint(); + + const Proxy P(X.get_ref()); + + return SizeMat( P.get_n_rows(), P.get_n_cols() ); + } + + + +// explicit overload to workround ADL issues with C++17 std::size() +template +arma_warn_unused +inline +const SizeMat +size(const Mat& X) + { + arma_extra_debug_sigprint(); + + return SizeMat( X.n_rows, X.n_cols ); + } + + + +// explicit overload to workround ADL issues with C++17 std::size() +template +arma_warn_unused +inline +const SizeMat +size(const Row& X) + { + arma_extra_debug_sigprint(); + + return SizeMat( X.n_rows, X.n_cols ); + } + + + +// explicit overload to workround ADL issues with C++17 std::size() +template +arma_warn_unused +inline +const SizeMat +size(const Col& X) + { + arma_extra_debug_sigprint(); + + return SizeMat( X.n_rows, X.n_cols ); + } + + + +arma_warn_unused +inline +const SizeMat +size(const arma::span& row_span, const arma::span& col_span) + { + arma_extra_debug_sigprint(); + + uword n_rows = 0; + uword n_cols = 0; + + if(row_span.whole || col_span.whole) + { + arma_debug_check(true, "size(): span::all not supported"); + } + else + { + if((row_span.a > row_span.b) || (col_span.a > col_span.b)) + { + arma_debug_check_bounds(true, "size(): span indices incorrectly used"); + } + else + { + n_rows = row_span.b - row_span.a + 1; + n_cols = col_span.b - col_span.a + 1; + } + } + + return SizeMat(n_rows, n_cols); + } + + + +template +arma_warn_unused +inline +uword +size(const Base& X, const uword dim) + { + arma_extra_debug_sigprint(); + + const Proxy P(X.get_ref()); + + return SizeMat( P.get_n_rows(), P.get_n_cols() )( dim ); + } + + + +arma_warn_unused +inline +const SizeCube +size(const uword n_rows, const uword n_cols, const uword n_slices) + { + arma_extra_debug_sigprint(); + + return SizeCube(n_rows, n_cols, n_slices); + } + + + +template +arma_warn_unused +inline +const SizeCube +size(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const ProxyCube P(X.get_ref()); + + return SizeCube( P.get_n_rows(), P.get_n_cols(), P.get_n_slices() ); + } + + + +// explicit overload to workround ADL issues with C++17 std::size() +template +arma_warn_unused +inline +const SizeCube +size(const Cube& X) + { + arma_extra_debug_sigprint(); + + return SizeCube( X.n_rows, X.n_cols, X.n_slices ); + } + + + +template +arma_warn_unused +inline +uword +size(const BaseCube& X, const uword dim) + { + arma_extra_debug_sigprint(); + + const ProxyCube P(X.get_ref()); + + return SizeCube( P.get_n_rows(), P.get_n_cols(), P.get_n_slices() )( dim ); + } + + + +arma_warn_unused +inline +const SizeCube +size(const arma::span& row_span, const arma::span& col_span, const arma::span& slice_span) + { + arma_extra_debug_sigprint(); + + uword n_rows = 0; + uword n_cols = 0; + uword n_slices = 0; + + if(row_span.whole || col_span.whole || slice_span.whole) + { + arma_debug_check(true, "size(): span::all not supported"); + } + else + { + if((row_span.a > row_span.b) || (col_span.a > col_span.b) || (slice_span.a > slice_span.b)) + { + arma_debug_check_bounds(true, "size(): span indices incorrectly used"); + } + else + { + n_rows = row_span.b - row_span.a + 1; + n_cols = col_span.b - col_span.a + 1; + n_slices = slice_span.b - slice_span.a + 1; + } + } + + return SizeCube(n_rows, n_cols, n_slices); + } + + + +template +arma_warn_unused +inline +const SizeMat +size(const SpBase& X) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X.get_ref()); + + return SizeMat( P.get_n_rows(), P.get_n_cols() ); + } + + + +// explicit overload to workround ADL issues with C++17 std::size() +template +arma_warn_unused +inline +const SizeMat +size(const SpMat& X) + { + arma_extra_debug_sigprint(); + + return SizeMat( X.n_rows, X.n_cols ); + } + + + +template +arma_warn_unused +inline +uword +size(const SpBase& X, const uword dim) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X.get_ref()); + + return SizeMat( P.get_n_rows(), P.get_n_cols() )( dim ); + } + + + + +template +arma_warn_unused +inline +const SizeCube +size(const field& X) + { + arma_extra_debug_sigprint(); + + return SizeCube( X.n_rows, X.n_cols, X.n_slices ); + } + + + +template +arma_warn_unused +inline +uword +size(const field& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return SizeCube( X.n_rows, X.n_cols, X.n_slices )( dim ); + } + + + +template +arma_warn_unused +inline +const SizeCube +size(const subview_field& X) + { + arma_extra_debug_sigprint(); + + return SizeCube( X.n_rows, X.n_cols, X.n_slices ); + } + + + +template +arma_warn_unused +inline +uword +size(const subview_field& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return SizeCube( X.n_rows, X.n_cols, X.n_slices )( dim ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_solve.hpp b/src/armadillo/include/armadillo_bits/fn_solve.hpp new file mode 100644 index 0000000..12ca693 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_solve.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_solve +//! @{ + + + +// +// solve_gen + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const Glue >::result +solve + ( + const Base& A, + const Base& B + ) + { + arma_extra_debug_sigprint(); + + return Glue(A.get_ref(), B.get_ref()); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +solve + ( + Mat& out, + const Base& A, + const Base& B + ) + { + arma_extra_debug_sigprint(); + + const bool status = glue_solve_gen_default::apply(out, A.get_ref(), B.get_ref()); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } + + return status; + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const Glue >::result +solve + ( + const Base& A, + const Base& B, + const solve_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + return Glue(A.get_ref(), B.get_ref(), opts.flags); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +solve + ( + Mat& out, + const Base& A, + const Base& B, + const solve_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + const bool status = glue_solve_gen_full::apply(out, A.get_ref(), B.get_ref(), opts.flags); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } + + return status; + } + + + +// +// solve_tri + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const Glue >::result +solve + ( + const Op& A, + const Base& B + ) + { + arma_extra_debug_sigprint(); + + uword flags = uword(0); + + if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } + if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } + + return Glue(A.m, B.get_ref(), flags); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const Glue >::result +solve + ( + const Op& A, + const Base& B, + const solve_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + uword flags = opts.flags; + + if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } + if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } + + return Glue(A.m, B.get_ref(), flags); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +solve + ( + Mat& out, + const Op& A, + const Base& B + ) + { + arma_extra_debug_sigprint(); + + uword flags = uword(0); + + if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } + if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } + + const bool status = glue_solve_tri_default::apply(out, A.m, B.get_ref(), flags); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +solve + ( + Mat& out, + const Op& A, + const Base& B, + const solve_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + uword flags = opts.flags; + + if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } + if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } + + const bool status = glue_solve_tri_full::apply(out, A.m, B.get_ref(), flags); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_sort.hpp b/src/armadillo/include/armadillo_bits/fn_sort.hpp new file mode 100644 index 0000000..01b45f2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_sort.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_sort +//! @{ + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +sort + ( + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +sort + ( + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes && is_same_type::value, + const Op + >::result +sort + ( + const T1& X, + const T2* sort_direction + ) + { + arma_extra_debug_sigprint(); + + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); + + arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction" ); + + const uword sort_type = (sig == 'a') ? 0 : 1; + + return Op(X, sort_type, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no && is_same_type::value, + const Op + >::result +sort + ( + const T1& X, + const T2* sort_direction + ) + { + arma_extra_debug_sigprint(); + + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); + + arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction" ); + + const uword sort_type = (sig == 'a') ? 0 : 1; + + return Op(X, sort_type, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + ( (is_arma_type::value) && (is_same_type::value) ), + const Op + >::result +sort + ( + const T1& X, + const T2* sort_direction, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); + + arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction" ); + + const uword sort_type = (sig == 'a') ? 0 : 1; + + return Op(X, sort_type, dim); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_sort_index.hpp b/src/armadillo/include/armadillo_bits/fn_sort_index.hpp new file mode 100644 index 0000000..1df3693 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_sort_index.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_sort_index +//! @{ + + + +template +arma_warn_unused +arma_inline +const mtOp +sort_index + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + return mtOp(X.get_ref(), uword(0), uword(0)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + ( (is_arma_type::value) && (is_same_type::value) ), + const mtOp + >::result +sort_index + ( + const T1& X, + const T2* sort_direction + ) + { + arma_extra_debug_sigprint(); + + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); + + arma_debug_check( ((sig != 'a') && (sig != 'd')), "sort_index(): unknown sort direction" ); + + return mtOp(X, ((sig == 'a') ? uword(0) : uword(1)), uword(0)); + } + + + +// + + + +template +arma_warn_unused +arma_inline +const mtOp +stable_sort_index + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + return mtOp(X.get_ref(), uword(0), uword(0)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + ( (is_arma_type::value) && (is_same_type::value) ), + const mtOp + >::result +stable_sort_index + ( + const T1& X, + const T2* sort_direction + ) + { + arma_extra_debug_sigprint(); + + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); + + arma_debug_check( ((sig != 'a') && (sig != 'd')), "stable_sort_index(): unknown sort direction" ); + + return mtOp(X, ((sig == 'a') ? uword(0) : uword(1)), uword(0)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_speye.hpp b/src/armadillo/include/armadillo_bits/fn_speye.hpp new file mode 100644 index 0000000..48570be --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_speye.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_speye +//! @{ + + + +//! Generate a sparse matrix with the values along the main diagonal set to one +template +arma_warn_unused +inline +obj_type +speye(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_SpCol::value) { arma_debug_check( (n_cols != 1), "speye(): incompatible size" ); } + if(is_SpRow::value) { arma_debug_check( (n_rows != 1), "speye(): incompatible size" ); } + + obj_type out; + + out.eye(n_rows, n_cols); + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +speye(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return speye(s.n_rows, s.n_cols); + } + + + +// Convenience shortcut method (no template parameter necessary) +arma_warn_unused +inline +sp_mat +speye(const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + sp_mat out; + + out.eye(n_rows, n_cols); + + return out; + } + + + +arma_warn_unused +inline +sp_mat +speye(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + sp_mat out; + + out.eye(s.n_rows, s.n_cols); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_spones.hpp b/src/armadillo/include/armadillo_bits/fn_spones.hpp new file mode 100644 index 0000000..ff45b21 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_spones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_spones +//! @{ + + + +//! Generate a sparse matrix with the non-zero values in the same locations as in the given sparse matrix X, +//! with the non-zero values set to one +template +arma_warn_unused +inline +SpMat +spones(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(X.get_ref()); + + SpMat out(arma_layout_indicator(), U.M); + + arrayops::inplace_set( access::rwp(out.values), eT(1), out.n_nonzero ); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_sprandn.hpp b/src/armadillo/include/armadillo_bits/fn_sprandn.hpp new file mode 100644 index 0000000..1798224 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_sprandn.hpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_sprandn +//! @{ + + + +//! Generate a sparse matrix with a randomly selected subset of the elements +//! set to random values from a Gaussian distribution with zero mean and unit variance +template +arma_warn_unused +inline +obj_type +sprandn + ( + const uword n_rows, + const uword n_cols, + const double density, + const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_SpCol::value) + { + arma_debug_check( (n_cols != 1), "sprandn(): incompatible size" ); + } + else + if(is_SpRow::value) + { + arma_debug_check( (n_rows != 1), "sprandn(): incompatible size" ); + } + + obj_type out; + + out.sprandn(n_rows, n_cols, density); + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +sprandn(const SizeMat& s, const double density, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return sprandn(s.n_rows, s.n_cols, density); + } + + + +arma_warn_unused +inline +sp_mat +sprandn(const uword n_rows, const uword n_cols, const double density) + { + arma_extra_debug_sigprint(); + + sp_mat out; + + out.sprandn(n_rows, n_cols, density); + + return out; + } + + + +arma_warn_unused +inline +sp_mat +sprandn(const SizeMat& s, const double density) + { + arma_extra_debug_sigprint(); + + sp_mat out; + + out.sprandn(s.n_rows, s.n_cols, density); + + return out; + } + + + +//! Generate a sparse matrix with the non-zero values in the same locations as in the given sparse matrix X, +//! with the non-zero values set to random values from a Gaussian distribution with zero mean and unit variance +template +arma_warn_unused +inline +SpMat +sprandn(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + SpMat out( X.get_ref() ); + + arma_rng::randn::fill( access::rwp(out.values), out.n_nonzero ); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_sprandu.hpp b/src/armadillo/include/armadillo_bits/fn_sprandu.hpp new file mode 100644 index 0000000..846e75b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_sprandu.hpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_sprandu +//! @{ + + + +//! Generate a sparse matrix with a randomly selected subset of the elements +//! set to random values in the [0,1] interval (uniform distribution) +template +arma_warn_unused +inline +obj_type +sprandu + ( + const uword n_rows, + const uword n_cols, + const double density, + const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_SpCol::value) + { + arma_debug_check( (n_cols != 1), "sprandu(): incompatible size" ); + } + else + if(is_SpRow::value) + { + arma_debug_check( (n_rows != 1), "sprandu(): incompatible size" ); + } + + obj_type out; + + out.sprandu(n_rows, n_cols, density); + + return out; + } + + + +template +arma_warn_unused +inline +obj_type +sprandu(const SizeMat& s, const double density, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return sprandu(s.n_rows, s.n_cols, density); + } + + + +arma_warn_unused +inline +sp_mat +sprandu(const uword n_rows, const uword n_cols, const double density) + { + arma_extra_debug_sigprint(); + + sp_mat out; + + out.sprandu(n_rows, n_cols, density); + + return out; + } + + + +arma_warn_unused +inline +sp_mat +sprandu(const SizeMat& s, const double density) + { + arma_extra_debug_sigprint(); + + sp_mat out; + + out.sprandu(s.n_rows, s.n_cols, density); + + return out; + } + + + +//! Generate a sparse matrix with the non-zero values in the same locations as in the given sparse matrix X, +//! with the non-zero values set to random values in the [0,1] interval (uniform distribution) +template +arma_warn_unused +inline +SpMat +sprandu(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + SpMat out( X.get_ref() ); + + arma_rng::randu::fill( access::rwp(out.values), out.n_nonzero ); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_spsolve.hpp b/src/armadillo/include/armadillo_bits/fn_spsolve.hpp new file mode 100644 index 0000000..3eaf333 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_spsolve.hpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_spsolve +//! @{ + + + +template +inline +bool +spsolve_helper + ( + Mat& out, + const SpBase& A, + const Base& B, + const char* solver, + const spsolve_opts_base& settings, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + const char sig = (solver != nullptr) ? solver[0] : char(0); + + arma_debug_check( ((sig != 'l') && (sig != 's')), "spsolve(): unknown solver" ); + + T rcond = T(0); + + bool status = false; + + superlu_opts superlu_opts_default; + + // if(is_float ::value) { superlu_opts_default.refine = superlu_opts::REF_SINGLE; } + // if(is_double::value) { superlu_opts_default.refine = superlu_opts::REF_DOUBLE; } + + const superlu_opts& opts = (settings.id == 1) ? static_cast(settings) : superlu_opts_default; + + arma_debug_check( ( (opts.pivot_thresh < double(0)) || (opts.pivot_thresh > double(1)) ), "spsolve(): pivot_thresh must be in the [0,1] interval" ); + + if(sig == 's') // SuperLU solver + { + if( (opts.equilibrate == false) && (opts.refine == superlu_opts::REF_NONE) ) + { + status = sp_auxlib::spsolve_simple(out, A.get_ref(), B.get_ref(), opts); + } + else + { + status = sp_auxlib::spsolve_refine(out, rcond, A.get_ref(), B.get_ref(), opts); + } + } + else + if(sig == 'l') // brutal LAPACK solver + { + if( (settings.id != 0) && ((opts.symmetric) || (opts.pivot_thresh != double(1))) ) + { + arma_debug_warn_level(1, "spsolve(): ignoring settings not applicable to LAPACK based solver"); + } + + Mat AA; + + bool conversion_ok = false; + + try + { + Mat tmp(A.get_ref()); // conversion from sparse to dense can throw std::bad_alloc + + AA.steal_mem(tmp); + + conversion_ok = true; + } + catch(...) + { + arma_debug_warn_level(1, "spsolve(): not enough memory to use LAPACK based solver"); + } + + if(conversion_ok) + { + arma_debug_check( (AA.n_rows != AA.n_cols), "spsolve(): matrix A must be square sized" ); + + uword flags = solve_opts::flag_none; + + if(opts.refine != superlu_opts::REF_NONE) { flags |= solve_opts::flag_refine; } + if(opts.equilibrate == true ) { flags |= solve_opts::flag_equilibrate; } + if(opts.allow_ugly == true ) { flags |= solve_opts::flag_allow_ugly; } + + status = glue_solve_gen_full::apply(out, AA, B.get_ref(), flags); + } + } + + + if( (status == false) && (rcond > T(0)) ) + { + arma_debug_warn_level(2, "spsolve(): system is singular (rcond: ", rcond, ")"); + } + + if( (status == true) && (rcond > T(0)) && (rcond < std::numeric_limits::epsilon()) ) + { + arma_debug_warn_level(2, "solve(): solution computed, but system is singular to working precision (rcond: ", rcond, ")"); + } + + return status; + } + + + +// + + + +template +inline +bool +spsolve + ( + Mat& out, + const SpBase& A, + const Base& B, + const char* solver = "superlu", + const spsolve_opts_base& settings = spsolve_opts_none(), + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool status = spsolve_helper(out, A.get_ref(), B.get_ref(), solver, settings); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "spsolve(): solution not found"); + } + + return status; + } + + + +template +arma_warn_unused +inline +Mat +spsolve + ( + const SpBase& A, + const Base& B, + const char* solver = "superlu", + const spsolve_opts_base& settings = spsolve_opts_none(), + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + Mat out; + + const bool status = spsolve_helper(out, A.get_ref(), B.get_ref(), solver, settings); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("spsolve(): solution not found"); + } + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_sqrtmat.hpp b/src/armadillo/include/armadillo_bits/fn_sqrtmat.hpp new file mode 100644 index 0000000..882aa15 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_sqrtmat.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_sqrtmat +//! @{ + + + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_supported_blas_type::value && is_cx::no), const mtOp, T1, op_sqrtmat> >::result +sqrtmat(const Base& X) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_sqrtmat>(X.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_supported_blas_type::value && is_cx::yes), const Op >::result +sqrtmat(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +inline +typename enable_if2< (is_supported_blas_type::value && is_cx::no), bool >::result +sqrtmat(Mat< std::complex >& Y, const Base& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat::apply_direct(Y, X.get_ref()); + + if(status == false) + { + arma_debug_warn_level(3, "sqrtmat(): given matrix is singular; may not have a square root"); + } + + return status; + } + + + +template +inline +typename enable_if2< (is_supported_blas_type::value && is_cx::yes), bool >::result +sqrtmat(Mat& Y, const Base& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat_cx::apply_direct(Y, X.get_ref()); + + if(status == false) + { + arma_debug_warn_level(3, "sqrtmat(): given matrix is singular; may not have a square root"); + } + + return status; + } + + + +// + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +sqrtmat_sympd(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +sqrtmat_sympd(Mat& Y, const Base& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat_sympd::apply_direct(Y, X.get_ref()); + + if(status == false) + { + Y.soft_reset(); + arma_debug_warn_level(3, "sqrtmat_sympd(): transformation failed"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_stddev.hpp b/src/armadillo/include/armadillo_bits/fn_stddev.hpp new file mode 100644 index 0000000..7544280 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_stddev.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_stddev +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + typename T1::pod_type + >::result +stddev(const T1& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + return std::sqrt( op_var::var_vec(X, norm_type) ); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +stddev(const T1& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + return mtOp(X, norm_type, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +stddev(const T1& X, const uword norm_type, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtOp(X, norm_type, dim); + } + + + +template +arma_warn_unused +inline +typename arma_scalar_only::result +stddev(const T&) + { + return T(0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_strans.hpp b/src/armadillo/include/armadillo_bits/fn_strans.hpp new file mode 100644 index 0000000..de81e19 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_strans.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_strans +//! @{ + + + +template +arma_warn_unused +arma_inline +const Op +strans + ( + const T1& X, + const typename enable_if< is_arma_type::value >::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return Op(X); + } + + + +// NOTE: for non-complex objects, deliberately returning op_htrans instead of op_strans, +// NOTE: due to currently more optimisations available when using op_htrans, especially by glue_times +template +arma_warn_unused +arma_inline +const Op +strans + ( + const T1& X, + const typename enable_if< is_arma_type::value >::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return Op(X); + } + + + +// +// handling of sparse matrices + + +template +arma_warn_unused +arma_inline +const SpOp +strans + ( + const T1& X, + const typename enable_if< is_arma_sparse_type::value >::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return SpOp(X); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +strans + ( + const T1& X, + const typename enable_if< is_arma_sparse_type::value >::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return SpOp(X); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_sum.hpp b/src/armadillo/include/armadillo_bits/fn_sum.hpp new file mode 100644 index 0000000..0fa8936 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_sum.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_sum +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value && resolves_to_vector::yes, typename T1::elem_type >::result +sum(const T1& X) + { + arma_extra_debug_sigprint(); + + return accu(X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value && resolves_to_vector::no, const Op >::result +sum(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X, 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const Op >::result +sum(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +sum(const T& x) + { + return x; + } + + + +//! sum of cube +template +arma_warn_unused +arma_inline +const OpCube +sum + ( + const BaseCube& X, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return OpCube(X.get_ref(), dim, 0); + } + + + +//! sum of sparse object +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::elem_type + >::result +sum(const T1& x) + { + arma_extra_debug_sigprint(); + + // sum elements + return accu(x); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const SpOp + >::result +sum(const T1& x) + { + arma_extra_debug_sigprint(); + + return SpOp(x, 0, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + const SpOp + >::result +sum(const T1& x, const uword dim) + { + arma_extra_debug_sigprint(); + + return SpOp(x, dim, 0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_svd.hpp b/src/armadillo/include/armadillo_bits/fn_svd.hpp new file mode 100644 index 0000000..ff987bb --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_svd.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_svd +//! @{ + + + +template +inline +bool +svd + ( + Col& S, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + Mat A(X.get_ref()); + + const bool status = auxlib::svd_dc(S, A); + + if(status == false) + { + S.soft_reset(); + arma_debug_warn_level(3, "svd(): decomposition failed"); + } + + return status; + } + + + +template +arma_warn_unused +inline +Col +svd + ( + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + Col out; + + Mat A(X.get_ref()); + + const bool status = auxlib::svd_dc(out, A); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("svd(): decomposition failed"); + } + + return out; + } + + + +template +inline +bool +svd + ( + Mat& U, + Col& S, + Mat& V, + const Base& X, + const char* method = "dc", + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + arma_debug_check + ( + ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), + "svd(): two or more output objects are the same object" + ); + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 's') && (sig != 'd')), "svd(): unknown method specified" ); + + Mat A(X.get_ref()); + + const bool status = (sig == 'd') ? auxlib::svd_dc(U, S, V, A) : auxlib::svd(U, S, V, A); + + if(status == false) + { + U.soft_reset(); + S.soft_reset(); + V.soft_reset(); + arma_debug_warn_level(3, "svd(): decomposition failed"); + } + + return status; + } + + + +template +inline +bool +svd_econ + ( + Mat& U, + Col& S, + Mat& V, + const Base& X, + const char mode, + const char* method = "dc", + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + arma_debug_check + ( + ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), + "svd_econ(): two or more output objects are the same object" + ); + + arma_debug_check + ( + ( (mode != 'l') && (mode != 'r') && (mode != 'b') ), + "svd_econ(): parameter 'mode' is incorrect" + ); + + const char sig = (method != nullptr) ? method[0] : char(0); + + arma_debug_check( ((sig != 's') && (sig != 'd')), "svd_econ(): unknown method specified" ); + + Mat A(X.get_ref()); + + const bool status = ((mode == 'b') && (sig == 'd')) ? auxlib::svd_dc_econ(U, S, V, A) : auxlib::svd_econ(U, S, V, A, mode); + + if(status == false) + { + U.soft_reset(); + S.soft_reset(); + V.soft_reset(); + arma_debug_warn_level(3, "svd_econ(): decomposition failed"); + } + + return status; + } + + + +template +inline +bool +svd_econ + ( + Mat& U, + Col& S, + Mat& V, + const Base& X, + const char* mode = "both", + const char* method = "dc", + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return svd_econ(U, S, V, X, ((mode != nullptr) ? mode[0] : char(0)), method); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_svds.hpp b/src/armadillo/include/armadillo_bits/fn_svds.hpp new file mode 100644 index 0000000..26c8c50 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_svds.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_svds +//! @{ + + +template +inline +bool +svds_helper + ( + Mat& U, + Col& S, + Mat& V, + const SpBase& X, + const uword k, + const typename T1::pod_type tol, + const bool calc_UV, + const typename arma_real_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + arma_debug_check + ( + ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), + "svds(): two or more output objects are the same object" + ); + + arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); + + const unwrap_spmat tmp(X.get_ref()); + const SpMat& A = tmp.M; + + const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); + + const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); + + if(A_max == T(0)) + { + // TODO: use reset instead ? + S.zeros(kk); + + if(calc_UV) + { + U.eye(A.n_rows, kk); + V.eye(A.n_cols, kk); + } + } + else + { + SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); + + SpMat B = A / A_max; + SpMat Bt = B.t(); + + C(0, A.n_rows, arma::size(B) ) = B; + C(A.n_rows, 0, arma::size(Bt)) = Bt; + + Bt.reset(); + B.reset(); + + Col eigval; + Mat eigvec; + + eigs_opts opts; + opts.tol = (tol / Datum::sqrt2); + + const bool status = eigs_sym(eigval, eigvec, C, kk, "la", opts); + + if(status == false) + { + U.soft_reset(); + S.soft_reset(); + V.soft_reset(); + + return false; + } + + const T A_norm = max(eigval); + + const T tol2 = tol / Datum::sqrt2 * A_norm; + + uvec indices = find(eigval > tol2); + + if(indices.n_elem > kk) + { + indices = indices.subvec(0,kk-1); + } + else + if(indices.n_elem < kk) + { + const uvec indices2 = find(abs(eigval) <= tol2); + + const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); + + if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } + } + + const uvec sorted_indices = sort_index(eigval, "descend"); + + S = eigval.elem(sorted_indices); S *= A_max; + + if(calc_UV) + { + uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } + uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } + + U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); + V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); + } + } + + if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } + + return true; + } + + + +template +inline +bool +svds_helper + ( + Mat& U, + Col& S, + Mat& V, + const SpBase& X, + const uword k, + const typename T1::pod_type tol, + const bool calc_UV, + const typename arma_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(arma_config::arpack == false) + { + arma_stop_logic_error("svds(): use of ARPACK must be enabled for decomposition of complex matrices"); + return false; + } + + arma_debug_check + ( + ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), + "svds(): two or more output objects are the same object" + ); + + arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); + + const unwrap_spmat tmp(X.get_ref()); + const SpMat& A = tmp.M; + + const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); + + const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); + + if(A_max == T(0)) + { + // TODO: use reset instead ? + S.zeros(kk); + + if(calc_UV) + { + U.eye(A.n_rows, kk); + V.eye(A.n_cols, kk); + } + } + else + { + SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); + + SpMat B = A / A_max; + SpMat Bt = B.t(); + + C(0, A.n_rows, arma::size(B) ) = B; + C(A.n_rows, 0, arma::size(Bt)) = Bt; + + Bt.reset(); + B.reset(); + + Col eigval_tmp; + Mat eigvec; + + eigs_opts opts; + opts.tol = (tol / Datum::sqrt2); + + const bool status = eigs_gen(eigval_tmp, eigvec, C, kk, "lr", opts); + + if(status == false) + { + U.soft_reset(); + S.soft_reset(); + V.soft_reset(); + + return false; + } + + const Col eigval = real(eigval_tmp); + + const T A_norm = max(eigval); + + const T tol2 = tol / Datum::sqrt2 * A_norm; + + uvec indices = find(eigval > tol2); + + if(indices.n_elem > kk) + { + indices = indices.subvec(0,kk-1); + } + else + if(indices.n_elem < kk) + { + const uvec indices2 = find(abs(eigval) <= tol2); + + const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); + + if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } + } + + const uvec sorted_indices = sort_index(eigval, "descend"); + + S = eigval.elem(sorted_indices); S *= A_max; + + if(calc_UV) + { + uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } + uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } + + U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); + V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); + } + } + + if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } + + return true; + } + + + +//! find the k largest singular values and corresponding singular vectors of sparse matrix X +template +inline +bool +svds + ( + Mat& U, + Col& S, + Mat& V, + const SpBase& X, + const uword k, + const typename T1::pod_type tol = 0.0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true); + + if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } + + return status; + } + + + +//! find the k largest singular values of sparse matrix X +template +inline +bool +svds + ( + Col& S, + const SpBase& X, + const uword k, + const typename T1::pod_type tol = 0.0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + Mat U; + Mat V; + + const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); + + if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } + + return status; + } + + + +//! find the k largest singular values of sparse matrix X +template +arma_warn_unused +inline +Col +svds + ( + const SpBase& X, + const uword k, + const typename T1::pod_type tol = 0.0, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + Col S; + + Mat U; + Mat V; + + const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); + + if(status == false) { arma_stop_runtime_error("svds(): decomposition failed"); } + + return S; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_sylvester.hpp b/src/armadillo/include/armadillo_bits/fn_sylvester.hpp new file mode 100644 index 0000000..a5b8165 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_sylvester.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_syl_lyap +//! @{ + + +//! find the solution of the Sylvester equation AX + XB = C +template +inline +bool +syl + ( + Mat & out, + const Base& in_A, + const Base& in_B, + const Base& in_C, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const unwrap_check tmp_A(in_A.get_ref(), out); + const unwrap_check tmp_B(in_B.get_ref(), out); + const unwrap_check tmp_C(in_C.get_ref(), out); + + const Mat& A = tmp_A.M; + const Mat& B = tmp_B.M; + const Mat& C = tmp_C.M; + + const bool status = auxlib::syl(out, A, B, C); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "syl(): solution not found"); + } + + return status; + } + + + +template +inline +bool +sylvester + ( + Mat & out, + const Base& in_A, + const Base& in_B, + const Base& in_C, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_ignore(junk); + return syl(out, in_A, in_B, in_C); + } + + + +template +arma_warn_unused +inline +Mat +syl + ( + const Base& in_A, + const Base& in_B, + const Base& in_C, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const unwrap tmp_A( in_A.get_ref() ); + const unwrap tmp_B( in_B.get_ref() ); + const unwrap tmp_C( in_C.get_ref() ); + + const Mat& A = tmp_A.M; + const Mat& B = tmp_B.M; + const Mat& C = tmp_C.M; + + Mat out; + + const bool status = auxlib::syl(out, A, B, C); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("syl(): solution not found"); + } + + return out; + } + + + +template +arma_warn_unused +inline +Mat +sylvester + ( + const Base& in_A, + const Base& in_B, + const Base& in_C, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_ignore(junk); + return syl(in_A, in_B, in_C); + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_symmat.hpp b/src/armadillo/include/armadillo_bits/fn_symmat.hpp new file mode 100644 index 0000000..4bee64d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_symmat.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_symmat +//! @{ + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const Op >::result +symmatu(const Base& X, const bool do_conj = false) + { + arma_extra_debug_sigprint(); + arma_ignore(do_conj); + + return Op(X.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const Op >::result +symmatl(const Base& X, const bool do_conj = false) + { + arma_extra_debug_sigprint(); + arma_ignore(do_conj); + + return Op(X.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::yes, const Op >::result +symmatu(const Base& X, const bool do_conj = true) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), 0, (do_conj ? 1 : 0)); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::yes, const Op >::result +symmatl(const Base& X, const bool do_conj = true) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), 0, (do_conj ? 1 : 0)); + } + + + +// + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const SpOp >::result +symmatu(const SpBase& X, const bool do_conj = false) + { + arma_extra_debug_sigprint(); + arma_ignore(do_conj); + + return SpOp(X.get_ref(), 0, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const SpOp >::result +symmatl(const SpBase& X, const bool do_conj = false) + { + arma_extra_debug_sigprint(); + arma_ignore(do_conj); + + return SpOp(X.get_ref(), 1, 0); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::yes, const SpOp >::result +symmatu(const SpBase& X, const bool do_conj = true) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), 0, (do_conj ? 1 : 0)); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::yes, const SpOp >::result +symmatl(const SpBase& X, const bool do_conj = true) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), 1, (do_conj ? 1 : 0)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_toeplitz.hpp b/src/armadillo/include/armadillo_bits/fn_toeplitz.hpp new file mode 100644 index 0000000..660541b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_toeplitz.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_toeplitz +//! @{ + + + +template +arma_warn_unused +inline +const Op +toeplitz(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op( X.get_ref() ); + } + + + +template +arma_warn_unused +inline +const Op +circ_toeplitz(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op( X.get_ref() ); + } + + + +template +arma_warn_unused +inline +const Glue +toeplitz(const Base& X, const Base& Y) + { + arma_extra_debug_sigprint(); + + return Glue( X.get_ref(), Y.get_ref() ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_trace.hpp b/src/armadillo/include/armadillo_bits/fn_trace.hpp new file mode 100644 index 0000000..8a15bac --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_trace.hpp @@ -0,0 +1,663 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trace +//! @{ + + +template +arma_warn_unused +inline +typename T1::elem_type +trace(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.get_ref()); + + const uword N = (std::min)(P.get_n_rows(), P.get_n_cols()); + + eT val1 = eT(0); + eT val2 = eT(0); + + uword i,j; + for(i=0, j=1; j +arma_warn_unused +inline +typename T1::elem_type +trace(const Op& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const diagmat_proxy A(X.m); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + eT val = eT(0); + + for(uword i=0; i +arma_warn_unused +inline +typename enable_if2< is_cx::no, typename T1::elem_type>::result +trace(const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const partial_unwrap tmp1(X.A); + const partial_unwrap tmp2(X.B); + + const typename partial_unwrap::stored_type& A = tmp1.M; + const typename partial_unwrap::stored_type& B = tmp2.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_elem == 0) || (B.n_elem == 0) ) { return eT(0); } + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + eT acc = eT(0); + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + const uword N = (std::min)(A_n_rows, B_n_cols); + + eT acc1 = eT(0); + eT acc2 = eT(0); + + for(uword k=0; k < N; ++k) + { + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + uword j; + + for(j=1; j < A_n_cols; j+=2) + { + const uword i = (j-1); + + const eT tmp_i = B_colptr[i]; + const eT tmp_j = B_colptr[j]; + + acc1 += A.at(k, i) * tmp_i; + acc2 += A.at(k, j) * tmp_j; + } + + const uword i = (j-1); + + if(i < A_n_cols) + { + acc1 += A.at(k, i) * B_colptr[i]; + } + } + + acc = (acc1 + acc2); + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + const uword N = (std::min)(A_n_cols, B_n_cols); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + acc += op_dot::direct_dot(A_n_rows, A_colptr, B_colptr); + } + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + const uword N = (std::min)(A_n_rows, B_n_rows); + + for(uword k=0; k < N; ++k) + { + // condition: A_n_cols = B_n_cols + for(uword i=0; i < A_n_cols; ++i) + { + acc += A.at(k,i) * B.at(k,i); + } + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + const uword N = (std::min)(A_n_cols, B_n_rows); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + for(uword i=0; i < A_n_rows; ++i) + { + acc += A_colptr[i] * B.at(k,i); + } + } + } + + return (use_alpha) ? (alpha * acc) : acc; + } + + + +//! speedup for trace(A*B); complex elements +template +arma_warn_unused +inline +typename enable_if2< is_cx::yes, typename T1::elem_type>::result +trace(const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + const partial_unwrap tmp1(X.A); + const partial_unwrap tmp2(X.B); + + const typename partial_unwrap::stored_type& A = tmp1.M; + const typename partial_unwrap::stored_type& B = tmp2.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_elem == 0) || (B.n_elem == 0) ) { return eT(0); } + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + eT acc = eT(0); + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + const uword N = (std::min)(A_n_rows, B_n_cols); + + T acc_real = T(0); + T acc_imag = T(0); + + for(uword k=0; k < N; ++k) + { + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k, i) * B_colptr[i]; + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + } + + acc = std::complex(acc_real, acc_imag); + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == false) ) + { + const uword N = (std::min)(A_n_cols, B_n_cols); + + T acc_real = T(0); + T acc_imag = T(0); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * B_colptr[i]; + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + // take into account the complex conjugate of xx + + acc_real += (a*c) + (b*d); + acc_imag += (a*d) - (b*c); + } + } + + acc = std::complex(acc_real, acc_imag); + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true) ) + { + const uword N = (std::min)(A_n_rows, B_n_rows); + + T acc_real = T(0); + T acc_imag = T(0); + + for(uword k=0; k < N; ++k) + { + // condition: A_n_cols = B_n_cols + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k,i) * std::conj(B.at(k,i)); + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + } + + acc = std::complex(acc_real, acc_imag); + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == true) ) + { + const uword N = (std::min)(A_n_cols, B_n_rows); + + T acc_real = T(0); + T acc_imag = T(0); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i)); + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = -xx.imag(); // take the conjugate + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + } + + acc = std::complex(acc_real, acc_imag); + } + + return (use_alpha) ? eT(alpha * acc) : eT(acc); + } + + + +//! trace of sparse object; generic version +template +arma_warn_unused +inline +typename T1::elem_type +trace(const SpBase& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(expr.get_ref()); + + const uword N = (std::min)(P.get_n_rows(), P.get_n_cols()); + + eT acc = eT(0); + + if( (is_SpMat::stored_type>::value) && (P.get_n_nonzero() >= 5*N) ) + { + const unwrap_spmat::stored_type> U(P.Q); + + const SpMat& X = U.M; + + for(uword i=0; i < N; ++i) + { + acc += X.at(i,i); // use binary search + } + } + else + { + typename SpProxy::const_iterator_type it = P.begin(); + + const uword P_n_nz = P.get_n_nonzero(); + + for(uword i=0; i < P_n_nz; ++i) + { + if(it.row() == it.col()) { acc += (*it); } + + ++it; + } + } + + return acc; + } + + + +//! trace of sparse object; speedup for trace(A + B) +template +arma_warn_unused +inline +typename T1::elem_type +trace(const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "addition"); + + return (trace(UA.M) + trace(UB.M)); + } + + + +//! trace of sparse object; speedup for trace(A - B) +template +arma_warn_unused +inline +typename T1::elem_type +trace(const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + arma_debug_assert_same_size(UA.M.n_rows, UA.M.n_cols, UB.M.n_rows, UB.M.n_cols, "subtraction"); + + return (trace(UA.M) - trace(UB.M)); + } + + + +//! trace of sparse object; speedup for trace(A % B) +template +arma_warn_unused +inline +typename T1::elem_type +trace(const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "element-wise multiplication"); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + eT acc = eT(0); + + for(uword i=0; i +arma_warn_unused +inline +typename T1::elem_type +trace(const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // better-than-nothing implementation + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) { return eT(0); } + + const uword N = (std::min)(A.n_rows, B.n_cols); + + eT acc = eT(0); + + // TODO: the threshold may need tuning for complex matrices + if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) ) + { + for(uword k=0; k < N; ++k) + { + typename SpMat::const_col_iterator B_it = B.begin_col_no_sync(k); + typename SpMat::const_col_iterator B_it_end = B.end_col_no_sync(k); + + while(B_it != B_it_end) + { + const eT B_val = (*B_it); + const uword i = B_it.row(); + + acc += A.at(k,i) * B_val; + + ++B_it; + } + } + } + else + { + const SpMat AB = A * B; + + acc = trace(AB); + } + + return acc; + } + + + +//! trace of sparse object; speedup for trace(A.t()*B); non-complex elements +template +arma_warn_unused +inline +typename enable_if2< is_cx::no, typename T1::elem_type>::result +trace(const SpGlue, T2, spglue_times>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(expr.A.m); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation + arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) { return eT(0); } + + const uword N = (std::min)(A.n_cols, B.n_cols); + + eT acc = eT(0); + + if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) ) + { + for(uword k=0; k < N; ++k) + { + typename SpMat::const_col_iterator B_it = B.begin_col_no_sync(k); + typename SpMat::const_col_iterator B_it_end = B.end_col_no_sync(k); + + while(B_it != B_it_end) + { + const eT B_val = (*B_it); + const uword i = B_it.row(); + + acc += A.at(i,k) * B_val; + + ++B_it; + } + } + } + else + { + const SpMat AtB = A.t() * B; + + acc = trace(AtB); + } + + return acc; + } + + + +//! trace of sparse object; speedup for trace(A.t()*B); complex elements +template +arma_warn_unused +inline +typename enable_if2< is_cx::yes, typename T1::elem_type>::result +trace(const SpGlue, T2, spglue_times>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(expr.A.m); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation + arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) { return eT(0); } + + const uword N = (std::min)(A.n_cols, B.n_cols); + + eT acc = eT(0); + + // TODO: the threshold may need tuning for complex matrices + if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) ) + { + for(uword k=0; k < N; ++k) + { + typename SpMat::const_col_iterator B_it = B.begin_col_no_sync(k); + typename SpMat::const_col_iterator B_it_end = B.end_col_no_sync(k); + + while(B_it != B_it_end) + { + const eT B_val = (*B_it); + const uword i = B_it.row(); + + acc += std::conj(A.at(i,k)) * B_val; + + ++B_it; + } + } + } + else + { + const SpMat AtB = A.t() * B; + + acc = trace(AtB); + } + + return acc; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_trans.hpp b/src/armadillo/include/armadillo_bits/fn_trans.hpp new file mode 100644 index 0000000..f558e10 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_trans.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trans +//! @{ + + +template +arma_warn_unused +arma_inline +const Op +trans + ( + const T1& X, + const typename enable_if< is_arma_type::value >::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return Op(X); + } + + + +template +arma_warn_unused +arma_inline +const Op +htrans + ( + const T1& X, + const typename enable_if< is_arma_type::value >::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return Op(X); + } + + + +// +// handling of sparse matrices + + +template +arma_warn_unused +arma_inline +const SpOp +trans + ( + const T1& X, + const typename enable_if< is_arma_sparse_type::value >::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return SpOp(X); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +htrans + ( + const T1& X, + const typename enable_if< is_arma_sparse_type::value >::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return SpOp(X); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_trapz.hpp b/src/armadillo/include/armadillo_bits/fn_trapz.hpp new file mode 100644 index 0000000..72646b7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_trapz.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trapz +//! @{ + + + +template +arma_warn_unused +inline +const Glue +trapz + ( + const Base& X, + const Base& Y, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return Glue(X.get_ref(), Y.get_ref(), dim); + } + + + +template +arma_warn_unused +inline +const Op +trapz + ( + const Base& Y, + const uword dim = 0 + ) + { + arma_extra_debug_sigprint(); + + return Op(Y.get_ref(), dim, uword(0)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_trig.hpp b/src/armadillo/include/armadillo_bits/fn_trig.hpp new file mode 100644 index 0000000..d73947b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_trig.hpp @@ -0,0 +1,493 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trig +//! @{ + + +// single argument trigonometric functions: +// cos family: cos, acos, cosh, acosh +// sin family: sin, asin, sinh, asinh +// tan family: tan, atan, tanh, atanh +// +// misc functions: +// sinc +// +// dual argument trigonometric functions: +// atan2 +// hypot + + + +// +// cos + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +cos(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +cos(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// acos + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +acos(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +acos(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// cosh + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +cosh(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +cosh(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// acosh + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +acosh(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +acosh(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// sin + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +sin(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +sin(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// asin + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +asin(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +asin(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// sinh + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +sinh(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +sinh(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// asinh + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +asinh(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +asinh(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// tan + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +tan(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +tan(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// atan + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +atan(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +atan(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// tanh + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +tanh(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +tanh(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// atanh + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +atanh(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +atanh(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// sinc + +template +arma_warn_unused +arma_inline +typename arma_scalar_only::result +sinc(const T x) + { + return arma_sinc(x); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +sinc(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +sinc(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +// +// atan2 + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_real::value && is_same_type::value), + const Glue + >::result +atan2(const T1& Y, const T2& X) + { + arma_extra_debug_sigprint(); + + return Glue(Y, X); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_real::value, const GlueCube >::result +atan2(const BaseCube& Y, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + return GlueCube(Y.get_ref(), X.get_ref()); + } + + + +// +// hypot + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_real::value && is_same_type::value), + const Glue + >::result +hypot(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return Glue(X, Y); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_real::value, const GlueCube >::result +hypot(const BaseCube& X, const BaseCube& Y) + { + arma_extra_debug_sigprint(); + + return GlueCube(X.get_ref(), Y.get_ref()); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_trimat.hpp b/src/armadillo/include/armadillo_bits/fn_trimat.hpp new file mode 100644 index 0000000..24c95f2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_trimat.hpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trimat +//! @{ + + +template +arma_warn_unused +arma_inline +const Op +trimatu(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), 0, 0); + } + + + +template +arma_warn_unused +arma_inline +const Op +trimatl(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), 1, 0); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +trimatu(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), 0, 0); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +trimatl(const SpBase& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X.get_ref(), 1, 0); + } + + + +// + + + +template +arma_warn_unused +arma_inline +const Op +trimatl(const Base& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return Op(X.get_ref(), row_offset, col_offset); + } + + + +template +arma_warn_unused +arma_inline +const Op +trimatu(const Base& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return Op(X.get_ref(), row_offset, col_offset); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +trimatu(const SpBase& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return SpOp(X.get_ref(), row_offset, col_offset); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +trimatl(const SpBase& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return SpOp(X.get_ref(), row_offset, col_offset); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_trimat_ind.hpp b/src/armadillo/include/armadillo_bits/fn_trimat_ind.hpp new file mode 100644 index 0000000..4e65705 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_trimat_ind.hpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trimat_ind +//! @{ + + +arma_warn_unused +inline +uvec +trimatu_ind(const SizeMat& s, const sword k = 0) + { + arma_extra_debug_sigprint(); + + const uword n_rows = s.n_rows; + const uword n_cols = s.n_cols; + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu_ind(): requested diagonal is out of bounds" ); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + uvec tmp(n_rows * n_cols, arma_nozeros_indicator()); // worst case scenario + uword* tmp_mem = tmp.memptr(); + uword count = 0; + + for(uword i=0; i < n_cols; ++i) + { + const uword col = i + col_offset; + + if(i < N) + { + const uword end_row = i + row_offset; + + const uword index_offset = (n_rows * col); + + for(uword row=0; row <= end_row; ++row) + { + tmp_mem[count] = index_offset + row; + ++count; + } + } + else + { + if(col < n_cols) + { + const uword index_offset = (n_rows * col); + + for(uword row=0; row < n_rows; ++row) + { + tmp_mem[count] = index_offset + row; + ++count; + } + } + } + } + + uvec out; + + out.steal_mem_col(tmp, count); + + return out; + } + + + +arma_warn_unused +inline +uvec +trimatl_ind(const SizeMat& s, const sword k = 0) + { + arma_extra_debug_sigprint(); + + const uword n_rows = s.n_rows; + const uword n_cols = s.n_cols; + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl_ind(): requested diagonal is out of bounds" ); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + uvec tmp(n_rows * n_cols, arma_nozeros_indicator()); // worst case scenario + uword* tmp_mem = tmp.memptr(); + uword count = 0; + + for(uword col=0; col < col_offset; ++col) + { + const uword index_offset = (n_rows * col); + + for(uword row=0; row < n_rows; ++row) + { + tmp_mem[count] = index_offset + row; + ++count; + } + } + + for(uword i=0; i +arma_warn_unused +inline +static +typename arma_real_only::result +trunc_exp(const eT x) + { + if(std::numeric_limits::is_iec559 && (x >= Datum::log_max )) + { + return std::numeric_limits::max(); + } + else + { + return std::exp(x); + } + } + + + +template +arma_warn_unused +inline +static +typename arma_integral_only::result +trunc_exp(const eT x) + { + return eT( trunc_exp( double(x) ) ); + } + + + +template +arma_warn_unused +inline +static +std::complex +trunc_exp(const std::complex& x) + { + return std::polar( trunc_exp( x.real() ), x.imag() ); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +trunc_exp(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +trunc_exp(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_trunc_log.hpp b/src/armadillo/include/armadillo_bits/fn_trunc_log.hpp new file mode 100644 index 0000000..9cbf826 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_trunc_log.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trunc_log +//! @{ + + + +template +arma_warn_unused +inline +static +typename arma_real_only::result +trunc_log(const eT x) + { + if(std::numeric_limits::is_iec559) + { + if(x == std::numeric_limits::infinity()) + { + return Datum::log_max; + } + else + { + return (x <= eT(0)) ? Datum::log_min : std::log(x); + } + } + else + { + return std::log(x); + } + } + + + +template +arma_warn_unused +inline +static +typename arma_integral_only::result +trunc_log(const eT x) + { + return eT( trunc_log( double(x) ) ); + } + + + +template +arma_warn_unused +inline +static +std::complex +trunc_log(const std::complex& x) + { + return std::complex( trunc_log( std::abs(x) ), std::arg(x) ); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +trunc_log(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +const eOpCube +trunc_log(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_unique.hpp b/src/armadillo/include/armadillo_bits/fn_unique.hpp new file mode 100644 index 0000000..99861a5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_unique.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_unique +//! @{ + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + const Op + >::result +unique(const T1& A) + { + arma_extra_debug_sigprint(); + + return Op(A); + } + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const Op + >::result +unique(const T1& A) + { + arma_extra_debug_sigprint(); + + return Op(A); + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_var.hpp b/src/armadillo/include/armadillo_bits/fn_var.hpp new file mode 100644 index 0000000..090f4ce --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_var.hpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_var +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + typename T1::pod_type + >::result +var(const T1& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + return op_var::var_vec(X, norm_type); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +var(const T1& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + return mtOp(X, norm_type, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +var(const T1& X, const uword norm_type, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtOp(X, norm_type, dim); + } + + + +template +arma_warn_unused +inline +typename arma_scalar_only::result +var(const T&) + { + return T(0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::pod_type + >::result +var(const T1& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + return spop_var::var_vec(X, norm_type); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const mtSpOp + >::result +var(const T1& X, const uword norm_type = 0) + { + arma_extra_debug_sigprint(); + + return mtSpOp(X, norm_type, 0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + const mtSpOp + >::result +var(const T1& X, const uword norm_type, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtSpOp(X, norm_type, dim); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_vecnorm.hpp b/src/armadillo/include/armadillo_bits/fn_vecnorm.hpp new file mode 100644 index 0000000..0fa88aa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_vecnorm.hpp @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_vecnorm +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::pod_type T; + + const Proxy P(X); + + if(P.get_n_elem() == 0) { return T(0); } + + if(k == uword(1)) { return op_norm::vec_norm_1(P); } + if(k == uword(2)) { return op_norm::vec_norm_2(P); } + + arma_debug_check( (k == 0), "vecnorm(): unsupported vector norm type" ); + + return op_norm::vec_norm_k(P, int(k)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword dim = 0; + + return mtOp(X, k, dim); + } + + + +template +arma_warn_unused +inline +const mtOp +vecnorm + ( + const Base& X, + const uword k, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtOp(X.get_ref(), k, dim); + } + + + +// + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::pod_type T; + + const Proxy P(X); + + if(P.get_n_elem() == 0) { return T(0); } + + const char sig = (method != nullptr) ? method[0] : char(0); + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { return op_norm::vec_norm_max(P); } + if( (sig == '-') ) { return op_norm::vec_norm_min(P); } + + arma_stop_logic_error("vecnorm(): unsupported vector norm type"); + + return T(0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + const uword dim = 0; + + return mtOp(X, method_id, dim); + } + + + +template +arma_warn_unused +inline +const mtOp +vecnorm + ( + const Base& X, + const char* method, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + return mtOp(X.get_ref(), method_id, dim); + } + + + +// +// norms for sparse matrices + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return arma::norm(X, k); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const mtSpOp + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword dim = 0; + + return mtSpOp(X, k, dim); + } + + + +template +arma_warn_unused +inline +const mtSpOp +vecnorm + ( + const SpBase& X, + const uword k, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtSpOp(X.get_ref(), k, dim); + } + + + +// + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return arma::norm(X, method); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const mtSpOp + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + const uword dim = 0; + + return mtSpOp(X, method_id, dim); + } + + + +template +arma_warn_unused +inline +const mtSpOp +vecnorm + ( + const SpBase& X, + const char* method, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + return mtSpOp(X.get_ref(), method_id, dim); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_vectorise.hpp b/src/armadillo/include/armadillo_bits/fn_vectorise.hpp new file mode 100644 index 0000000..ff21006 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_vectorise.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_vectorise +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +vectorise(const T1& X) + { + arma_extra_debug_sigprint(); + + return Op(X); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const Op + >::result +vectorise(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (dim > 1), "vectorise(): parameter 'dim' must be 0 or 1" ); + + return Op(X, dim, 0); + } + + + +template +arma_warn_unused +inline +CubeToMatOp +vectorise(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + return CubeToMatOp(X.get_ref()); + } + + + +//! Vectorization for sparse objects. +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + const SpOp + >::result +vectorise(const T1& X) + { + arma_extra_debug_sigprint(); + + return SpOp(X); + } + + + +//! Vectorization for sparse objects. +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + const SpOp + >::result +vectorise(const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (dim > 1), "vectorise(): parameter 'dim' must be 0 or 1" ); + + return SpOp(X, dim, 0); + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_wishrnd.hpp b/src/armadillo/include/armadillo_bits/fn_wishrnd.hpp new file mode 100644 index 0000000..3f05b77 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_wishrnd.hpp @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_wishrnd +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_real::value, + const Op + >::result +wishrnd(const Base& S, typename T1::elem_type df) + { + arma_extra_debug_sigprint(); + + return Op(S.get_ref(), df, uword(1), uword(0)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_real::value, + const Op + >::result +wishrnd(const Base& S, typename T1::elem_type df, const Base& D) + { + arma_extra_debug_sigprint(); + arma_ignore(S); + + return Op(D.get_ref(), df, uword(2), uword(0)); + } + + + +template +inline +typename +enable_if2 + < + is_real::value, + bool + >::result +wishrnd(Mat& W, const Base& S, typename T1::elem_type df) + { + arma_extra_debug_sigprint(); + + const bool status = op_wishrnd::apply_direct(W, S.get_ref(), df, uword(1)); + + if(status == false) + { + W.soft_reset(); + arma_debug_warn_level(3, "wishrnd(): given matrix is not symmetric positive definite"); + } + + return status; + } + + + +template +inline +typename +enable_if2 + < + is_real::value, + bool + >::result +wishrnd(Mat& W, const Base& S, typename T1::elem_type df, const Base& D) + { + arma_extra_debug_sigprint(); + arma_ignore(S); + + const bool status = op_wishrnd::apply_direct(W, D.get_ref(), df, uword(2)); + + if(status == false) + { + W.soft_reset(); + arma_debug_warn_level(3, "wishrnd(): problem with given 'D' matrix"); + } + + return status; + } + + + +// + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_real::value, + const Op + >::result +iwishrnd(const Base& T, typename T1::elem_type df) + { + arma_extra_debug_sigprint(); + + return Op(T.get_ref(), df, uword(1), uword(0)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_real::value, + const Op + >::result +iwishrnd(const Base& T, typename T1::elem_type df, const Base& Dinv) + { + arma_extra_debug_sigprint(); + arma_ignore(T); + + return Op(Dinv.get_ref(), df, uword(2), uword(0)); + } + + + +template +inline +typename +enable_if2 + < + is_real::value, + bool + >::result +iwishrnd(Mat& W, const Base& T, typename T1::elem_type df) + { + arma_extra_debug_sigprint(); + + const bool status = op_iwishrnd::apply_direct(W, T.get_ref(), df, uword(1)); + + if(status == false) + { + W.soft_reset(); + arma_debug_warn_level(3, "iwishrnd(): given matrix is not symmetric positive definite and/or df is too low"); + } + + return status; + } + + + +template +inline +typename +enable_if2 + < + is_real::value, + bool + >::result +iwishrnd(Mat& W, const Base& T, typename T1::elem_type df, const Base& Dinv) + { + arma_extra_debug_sigprint(); + arma_ignore(T); + + const bool status = op_iwishrnd::apply_direct(W, Dinv.get_ref(), df, uword(2)); + + if(status == false) + { + W.soft_reset(); + arma_debug_warn_level(3, "wishrnd(): problem with given 'Dinv' matrix and/or df is too low"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/fn_zeros.hpp b/src/armadillo/include/armadillo_bits/fn_zeros.hpp new file mode 100644 index 0000000..5f06922 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/fn_zeros.hpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_zeros +//! @{ + + + +arma_warn_unused +arma_inline +const Gen +zeros(const uword n_elem) + { + arma_extra_debug_sigprint(); + + return Gen(n_elem, 1); + } + + + +template +arma_warn_unused +arma_inline +const Gen +zeros(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + return Gen(n_rows, n_cols); + } + + + +arma_warn_unused +arma_inline +const Gen +zeros(const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + return Gen(n_rows, n_cols); + } + + + +arma_warn_unused +arma_inline +const Gen +zeros(const SizeMat& s) + { + arma_extra_debug_sigprint(); + + return Gen(s.n_rows, s.n_cols); + } + + + +template +arma_warn_unused +arma_inline +const Gen +zeros(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_Col::value) { arma_debug_check( (n_cols != 1), "zeros(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "zeros(): incompatible size" ); } + + return Gen(n_rows, n_cols); + } + + + +template +arma_warn_unused +arma_inline +const Gen +zeros(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return zeros(s.n_rows, s.n_cols); + } + + + +arma_warn_unused +arma_inline +const GenCube +zeros(const uword n_rows, const uword n_cols, const uword n_slices) + { + arma_extra_debug_sigprint(); + + return GenCube(n_rows, n_cols, n_slices); + } + + + +arma_warn_unused +arma_inline +const GenCube +zeros(const SizeCube& s) + { + arma_extra_debug_sigprint(); + + return GenCube(s.n_rows, s.n_cols, s.n_slices); + } + + + +template +arma_warn_unused +arma_inline +const GenCube +zeros(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return GenCube(n_rows, n_cols, n_slices); + } + + + +template +arma_warn_unused +arma_inline +const GenCube +zeros(const SizeCube& s, const typename arma_Cube_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return GenCube(s.n_rows, s.n_cols, s.n_slices); + } + + + +template +arma_warn_unused +inline +sp_obj_type +zeros(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_SpCol::value) { arma_debug_check( (n_cols != 1), "zeros(): incompatible size" ); } + if(is_SpRow::value) { arma_debug_check( (n_rows != 1), "zeros(): incompatible size" ); } + + return sp_obj_type(n_rows, n_cols); + } + + + +template +arma_warn_unused +inline +sp_obj_type +zeros(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return zeros(s.n_rows, s.n_cols); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_affmul_bones.hpp b/src/armadillo/include/armadillo_bits/glue_affmul_bones.hpp new file mode 100644 index 0000000..5284b6c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_affmul_bones.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_affmul +//! @{ + + + +class glue_affmul + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const Glue& X); + + template + inline static void apply_noalias(Mat& out, const T1& A, const T2& B); + + template + inline static void apply_noalias_square(Mat& out, const T1& A, const T2& B); + + template + inline static void apply_noalias_rectangle(Mat& out, const T1& A, const T2& B); + + template + inline static void apply_noalias_generic(Mat& out, const T1& A, const T2& B); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_affmul_meat.hpp b/src/armadillo/include/armadillo_bits/glue_affmul_meat.hpp new file mode 100644 index 0000000..19c3799 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_affmul_meat.hpp @@ -0,0 +1,490 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_affmul +//! @{ + + + +template +inline +void +glue_affmul::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U1(X.A); + const quasi_unwrap U2(X.B); + + const bool is_alias = (U1.is_alias(out) || U2.is_alias(out)); + + if(is_alias == false) + { + glue_affmul::apply_noalias(out, U1.M, U2.M); + } + else + { + Mat tmp; + + glue_affmul::apply_noalias(tmp, U1.M, U2.M); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_affmul::apply_noalias(Mat& out, const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + const uword A_n_cols = A.n_cols; + const uword A_n_rows = A.n_rows; + const uword B_n_rows = B.n_rows; + + arma_debug_check( (A_n_cols != B_n_rows+1), "affmul(): size mismatch" ); + + if(A_n_rows == A_n_cols) + { + glue_affmul::apply_noalias_square(out, A, B); + } + else + if(A_n_rows == B_n_rows) + { + glue_affmul::apply_noalias_rectangle(out, A, B); + } + else + { + glue_affmul::apply_noalias_generic(out, A, B); + } + } + + + +template +inline +void +glue_affmul::apply_noalias_square(Mat& out, const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // assuming that A is square sized, and A.n_cols = B.n_rows+1 + + const uword N = A.n_rows; + const uword B_n_cols = B.n_cols; + + out.set_size(N, B_n_cols); + + if(out.n_elem == 0) { return; } + + const eT* A_mem = A.memptr(); + + switch(N) + { + case 0: + break; + + case 1: // A is 1x1 + out.fill(A_mem[0]); + break; + + case 2: // A is 2x2 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + + out_mem[0] = A_mem[0]*x + A_mem[2]; + out_mem[1] = A_mem[1]*x + A_mem[3]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + + out_mem[0] = A_mem[0]*x + A_mem[2]; + out_mem[1] = A_mem[1]*x + A_mem[3]; + } + } + break; + + case 3: // A is 3x3 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + + out_mem[0] = A_mem[0]*x + A_mem[3]*y + A_mem[6]; + out_mem[1] = A_mem[1]*x + A_mem[4]*y + A_mem[7]; + out_mem[2] = A_mem[2]*x + A_mem[5]*y + A_mem[8]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + + out_mem[0] = A_mem[0]*x + A_mem[3]*y + A_mem[6]; + out_mem[1] = A_mem[1]*x + A_mem[4]*y + A_mem[7]; + out_mem[2] = A_mem[2]*x + A_mem[5]*y + A_mem[8]; + } + } + break; + + case 4: // A is 4x4 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 4]*y + A_mem[ 8]*z + A_mem[12]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 5]*y + A_mem[ 9]*z + A_mem[13]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 6]*y + A_mem[10]*z + A_mem[14]; + out_mem[3] = A_mem[ 3]*x + A_mem[ 7]*y + A_mem[11]*z + A_mem[15]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 4]*y + A_mem[ 8]*z + A_mem[12]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 5]*y + A_mem[ 9]*z + A_mem[13]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 6]*y + A_mem[10]*z + A_mem[14]; + out_mem[3] = A_mem[ 3]*x + A_mem[ 7]*y + A_mem[11]*z + A_mem[15]; + } + } + break; + + case 5: // A is 5x5 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + const eT w = B_mem[3]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 5]*y + A_mem[10]*z + A_mem[15]*w + A_mem[20]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 6]*y + A_mem[11]*z + A_mem[16]*w + A_mem[21]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 7]*y + A_mem[12]*z + A_mem[17]*w + A_mem[22]; + out_mem[3] = A_mem[ 3]*x + A_mem[ 8]*y + A_mem[13]*z + A_mem[18]*w + A_mem[23]; + out_mem[4] = A_mem[ 4]*x + A_mem[ 9]*y + A_mem[14]*z + A_mem[19]*w + A_mem[24]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + const eT w = B_mem[3]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 5]*y + A_mem[10]*z + A_mem[15]*w + A_mem[20]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 6]*y + A_mem[11]*z + A_mem[16]*w + A_mem[21]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 7]*y + A_mem[12]*z + A_mem[17]*w + A_mem[22]; + out_mem[3] = A_mem[ 3]*x + A_mem[ 8]*y + A_mem[13]*z + A_mem[18]*w + A_mem[23]; + out_mem[4] = A_mem[ 4]*x + A_mem[ 9]*y + A_mem[14]*z + A_mem[19]*w + A_mem[24]; + } + } + break; + + default: + { + if(B_n_cols == 1) + { + Col tmp(N, arma_nozeros_indicator()); + eT* tmp_mem = tmp.memptr(); + + arrayops::copy(tmp_mem, B.memptr(), N-1); + + tmp_mem[N-1] = eT(1); + + out = A * tmp; + } + else + { + Mat tmp(N, B_n_cols, arma_nozeros_indicator()); + + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* tmp_mem = tmp.colptr(col); + + arrayops::copy(tmp_mem, B_mem, N-1); + + tmp_mem[N-1] = eT(1); + } + + out = A * tmp; + } + } + } + } + + + +template +inline +void +glue_affmul::apply_noalias_rectangle(Mat& out, const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // assuming that A.n_rows = A.n_cols-1, and A.n_cols = B.n_rows+1 + // (A and B have the same number of rows) + + const uword A_n_rows = A.n_rows; + const uword B_n_cols = B.n_cols; + + out.set_size(A_n_rows, B_n_cols); + + if(out.n_elem == 0) { return; } + + const eT* A_mem = A.memptr(); + + switch(A_n_rows) + { + case 0: + break; + + case 1: // A is 1x2 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + + out_mem[0] = A_mem[0]*x + A_mem[1]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + + out_mem[0] = A_mem[0]*x + A_mem[1]; + } + } + break; + + case 2: // A is 2x3 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + + out_mem[0] = A_mem[0]*x + A_mem[2]*y + A_mem[4]; + out_mem[1] = A_mem[1]*x + A_mem[3]*y + A_mem[5]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + + out_mem[0] = A_mem[0]*x + A_mem[2]*y + A_mem[4]; + out_mem[1] = A_mem[1]*x + A_mem[3]*y + A_mem[5]; + } + } + break; + + case 3: // A is 3x4 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 3]*y + A_mem[ 6]*z + A_mem[ 9]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 4]*y + A_mem[ 7]*z + A_mem[10]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 5]*y + A_mem[ 8]*z + A_mem[11]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 3]*y + A_mem[ 6]*z + A_mem[ 9]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 4]*y + A_mem[ 7]*z + A_mem[10]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 5]*y + A_mem[ 8]*z + A_mem[11]; + } + } + break; + + case 4: // A is 4x5 + { + if(B_n_cols == 1) + { + const eT* B_mem = B.memptr(); + eT* out_mem = out.memptr(); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + const eT w = B_mem[3]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 4]*y + A_mem[ 8]*z + A_mem[12]*w + A_mem[16]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 5]*y + A_mem[ 9]*z + A_mem[13]*w + A_mem[17]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 6]*y + A_mem[10]*z + A_mem[14]*w + A_mem[18]; + out_mem[3] = A_mem[ 3]*x + A_mem[ 7]*y + A_mem[11]*z + A_mem[15]*w + A_mem[19]; + } + else + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* out_mem = out.colptr(col); + + const eT x = B_mem[0]; + const eT y = B_mem[1]; + const eT z = B_mem[2]; + const eT w = B_mem[3]; + + out_mem[0] = A_mem[ 0]*x + A_mem[ 4]*y + A_mem[ 8]*z + A_mem[12]*w + A_mem[16]; + out_mem[1] = A_mem[ 1]*x + A_mem[ 5]*y + A_mem[ 9]*z + A_mem[13]*w + A_mem[17]; + out_mem[2] = A_mem[ 2]*x + A_mem[ 6]*y + A_mem[10]*z + A_mem[14]*w + A_mem[18]; + out_mem[3] = A_mem[ 3]*x + A_mem[ 7]*y + A_mem[11]*z + A_mem[15]*w + A_mem[19]; + } + } + break; + + default: + { + const uword A_n_cols = A.n_cols; + + if(B_n_cols == 1) + { + Col tmp(A_n_cols, arma_nozeros_indicator()); + eT* tmp_mem = tmp.memptr(); + + arrayops::copy(tmp_mem, B.memptr(), A_n_cols-1); + + tmp_mem[A_n_cols-1] = eT(1); + + out = A * tmp; + } + else + { + Mat tmp(A_n_cols, B_n_cols, arma_nozeros_indicator()); + + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* tmp_mem = tmp.colptr(col); + + arrayops::copy(tmp_mem, B_mem, A_n_cols-1); + + tmp_mem[A_n_cols-1] = eT(1); + } + + out = A * tmp; + } + } + } + } + + + +template +inline +void +glue_affmul::apply_noalias_generic(Mat& out, const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // assuming that A.n_cols = B.n_rows+1 + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + Mat tmp(B_n_rows+1, B_n_cols, arma_nozeros_indicator()); + + for(uword col=0; col < B_n_cols; ++col) + { + const eT* B_mem = B.colptr(col); + eT* tmp_mem = tmp.colptr(col); + + arrayops::copy(tmp_mem, B_mem, B_n_rows); + + tmp_mem[B_n_rows] = eT(1); + } + + out = A * tmp; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_atan2_bones.hpp b/src/armadillo/include/armadillo_bits/glue_atan2_bones.hpp new file mode 100644 index 0000000..f60e783 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_atan2_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_atan2 +//! @{ + + + +class glue_atan2 + : public traits_glue_or + { + public: + + + // matrices + + template inline static void apply(Mat& out, const Glue& expr); + + template inline static void apply_noalias(Mat& out, const Proxy& P1, const Proxy& P2); + + + // cubes + + template inline static void apply(Cube& out, const GlueCube& expr); + + template inline static void apply_noalias(Cube& out, const ProxyCube& P1, const ProxyCube& P2); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_atan2_meat.hpp b/src/armadillo/include/armadillo_bits/glue_atan2_meat.hpp new file mode 100644 index 0000000..38469ed --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_atan2_meat.hpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_atan2 +//! @{ + + + +template +inline +void +glue_atan2::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P1(expr.A); + const Proxy P2(expr.B); + + arma_assert_same_size(P1, P2, "atan2()"); + + const bool bad_alias = ( (Proxy::has_subview && P1.is_alias(out)) || (Proxy::has_subview && P2.is_alias(out)) ); + + if(bad_alias == false) + { + glue_atan2::apply_noalias(out, P1, P2); + } + else + { + Mat tmp; + + glue_atan2::apply_noalias(tmp, P1, P2); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_atan2::apply_noalias(Mat& out, const Proxy& P1, const Proxy& P2) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P1.get_n_rows(); + const uword n_cols = P1.get_n_cols(); + const uword n_elem = P1.get_n_elem(); + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + const bool use_mp = arma_config::openmp && mp_gate::use_mp || Proxy::use_mp)>::eval(n_elem); + const bool use_at = Proxy::use_at || Proxy::use_at; + + if(use_at == false) + { + typename Proxy::ea_type eaP1 = P1.get_ea(); + typename Proxy::ea_type eaP2 = P2.get_ea(); + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i::stored_type> U1(P1.Q); + const unwrap::stored_type> U2(P2.Q); + + out = arma::atan2(U1.M, U2.M); + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + *out_mem = std::atan2( P1.at(row,col), P2.at(row,col) ); + out_mem++; + } + } + } + } + + + +template +inline +void +glue_atan2::apply(Cube& out, const GlueCube& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const ProxyCube P1(expr.A); + const ProxyCube P2(expr.B); + + arma_assert_same_size(P1, P2, "atan2()"); + + const bool bad_alias = ( (ProxyCube::has_subview && P1.is_alias(out)) || (ProxyCube::has_subview && P2.is_alias(out)) ); + + if(bad_alias == false) + { + glue_atan2::apply_noalias(out, P1, P2); + } + else + { + Cube tmp; + + glue_atan2::apply_noalias(tmp, P1, P2); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_atan2::apply_noalias(Cube& out, const ProxyCube& P1, const ProxyCube& P2) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P1.get_n_rows(); + const uword n_cols = P1.get_n_cols(); + const uword n_slices = P1.get_n_slices(); + const uword n_elem = P1.get_n_elem(); + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + const bool use_mp = arma_config::openmp && mp_gate::use_mp || ProxyCube::use_mp)>::eval(n_elem); + const bool use_at = ProxyCube::use_at || ProxyCube::use_at; + + if(use_at == false) + { + typename ProxyCube::ea_type eaP1 = P1.get_ea(); + typename ProxyCube::ea_type eaP2 = P2.get_ea(); + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i::stored_type> U1(P1.Q); + const unwrap_cube::stored_type> U2(P2.Q); + + out = arma::atan2(U1.M, U2.M); + } + else + { + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + *out_mem = std::atan2( P1.at(row,col,slice), P2.at(row,col,slice) ); + out_mem++; + } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_conv_bones.hpp b/src/armadillo/include/armadillo_bits/glue_conv_bones.hpp new file mode 100644 index 0000000..5382f84 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_conv_bones.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_conv +//! @{ + + + +class glue_conv + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template inline static void apply(Mat& out, const Mat& A, const Mat& B, const bool A_is_col); + + template inline static void apply(Mat& out, const Glue& X); + }; + + + +class glue_conv2 + : public traits_glue_default + { + public: + + template inline static void apply(Mat& out, const Mat& A, const Mat& B); + + template inline static void apply(Mat& out, const Glue& expr); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_conv_meat.hpp b/src/armadillo/include/armadillo_bits/glue_conv_meat.hpp new file mode 100644 index 0000000..e722ff4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_conv_meat.hpp @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_conv +//! @{ + + + +// TODO: this implementation of conv() is rudimentary; replace with faster version +template +inline +void +glue_conv::apply(Mat& out, const Mat& A, const Mat& B, const bool A_is_col) + { + arma_extra_debug_sigprint(); + + const Mat& h = (A.n_elem <= B.n_elem) ? A : B; + const Mat& x = (A.n_elem <= B.n_elem) ? B : A; + + const uword h_n_elem = h.n_elem; + const uword h_n_elem_m1 = h_n_elem - 1; + const uword x_n_elem = x.n_elem; + const uword out_n_elem = ((h_n_elem + x_n_elem) > 0) ? (h_n_elem + x_n_elem - 1) : uword(0); + + if( (h_n_elem == 0) || (x_n_elem == 0) ) { out.zeros(); return; } + + + Col hh(h_n_elem, arma_nozeros_indicator()); // flipped version of h + + const eT* h_mem = h.memptr(); + eT* hh_mem = hh.memptr(); + + for(uword i=0; i < h_n_elem; ++i) + { + hh_mem[h_n_elem_m1-i] = h_mem[i]; + } + + + Col xx( (x_n_elem + 2*h_n_elem_m1), arma_zeros_indicator() ); // zero padded version of x + + const eT* x_mem = x.memptr(); + eT* xx_mem = xx.memptr(); + + arrayops::copy( &(xx_mem[h_n_elem_m1]), x_mem, x_n_elem ); + + + (A_is_col) ? out.set_size(out_n_elem, 1) : out.set_size(1, out_n_elem); + + eT* out_mem = out.memptr(); + + if( (arma_config::openmp) && (x_n_elem >= 128) && (h_n_elem >= 64) && (mp_thread_limit::in_parallel() == false) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < out_n_elem; ++i) + { + out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) ); + } + } + #endif + } + else + { + for(uword i=0; i < out_n_elem; ++i) + { + // out_mem[i] = dot( hh, xx.subvec(i, (i + h_n_elem_m1)) ); + + out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) ); + } + } + } + + + +// // alternative implementation of 1d convolution +// template +// inline +// void +// glue_conv::apply(Mat& out, const Mat& A, const Mat& B, const bool A_is_col) +// { +// arma_extra_debug_sigprint(); +// +// const Mat& h = (A.n_elem <= B.n_elem) ? A : B; +// const Mat& x = (A.n_elem <= B.n_elem) ? B : A; +// +// const uword h_n_elem = h.n_elem; +// const uword h_n_elem_m1 = h_n_elem - 1; +// const uword x_n_elem = x.n_elem; +// const uword out_n_elem = ((h_n_elem + x_n_elem) > 0) ? (h_n_elem + x_n_elem - 1) : uword(0); +// +// if( (h_n_elem == 0) || (x_n_elem == 0) ) { out.zeros(); return; } +// +// +// Col hh(h_n_elem, arma_nozeros_indicator()); // flipped version of h +// +// const eT* h_mem = h.memptr(); +// eT* hh_mem = hh.memptr(); +// +// for(uword i=0; i < h_n_elem; ++i) +// { +// hh_mem[h_n_elem_m1-i] = h_mem[i]; +// } +// +// // construct HH matrix, with the column containing shifted versions of hh; +// // upper limit for number of zeros is about 50%; may not be optimal +// const uword N_copies = (std::min)(uword(10), h_n_elem); +// +// const uword HH_n_rows = h_n_elem + (N_copies-1); +// +// Mat HH(HH_n_rows, N_copies, arma_zeros_indicator()); +// +// for(uword i=0; i xx( (x_n_elem + 2*h_n_elem_m1), arma_zeros_indicator() ); // zero padded version of x +// +// const eT* x_mem = x.memptr(); +// eT* xx_mem = xx.memptr(); +// +// arrayops::copy( &(xx_mem[h_n_elem_m1]), x_mem, x_n_elem ); +// +// +// (A_is_col) ? out.set_size(out_n_elem, 1) : out.set_size(1, out_n_elem); +// +// eT* out_mem = out.memptr(); +// +// uword last_i = 0; +// bool last_i_done = false; +// +// for(uword i=0; i < xx.n_elem; i += N_copies) +// { +// if( ((i + HH_n_rows) <= xx.n_elem) && ((i + N_copies) <= out_n_elem) ) +// { +// const Row xx_sub(xx_mem + i, HH_n_rows, false, true); +// +// Row out_sub(out_mem + i, N_copies, false, true); +// +// out_sub = xx_sub * HH; +// +// last_i_done = true; +// } +// else +// { +// last_i = i; +// last_i_done = false; +// break; +// } +// } +// +// if(last_i_done == false) +// { +// for(uword i=last_i; i < out_n_elem; ++i) +// { +// // out_mem[i] = dot( hh, xx.subvec(i, (i + h_n_elem_m1)) ); +// +// out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) ); +// } +// } +// } + + + +template +inline +void +glue_conv::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + const Mat& A = UA.M; + const Mat& B = UB.M; + + arma_debug_check + ( + ( ((A.is_vec() == false) && (A.is_empty() == false)) || ((B.is_vec() == false) && (B.is_empty() == false)) ), + "conv(): given object must be a vector" + ); + + const bool A_is_col = ((T1::is_col) || (A.n_cols == 1)); + + const uword mode = expr.aux_uword; + + if(mode == 0) // full convolution + { + glue_conv::apply(out, A, B, A_is_col); + } + else + if(mode == 1) // same size as A + { + Mat tmp; + + glue_conv::apply(tmp, A, B, A_is_col); + + if( (tmp.is_empty() == false) && (A.is_empty() == false) && (B.is_empty() == false) ) + { + const uword start = uword( std::floor( double(B.n_elem) / double(2) ) ); + + out = (A_is_col) ? tmp(start, 0, arma::size(A)) : tmp(0, start, arma::size(A)); + } + else + { + out.zeros( arma::size(A) ); + } + } + } + + + +/// + + + +// TODO: this implementation of conv2() is rudimentary; replace with faster version +template +inline +void +glue_conv2::apply(Mat& out, const Mat& A, const Mat& B) + { + arma_extra_debug_sigprint(); + + const Mat& G = (A.n_elem <= B.n_elem) ? A : B; // unflipped filter coefficients + const Mat& W = (A.n_elem <= B.n_elem) ? B : A; // original 2D image + + const uword out_n_rows = ((W.n_rows + G.n_rows) > 0) ? (W.n_rows + G.n_rows - 1) : uword(0); + const uword out_n_cols = ((W.n_cols + G.n_cols) > 0) ? (W.n_cols + G.n_cols - 1) : uword(0); + + if(G.is_empty() || W.is_empty()) { out.zeros(); return; } + + + Mat H(G.n_rows, G.n_cols, arma_nozeros_indicator()); // flipped filter coefficients + + const uword H_n_rows = H.n_rows; + const uword H_n_cols = H.n_cols; + + const uword H_n_rows_m1 = H_n_rows - 1; + const uword H_n_cols_m1 = H_n_cols - 1; + + for(uword col=0; col < H_n_cols; ++col) + { + eT* H_colptr = H.colptr(H_n_cols_m1 - col); + const eT* G_colptr = G.colptr(col); + + for(uword row=0; row < H_n_rows; ++row) + { + H_colptr[H_n_rows_m1 - row] = G_colptr[row]; + } + } + + + Mat X( (W.n_rows + 2*H_n_rows_m1), (W.n_cols + 2*H_n_cols_m1), arma_zeros_indicator() ); + + X( H_n_rows_m1, H_n_cols_m1, arma::size(W) ) = W; // zero padded version of 2D image + + + out.set_size( out_n_rows, out_n_cols ); + + if( (arma_config::openmp) && (out_n_cols >= 2) && (mp_thread_limit::in_parallel() == false) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < out_n_cols; ++col) + { + eT* out_colptr = out.colptr(col); + + for(uword row=0; row < out_n_rows; ++row) + { + // out.at(row, col) = accu( H % X(row, col, size(H)) ); + + eT acc = eT(0); + + for(uword H_col = 0; H_col < H_n_cols; ++H_col) + { + const eT* X_colptr = X.colptr(col + H_col); + + acc += op_dot::direct_dot( H_n_rows, H.colptr(H_col), &(X_colptr[row]) ); + } + + out_colptr[row] = acc; + } + } + } + #endif + } + else + { + for(uword col=0; col < out_n_cols; ++col) + { + eT* out_colptr = out.colptr(col); + + for(uword row=0; row < out_n_rows; ++row) + { + // out.at(row, col) = accu( H % X(row, col, size(H)) ); + + eT acc = eT(0); + + for(uword H_col = 0; H_col < H_n_cols; ++H_col) + { + const eT* X_colptr = X.colptr(col + H_col); + + acc += op_dot::direct_dot( H_n_rows, H.colptr(H_col), &(X_colptr[row]) ); + } + + out_colptr[row] = acc; + } + } + } + } + + + +template +inline +void +glue_conv2::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + const Mat& A = UA.M; + const Mat& B = UB.M; + + const uword mode = expr.aux_uword; + + if(mode == 0) // full convolution + { + glue_conv2::apply(out, A, B); + } + else + if(mode == 1) // same size as A + { + Mat tmp; + + glue_conv2::apply(tmp, A, B); + + if( (tmp.is_empty() == false) && (A.is_empty() == false) && (B.is_empty() == false) ) + { + const uword start_row = uword( std::floor( double(B.n_rows) / double(2) ) ); + const uword start_col = uword( std::floor( double(B.n_cols) / double(2) ) ); + + out = tmp(start_row, start_col, arma::size(A)); + } + else + { + out.zeros( arma::size(A) ); + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_cor_bones.hpp b/src/armadillo/include/armadillo_bits/glue_cor_bones.hpp new file mode 100644 index 0000000..eabb897 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_cor_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_cor +//! @{ + + + +class glue_cor + { + public: + + template + struct traits + { + static constexpr bool is_row = false; // T1::is_col; // TODO: check + static constexpr bool is_col = false; // T2::is_col; // TODO: check + static constexpr bool is_xvec = false; + }; + + template inline static void apply(Mat& out, const Glue& X); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_cor_meat.hpp b/src/armadillo/include/armadillo_bits/glue_cor_meat.hpp new file mode 100644 index 0000000..8f93797 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_cor_meat.hpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_cor +//! @{ + + + +template +inline +void +glue_cor::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword norm_type = X.aux_uword; + + const unwrap UA(X.A); + const unwrap UB(X.B); + + const Mat& A = UA.M; + const Mat& B = UB.M; + + const Mat& AA = (A.n_rows == 1) + ? Mat(const_cast(A.memptr()), A.n_cols, A.n_rows, false, false) + : Mat(const_cast(A.memptr()), A.n_rows, A.n_cols, false, false); + + const Mat& BB = (B.n_rows == 1) + ? Mat(const_cast(B.memptr()), B.n_cols, B.n_rows, false, false) + : Mat(const_cast(B.memptr()), B.n_rows, B.n_cols, false, false); + + arma_debug_assert_mul_size(AA, BB, true, false, "cor()"); + + if( (AA.n_elem == 0) || (BB.n_elem == 0) ) + { + out.reset(); + return; + } + + const uword N = AA.n_rows; + const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); + + const Mat tmp1 = AA.each_row() - mean(AA,0); + const Mat tmp2 = BB.each_row() - mean(BB,0); + + out = tmp1.t() * tmp2; + out /= norm_val; + + out /= conv_to< Mat >::from( stddev(AA).t() * stddev(BB) ); // TODO: check for zeros? + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_cov_bones.hpp b/src/armadillo/include/armadillo_bits/glue_cov_bones.hpp new file mode 100644 index 0000000..385dd7a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_cov_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_cov +//! @{ + + + +class glue_cov + { + public: + + template + struct traits + { + static constexpr bool is_row = false; // T1::is_col; // TODO: check + static constexpr bool is_col = false; // T2::is_col; // TODO: check + static constexpr bool is_xvec = false; + }; + + template inline static void apply(Mat& out, const Glue& X); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_cov_meat.hpp b/src/armadillo/include/armadillo_bits/glue_cov_meat.hpp new file mode 100644 index 0000000..d5768e2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_cov_meat.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_cov +//! @{ + + + +template +inline +void +glue_cov::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword norm_type = X.aux_uword; + + const unwrap UA(X.A); + const unwrap UB(X.B); + + const Mat& A = UA.M; + const Mat& B = UB.M; + + const Mat& AA = (A.n_rows == 1) + ? Mat(const_cast(A.memptr()), A.n_cols, A.n_rows, false, false) + : Mat(const_cast(A.memptr()), A.n_rows, A.n_cols, false, false); + + const Mat& BB = (B.n_rows == 1) + ? Mat(const_cast(B.memptr()), B.n_cols, B.n_rows, false, false) + : Mat(const_cast(B.memptr()), B.n_rows, B.n_cols, false, false); + + arma_debug_assert_mul_size(AA, BB, true, false, "cov()"); + + if( (A.n_elem == 0) || (B.n_elem == 0) ) + { + out.reset(); + return; + } + + const uword N = AA.n_rows; + const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); + + const Mat tmp1 = AA.each_row() - mean(AA,0); + const Mat tmp2 = BB.each_row() - mean(BB,0); + + out = tmp1.t() * tmp2; + out /= norm_val; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_cross_bones.hpp b/src/armadillo/include/armadillo_bits/glue_cross_bones.hpp new file mode 100644 index 0000000..469e2e7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_cross_bones.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_cross +//! @{ + + + +class glue_cross + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = true; + }; + + template inline static void apply(Mat& out, const Glue& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_cross_meat.hpp b/src/armadillo/include/armadillo_bits/glue_cross_meat.hpp new file mode 100644 index 0000000..bdf38d1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_cross_meat.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_cross +//! @{ + + + +template +inline +void +glue_cross::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy PA(X.A); + const Proxy PB(X.B); + + arma_debug_check( ((PA.get_n_elem() != 3) || (PB.get_n_elem() != 3)), "cross(): each vector must have 3 elements" ); + + out.set_size(PA.get_n_rows(), PA.get_n_cols()); + + eT* out_mem = out.memptr(); + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typename Proxy::ea_type A = PA.get_ea(); + typename Proxy::ea_type B = PB.get_ea(); + + const eT ax = A[0]; + const eT ay = A[1]; + const eT az = A[2]; + + const eT bx = B[0]; + const eT by = B[1]; + const eT bz = B[2]; + + out_mem[0] = ay*bz - az*by; + out_mem[1] = az*bx - ax*bz; + out_mem[2] = ax*by - ay*bx; + } + else + { + const bool PA_is_col = Proxy::is_col ? true : (PA.get_n_cols() == 1); + const bool PB_is_col = Proxy::is_col ? true : (PB.get_n_cols() == 1); + + const eT ax = PA.at(0,0); + const eT ay = PA_is_col ? PA.at(1,0) : PA.at(0,1); + const eT az = PA_is_col ? PA.at(2,0) : PA.at(0,2); + + const eT bx = PB.at(0,0); + const eT by = PB_is_col ? PB.at(1,0) : PB.at(0,1); + const eT bz = PB_is_col ? PB.at(2,0) : PB.at(0,2); + + out_mem[0] = ay*bz - az*by; + out_mem[1] = az*bx - ax*bz; + out_mem[2] = ax*by - ay*bx; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_hist_bones.hpp b/src/armadillo/include/armadillo_bits/glue_hist_bones.hpp new file mode 100644 index 0000000..2d05358 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_hist_bones.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_hist +//! @{ + + +class glue_hist + : public traits_glue_default + { + public: + + template + inline static void apply_noalias(Mat& out, const Mat& X, const Mat& C, const uword dim); + + template + inline static void apply(Mat& out, const mtGlue& expr); + }; + + + +class glue_hist_default + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + inline static void apply(Mat& out, const mtGlue& expr); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_hist_meat.hpp b/src/armadillo/include/armadillo_bits/glue_hist_meat.hpp new file mode 100644 index 0000000..ec5b4a3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_hist_meat.hpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_hist +//! @{ + + +template +inline +void +glue_hist::apply_noalias(Mat& out, const Mat& X, const Mat& C, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ((C.is_vec() == false) && (C.is_empty() == false)), "hist(): parameter 'centers' must be a vector" ); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword C_n_elem = C.n_elem; + + if( C_n_elem == 0 ) { out.reset(); return; } + + arma_debug_check + ( + ((Col(const_cast(C.memptr()), C_n_elem, false, false)).is_sorted("strictascend") == false), + "hist(): given 'centers' vector does not contain monotonically increasing values" + ); + + const eT* C_mem = C.memptr(); + const eT center_0 = C_mem[0]; + + if(dim == 0) + { + out.zeros(C_n_elem, X_n_cols); + + for(uword col=0; col < X_n_cols; ++col) + { + const eT* X_coldata = X.colptr(col); + uword* out_coldata = out.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + const eT val = X_coldata[row]; + + if(arma_isfinite(val)) + { + eT opt_dist = (center_0 >= val) ? (center_0 - val) : (val - center_0); + uword opt_index = 0; + + for(uword j=1; j < C_n_elem; ++j) + { + const eT center = C_mem[j]; + const eT dist = (center >= val) ? (center - val) : (val - center); + + if(dist < opt_dist) + { + opt_dist = dist; + opt_index = j; + } + else + { + break; + } + } + + out_coldata[opt_index]++; + } + else + { + // -inf + if(val < eT(0)) { out_coldata[0]++; } + + // +inf + if(val > eT(0)) { out_coldata[C_n_elem-1]++; } + + // ignore NaN + } + } + } + } + else + if(dim == 1) + { + out.zeros(X_n_rows, C_n_elem); + + if(X_n_rows == 1) + { + const uword X_n_elem = X.n_elem; + const eT* X_mem = X.memptr(); + uword* out_mem = out.memptr(); + + for(uword i=0; i < X_n_elem; ++i) + { + const eT val = X_mem[i]; + + if(is_finite(val)) + { + eT opt_dist = (val >= center_0) ? (val - center_0) : (center_0 - val); + uword opt_index = 0; + + for(uword j=1; j < C_n_elem; ++j) + { + const eT center = C_mem[j]; + const eT dist = (val >= center) ? (val - center) : (center - val); + + if(dist < opt_dist) + { + opt_dist = dist; + opt_index = j; + } + else + { + break; + } + } + + out_mem[opt_index]++; + } + else + { + // -inf + if(val < eT(0)) { out_mem[0]++; } + + // +inf + if(val > eT(0)) { out_mem[C_n_elem-1]++; } + + // ignore NaN + } + } + } + else + { + for(uword row=0; row < X_n_rows; ++row) + { + for(uword col=0; col < X_n_cols; ++col) + { + const eT val = X.at(row,col); + + if(arma_isfinite(val)) + { + eT opt_dist = (center_0 >= val) ? (center_0 - val) : (val - center_0); + uword opt_index = 0; + + for(uword j=1; j < C_n_elem; ++j) + { + const eT center = C_mem[j]; + const eT dist = (center >= val) ? (center - val) : (val - center); + + if(dist < opt_dist) + { + opt_dist = dist; + opt_index = j; + } + else + { + break; + } + } + + out.at(row,opt_index)++; + } + else + { + // -inf + if(val < eT(0)) { out.at(row,0)++; } + + // +inf + if(val > eT(0)) { out.at(row,C_n_elem-1)++; } + + // ignore NaN + } + } + } + } + } + } + + + +template +inline +void +glue_hist::apply(Mat& out, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + const uword dim = expr.aux_uword; + + arma_debug_check( (dim > 1), "hist(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + Mat tmp; + + glue_hist::apply_noalias(tmp, UA.M, UB.M, dim); + + out.steal_mem(tmp); + } + else + { + glue_hist::apply_noalias(out, UA.M, UB.M, dim); + } + } + + + +template +inline +void +glue_hist_default::apply(Mat& out, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + const uword dim = (T1::is_xvec) ? uword(UA.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + Mat tmp; + + glue_hist::apply_noalias(tmp, UA.M, UB.M, dim); + + out.steal_mem(tmp); + } + else + { + glue_hist::apply_noalias(out, UA.M, UB.M, dim); + } + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_histc_bones.hpp b/src/armadillo/include/armadillo_bits/glue_histc_bones.hpp new file mode 100644 index 0000000..c1cc687 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_histc_bones.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_histc +//! @{ + + +class glue_histc + : public traits_glue_default + { + public: + + template + inline static void apply_noalias(Mat& C, const Mat& A, const Mat& B, const uword dim); + + template + inline static void apply(Mat& C, const mtGlue& expr); + }; + + + +class glue_histc_default + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + inline static void apply(Mat& C, const mtGlue& expr); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_histc_meat.hpp b/src/armadillo/include/armadillo_bits/glue_histc_meat.hpp new file mode 100644 index 0000000..6e79175 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_histc_meat.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_histc +//! @{ + + +template +inline +void +glue_histc::apply_noalias(Mat& C, const Mat& A, const Mat& B, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ((B.is_vec() == false) && (B.is_empty() == false)), "histc(): parameter 'edges' must be a vector" ); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_elem = B.n_elem; + + if( B_n_elem == uword(0) ) { C.reset(); return; } + + arma_debug_check + ( + ((Col(const_cast(B.memptr()), B_n_elem, false, false)).is_sorted("strictascend") == false), + "hist(): given 'edges' vector does not contain monotonically increasing values" + ); + + const eT* B_mem = B.memptr(); + const uword B_n_elem_m1 = B_n_elem - 1; + + if(dim == uword(0)) + { + C.zeros(B_n_elem, A_n_cols); + + for(uword col=0; col < A_n_cols; ++col) + { + const eT* A_coldata = A.colptr(col); + uword* C_coldata = C.colptr(col); + + for(uword row=0; row < A_n_rows; ++row) + { + const eT x = A_coldata[row]; + + for(uword i=0; i < B_n_elem_m1; ++i) + { + if( (B_mem[i] <= x) && (x < B_mem[i+1]) ) { C_coldata[i]++; break; } + else if( B_mem[B_n_elem_m1] == x ) { C_coldata[B_n_elem_m1]++; break; } // for compatibility with Matlab + } + } + } + } + else + if(dim == uword(1)) + { + C.zeros(A_n_rows, B_n_elem); + + if(A.n_rows == 1) + { + const uword A_n_elem = A.n_elem; + const eT* A_mem = A.memptr(); + uword* C_mem = C.memptr(); + + for(uword j=0; j < A_n_elem; ++j) + { + const eT x = A_mem[j]; + + for(uword i=0; i < B_n_elem_m1; ++i) + { + if( (B_mem[i] <= x) && (x < B_mem[i+1]) ) { C_mem[i]++; break; } + else if( B_mem[B_n_elem_m1] == x ) { C_mem[B_n_elem_m1]++; break; } // for compatibility with Matlab + } + } + } + else + { + for(uword row=0; row < A_n_rows; ++row) + for(uword col=0; col < A_n_cols; ++col) + { + const eT x = A.at(row,col); + + for(uword i=0; i < B_n_elem_m1; ++i) + { + if( (B_mem[i] <= x) && (x < B_mem[i+1]) ) { C.at(row,i)++; break; } + else if( B_mem[B_n_elem_m1] == x ) { C.at(row,B_n_elem_m1)++; break; } // for compatibility with Matlab + } + } + } + } + } + + + +template +inline +void +glue_histc::apply(Mat& C, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + const uword dim = expr.aux_uword; + + arma_debug_check( (dim > 1), "histc(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + if(UA.is_alias(C) || UB.is_alias(C)) + { + Mat tmp; + + glue_histc::apply_noalias(tmp, UA.M, UB.M, dim); + + C.steal_mem(tmp); + } + else + { + glue_histc::apply_noalias(C, UA.M, UB.M, dim); + } + } + + + +template +inline +void +glue_histc_default::apply(Mat& C, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + const uword dim = (T1::is_xvec) ? uword(UA.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + if(UA.is_alias(C) || UB.is_alias(C)) + { + Mat tmp; + + glue_histc::apply_noalias(tmp, UA.M, UB.M, dim); + + C.steal_mem(tmp); + } + else + { + glue_histc::apply_noalias(C, UA.M, UB.M, dim); + } + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_hypot_bones.hpp b/src/armadillo/include/armadillo_bits/glue_hypot_bones.hpp new file mode 100644 index 0000000..53985cc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_hypot_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_hypot +//! @{ + + + +class glue_hypot + : public traits_glue_or + { + public: + + + // matrices + + template inline static void apply(Mat& out, const Glue& expr); + + template inline static void apply_noalias(Mat& out, const Proxy& P1, const Proxy& P2); + + + // cubes + + template inline static void apply(Cube& out, const GlueCube& expr); + + template inline static void apply_noalias(Cube& out, const ProxyCube& P1, const ProxyCube& P2); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_hypot_meat.hpp b/src/armadillo/include/armadillo_bits/glue_hypot_meat.hpp new file mode 100644 index 0000000..f773b33 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_hypot_meat.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_hypot +//! @{ + + + +template +inline +void +glue_hypot::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P1(expr.A); + const Proxy P2(expr.B); + + arma_assert_same_size(P1, P2, "hypot()"); + + const bool bad_alias = ( (Proxy::has_subview && P1.is_alias(out)) || (Proxy::has_subview && P2.is_alias(out)) ); + + if(bad_alias == false) + { + glue_hypot::apply_noalias(out, P1, P2); + } + else + { + Mat tmp; + + glue_hypot::apply_noalias(tmp, P1, P2); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_hypot::apply_noalias(Mat& out, const Proxy& P1, const Proxy& P2) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P1.get_n_rows(); + const uword n_cols = P1.get_n_cols(); + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typename Proxy::ea_type eaP1 = P1.get_ea(); + typename Proxy::ea_type eaP2 = P2.get_ea(); + + const uword N = P1.get_n_elem(); + + for(uword i=0; i +inline +void +glue_hypot::apply(Cube& out, const GlueCube& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const ProxyCube P1(expr.A); + const ProxyCube P2(expr.B); + + arma_assert_same_size(P1, P2, "hypot()"); + + const bool bad_alias = ( (ProxyCube::has_subview && P1.is_alias(out)) || (ProxyCube::has_subview && P2.is_alias(out)) ); + + if(bad_alias == false) + { + glue_hypot::apply_noalias(out, P1, P2); + } + else + { + Cube tmp; + + glue_hypot::apply_noalias(tmp, P1, P2); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_hypot::apply_noalias(Cube& out, const ProxyCube& P1, const ProxyCube& P2) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P1.get_n_rows(); + const uword n_cols = P1.get_n_cols(); + const uword n_slices = P1.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + if( (ProxyCube::use_at == false) && (ProxyCube::use_at == false) ) + { + typename ProxyCube::ea_type eaP1 = P1.get_ea(); + typename ProxyCube::ea_type eaP2 = P2.get_ea(); + + const uword N = P1.get_n_elem(); + + for(uword i=0; i + struct traits + { + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = (T1::is_col || T2::is_col); + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const Glue& X); + + template + inline static void apply(Mat& out, uvec& iA, uvec& iB, const Base& A_expr, const Base& B_expr, const bool calc_indx); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_intersect_meat.hpp b/src/armadillo/include/armadillo_bits/glue_intersect_meat.hpp new file mode 100644 index 0000000..21a2b6d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_intersect_meat.hpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_intersect +//! @{ + + + +template +inline +void +glue_intersect::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + uvec iA; + uvec iB; + + glue_intersect::apply(out, iA, iB, X.A, X.B, false); + } + + + +template +inline +void +glue_intersect::apply(Mat& out, uvec& iA, uvec& iB, const Base& A_expr, const Base& B_expr, const bool calc_indx) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(A_expr.get_ref()); + const quasi_unwrap UB(B_expr.get_ref()); + + if(UA.M.is_empty() || UB.M.is_empty()) + { + out.reset(); + iA.reset(); + iB.reset(); + return; + } + + uvec A_uniq_indx; + uvec B_uniq_indx; + + Mat A_uniq; + Mat B_uniq; + + if(calc_indx) + { + A_uniq_indx = find_unique(UA.M); + B_uniq_indx = find_unique(UB.M); + + A_uniq = UA.M.elem(A_uniq_indx); + B_uniq = UB.M.elem(B_uniq_indx); + } + else + { + A_uniq = unique(UA.M); + B_uniq = unique(UB.M); + } + + const uword C_n_elem = A_uniq.n_elem + B_uniq.n_elem; + + Col C(C_n_elem, arma_nozeros_indicator()); + + arrayops::copy(C.memptr(), A_uniq.memptr(), A_uniq.n_elem); + arrayops::copy(C.memptr() + A_uniq.n_elem, B_uniq.memptr(), B_uniq.n_elem); + + uvec C_sorted_indx; + Col C_sorted; + + if(calc_indx) + { + C_sorted_indx = stable_sort_index(C); + C_sorted = C.elem(C_sorted_indx); + } + else + { + C_sorted = sort(C); + } + + const eT* C_sorted_mem = C_sorted.memptr(); + + uvec jj(C_n_elem, arma_nozeros_indicator()); // worst case length + + uword* jj_mem = jj.memptr(); + uword jj_count = 0; + + for(uword i=0; i < (C_n_elem-1); ++i) + { + if( C_sorted_mem[i] == C_sorted_mem[i+1] ) + { + jj_mem[jj_count] = i; + ++jj_count; + } + } + + if(jj_count == 0) + { + out.reset(); + iA.reset(); + iB.reset(); + return; + } + + const uvec ii(jj.memptr(), jj_count, false); + + if(UA.M.is_rowvec() && UB.M.is_rowvec()) + { + out.set_size(1, ii.n_elem); + + Mat out_alias(out.memptr(), ii.n_elem, 1, false, true); + + // NOTE: this relies on .elem() not changing the size of the output and not reallocating memory for the output + out_alias = C_sorted.elem(ii); + } + else + { + out = C_sorted.elem(ii); + } + + if(calc_indx) + { + iA = A_uniq_indx.elem(C_sorted_indx.elem(ii ) ); + iB = B_uniq_indx.elem(C_sorted_indx.elem(ii+1) - A_uniq.n_elem); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_join_bones.hpp b/src/armadillo/include/armadillo_bits/glue_join_bones.hpp new file mode 100644 index 0000000..b84116a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_join_bones.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_join +//! @{ + + + +class glue_join_cols + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const Glue& X); + + template + inline static void apply_noalias(Mat& out, const Proxy& A, const Proxy& B); + + template + inline static void apply(Mat& out, const Base& A, const Base& B, const Base& C); + + template + inline static void apply(Mat& out, const Base& A, const Base& B, const Base& C, const Base& D); + }; + + + +class glue_join_rows + { + public: + + template + struct traits + { + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const Glue& X); + + template + inline static void apply_noalias(Mat& out, const Proxy& A, const Proxy& B); + + template + inline static void apply(Mat& out, const Base& A, const Base& B, const Base& C); + + template + inline static void apply(Mat& out, const Base& A, const Base& B, const Base& C, const Base& D); + }; + + + +class glue_join_slices + { + public: + + template + inline static void apply(Cube& out, const GlueCube& X); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_join_meat.hpp b/src/armadillo/include/armadillo_bits/glue_join_meat.hpp new file mode 100644 index 0000000..1ffd3b1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_join_meat.hpp @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_join +//! @{ + + + +template +inline +void +glue_join_cols::apply_noalias(Mat& out, const Proxy& A, const Proxy& B) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.get_n_rows(); + const uword A_n_cols = A.get_n_cols(); + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + + arma_debug_check + ( + ( (A_n_cols != B_n_cols) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ), + "join_cols() / join_vert(): number of columns must be the same" + ); + + out.set_size( A_n_rows + B_n_rows, (std::max)(A_n_cols, B_n_cols) ); + + if( out.n_elem > 0 ) + { + if(A.get_n_elem() > 0) + { + out.submat(0, 0, A_n_rows-1, out.n_cols-1) = A.Q; + } + + if(B.get_n_elem() > 0) + { + out.submat(A_n_rows, 0, out.n_rows-1, out.n_cols-1) = B.Q; + } + } + } + + + + +template +inline +void +glue_join_cols::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy A(X.A); + const Proxy B(X.B); + + if( (A.is_alias(out) == false) && (B.is_alias(out) == false) ) + { + glue_join_cols::apply_noalias(out, A, B); + } + else + { + Mat tmp; + + glue_join_cols::apply_noalias(tmp, A, B); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_join_cols::apply(Mat& out, const Base& A_expr, const Base& B_expr, const Base& C_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UA(A_expr.get_ref()); + const quasi_unwrap UB(B_expr.get_ref()); + const quasi_unwrap UC(C_expr.get_ref()); + + const Mat& A = UA.M; + const Mat& B = UB.M; + const Mat& C = UC.M; + + const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows; + const uword out_n_cols = (std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols); + + arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + + out.set_size(out_n_rows, out_n_cols); + + if(out.n_elem == 0) { return; } + + uword row_start = 0; + uword row_end_p1 = 0; + + if(A.n_elem > 0) { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; } + + row_start = row_end_p1; + + if(B.n_elem > 0) { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; } + + row_start = row_end_p1; + + if(C.n_elem > 0) { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; } + } + + + +template +inline +void +glue_join_cols::apply(Mat& out, const Base& A_expr, const Base& B_expr, const Base& C_expr, const Base& D_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UA(A_expr.get_ref()); + const quasi_unwrap UB(B_expr.get_ref()); + const quasi_unwrap UC(C_expr.get_ref()); + const quasi_unwrap UD(D_expr.get_ref()); + + const Mat& A = UA.M; + const Mat& B = UB.M; + const Mat& C = UC.M; + const Mat& D = UD.M; + + const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows + D.n_rows; + const uword out_n_cols = (std::max)(((std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols)), D.n_cols); + + arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((D.n_cols != out_n_cols) && ((D.n_rows > 0) || (D.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + + out.set_size(out_n_rows, out_n_cols); + + if(out.n_elem == 0) { return; } + + uword row_start = 0; + uword row_end_p1 = 0; + + if(A.n_elem > 0) { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; } + + row_start = row_end_p1; + + if(B.n_elem > 0) { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; } + + row_start = row_end_p1; + + if(C.n_elem > 0) { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; } + + row_start = row_end_p1; + + if(D.n_elem > 0) { row_end_p1 += D.n_rows; out.rows(row_start, row_end_p1 - 1) = D; } + } + + + +template +inline +void +glue_join_rows::apply_noalias(Mat& out, const Proxy& A, const Proxy& B) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.get_n_rows(); + const uword A_n_cols = A.get_n_cols(); + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + + arma_debug_check + ( + ( (A_n_rows != B_n_rows) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ), + "join_rows() / join_horiz(): number of rows must be the same" + ); + + out.set_size( (std::max)(A_n_rows, B_n_rows), A_n_cols + B_n_cols ); + + if( out.n_elem > 0 ) + { + if(A.get_n_elem() > 0) + { + out.submat(0, 0, out.n_rows-1, A_n_cols-1) = A.Q; + } + + if(B.get_n_elem() > 0) + { + out.submat(0, A_n_cols, out.n_rows-1, out.n_cols-1) = B.Q; + } + } + } + + + + +template +inline +void +glue_join_rows::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy A(X.A); + const Proxy B(X.B); + + if( (A.is_alias(out) == false) && (B.is_alias(out) == false) ) + { + glue_join_rows::apply_noalias(out, A, B); + } + else + { + Mat tmp; + + glue_join_rows::apply_noalias(tmp, A, B); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_join_rows::apply(Mat& out, const Base& A_expr, const Base& B_expr, const Base& C_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UA(A_expr.get_ref()); + const quasi_unwrap UB(B_expr.get_ref()); + const quasi_unwrap UC(C_expr.get_ref()); + + const Mat& A = UA.M; + const Mat& B = UB.M; + const Mat& C = UC.M; + + const uword out_n_rows = (std::max)((std::max)(A.n_rows, B.n_rows), C.n_rows); + const uword out_n_cols = A.n_cols + B.n_cols + C.n_cols; + + arma_debug_check( ((A.n_rows != out_n_rows) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" ); + arma_debug_check( ((B.n_rows != out_n_rows) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" ); + arma_debug_check( ((C.n_rows != out_n_rows) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" ); + + out.set_size(out_n_rows, out_n_cols); + + if(out.n_elem == 0) { return; } + + uword col_start = 0; + uword col_end_p1 = 0; + + if(A.n_elem > 0) { col_end_p1 += A.n_cols; out.cols(col_start, col_end_p1 - 1) = A; } + + col_start = col_end_p1; + + if(B.n_elem > 0) { col_end_p1 += B.n_cols; out.cols(col_start, col_end_p1 - 1) = B; } + + col_start = col_end_p1; + + if(C.n_elem > 0) { col_end_p1 += C.n_cols; out.cols(col_start, col_end_p1 - 1) = C; } + } + + + +template +inline +void +glue_join_rows::apply(Mat& out, const Base& A_expr, const Base& B_expr, const Base& C_expr, const Base& D_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UA(A_expr.get_ref()); + const quasi_unwrap UB(B_expr.get_ref()); + const quasi_unwrap UC(C_expr.get_ref()); + const quasi_unwrap UD(D_expr.get_ref()); + + const Mat& A = UA.M; + const Mat& B = UB.M; + const Mat& C = UC.M; + const Mat& D = UD.M; + + const uword out_n_rows = (std::max)(((std::max)((std::max)(A.n_rows, B.n_rows), C.n_rows)), D.n_rows); + const uword out_n_cols = A.n_cols + B.n_cols + C.n_cols + D.n_cols; + + arma_debug_check( ((A.n_rows != out_n_rows) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" ); + arma_debug_check( ((B.n_rows != out_n_rows) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" ); + arma_debug_check( ((C.n_rows != out_n_rows) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" ); + arma_debug_check( ((D.n_rows != out_n_rows) && ((D.n_rows > 0) || (D.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" ); + + out.set_size(out_n_rows, out_n_cols); + + if(out.n_elem == 0) { return; } + + uword col_start = 0; + uword col_end_p1 = 0; + + if(A.n_elem > 0) { col_end_p1 += A.n_cols; out.cols(col_start, col_end_p1 - 1) = A; } + + col_start = col_end_p1; + + if(B.n_elem > 0) { col_end_p1 += B.n_cols; out.cols(col_start, col_end_p1 - 1) = B; } + + col_start = col_end_p1; + + if(C.n_elem > 0) { col_end_p1 += C.n_cols; out.cols(col_start, col_end_p1 - 1) = C; } + + col_start = col_end_p1; + + if(D.n_elem > 0) { col_end_p1 += D.n_cols; out.cols(col_start, col_end_p1 - 1) = D; } + } + + + +template +inline +void +glue_join_slices::apply(Cube& out, const GlueCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube A_tmp(X.A); + const unwrap_cube B_tmp(X.B); + + const Cube& A = A_tmp.M; + const Cube& B = B_tmp.M; + + if(A.n_elem == 0) { out = B; return; } + if(B.n_elem == 0) { out = A; return; } + + arma_debug_check( ( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) ), "join_slices(): size of slices must be the same" ); + + if( (&out != &A) && (&out != &B) ) + { + out.set_size(A.n_rows, A.n_cols, A.n_slices + B.n_slices); + + out.slices(0, A.n_slices-1 ) = A; + out.slices(A.n_slices, out.n_slices-1) = B; + } + else // we have aliasing + { + Cube C(A.n_rows, A.n_cols, A.n_slices + B.n_slices, arma_nozeros_indicator()); + + C.slices(0, A.n_slices-1) = A; + C.slices(A.n_slices, C.n_slices-1) = B; + + out.steal_mem(C); + } + + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_kron_bones.hpp b/src/armadillo/include/armadillo_bits/glue_kron_bones.hpp new file mode 100644 index 0000000..84c9347 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_kron_bones.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_kron +//! @{ + + + +class glue_kron + { + public: + + template + struct traits + { + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; + }; + + template inline static void direct_kron(Mat& out, const Mat& A, const Mat& B); + template inline static void direct_kron(Mat< std::complex >& out, const Mat< std::complex >& A, const Mat& B); + template inline static void direct_kron(Mat< std::complex >& out, const Mat& A, const Mat< std::complex >& B); + + template inline static void apply(Mat& out, const Glue& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_kron_meat.hpp b/src/armadillo/include/armadillo_bits/glue_kron_meat.hpp new file mode 100644 index 0000000..c7c4ff6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_kron_meat.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_kron +//! @{ + + + +//! \brief +//! both input matrices have the same element type +template +inline +void +glue_kron::direct_kron(Mat& out, const Mat& A, const Mat& B) + { + arma_extra_debug_sigprint(); + + const uword A_rows = A.n_rows; + const uword A_cols = A.n_cols; + const uword B_rows = B.n_rows; + const uword B_cols = B.n_cols; + + out.set_size(A_rows*B_rows, A_cols*B_cols); + + if(out.is_empty()) { return; } + + for(uword j = 0; j < A_cols; j++) + { + for(uword i = 0; i < A_rows; i++) + { + out.submat(i*B_rows, j*B_cols, (i+1)*B_rows-1, (j+1)*B_cols-1) = A.at(i,j) * B; + } + } + } + + + +//! \brief +//! different types of input matrices +//! A -> complex, B -> basic element type +template +inline +void +glue_kron::direct_kron(Mat< std::complex >& out, const Mat< std::complex >& A, const Mat& B) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const uword A_rows = A.n_rows; + const uword A_cols = A.n_cols; + const uword B_rows = B.n_rows; + const uword B_cols = B.n_cols; + + out.set_size(A_rows*B_rows, A_cols*B_cols); + + if(out.is_empty()) { return; } + + Mat tmp_B = conv_to< Mat >::from(B); + + for(uword j = 0; j < A_cols; j++) + { + for(uword i = 0; i < A_rows; i++) + { + out.submat(i*B_rows, j*B_cols, (i+1)*B_rows-1, (j+1)*B_cols-1) = A.at(i,j) * tmp_B; + } + } + } + + + +//! \brief +//! different types of input matrices +//! A -> basic element type, B -> complex +template +inline +void +glue_kron::direct_kron(Mat< std::complex >& out, const Mat& A, const Mat< std::complex >& B) + { + arma_extra_debug_sigprint(); + + const uword A_rows = A.n_rows; + const uword A_cols = A.n_cols; + const uword B_rows = B.n_rows; + const uword B_cols = B.n_cols; + + out.set_size(A_rows*B_rows, A_cols*B_cols); + + if(out.is_empty()) { return; } + + for(uword j = 0; j < A_cols; j++) + { + for(uword i = 0; i < A_rows; i++) + { + out.submat(i*B_rows, j*B_cols, (i+1)*B_rows-1, (j+1)*B_cols-1) = A.at(i,j) * B; + } + } + } + + + +//! \brief +//! apply Kronecker product for two objects with same element type +template +inline +void +glue_kron::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(X.A); + const quasi_unwrap UB(X.B); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + Mat tmp; + + glue_kron::direct_kron(tmp, UA.M, UB.M); + + out.steal_mem(tmp); + } + else + { + glue_kron::direct_kron(out, UA.M, UB.M); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_max_bones.hpp b/src/armadillo/include/armadillo_bits/glue_max_bones.hpp new file mode 100644 index 0000000..149988e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_max_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_max +//! @{ + + + +class glue_max + : public traits_glue_or + { + public: + + // dense matrices + + template inline static void apply(Mat& out, const Glue& X); + + template inline static void apply(Mat& out, const Proxy& PA, const Proxy& PB); + + + // cubes + + template inline static void apply(Cube& out, const GlueCube& X); + + template inline static void apply(Cube& out, const ProxyCube& PA, const ProxyCube& PB); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_max_meat.hpp b/src/armadillo/include/armadillo_bits/glue_max_meat.hpp new file mode 100644 index 0000000..b1e52c2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_max_meat.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_max +//! @{ + + + +template +inline +void +glue_max::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy PA(X.A); + const Proxy PB(X.B); + + if( (PA.is_alias(out) && PA.has_subview) || (PB.is_alias(out) && PB.has_subview) ) + { + Mat tmp; + + glue_max::apply(tmp, PA, PB); + + out.steal_mem(tmp); + } + else + { + glue_max::apply(out, PA, PB); + } + } + + + +template +inline +void +glue_max::apply(Mat& out, const Proxy& PA, const Proxy& PB) + { + arma_extra_debug_sigprint(); + + const uword n_rows = PA.get_n_rows(); + const uword n_cols = PA.get_n_cols(); + + arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise max()"); + + const arma_gt_comparator comparator; + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typename Proxy::ea_type A = PA.get_ea(); + typename Proxy::ea_type B = PB.get_ea(); + + const uword N = PA.get_n_elem(); + + for(uword i=0; i +inline +void +glue_max::apply(Cube& out, const GlueCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const ProxyCube PA(X.A); + const ProxyCube PB(X.B); + + if( (PA.is_alias(out) && PA.has_subview) || (PB.is_alias(out) && PB.has_subview) ) + { + Cube tmp; + + glue_max::apply(tmp, PA, PB); + + out.steal_mem(tmp); + } + else + { + glue_max::apply(out, PA, PB); + } + } + + + +template +inline +void +glue_max::apply(Cube& out, const ProxyCube& PA, const ProxyCube& PB) + { + arma_extra_debug_sigprint(); + + const uword n_rows = PA.get_n_rows(); + const uword n_cols = PA.get_n_cols(); + const uword n_slices = PA.get_n_slices(); + + arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise max()"); + + const arma_gt_comparator comparator; + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + if( (ProxyCube::use_at == false) && (ProxyCube::use_at == false) ) + { + typename ProxyCube::ea_type A = PA.get_ea(); + typename ProxyCube::ea_type B = PB.get_ea(); + + const uword N = PA.get_n_elem(); + + for(uword i=0; i inline static void apply(Mat& out, const Glue& X); + + template inline static void apply(Mat& out, const Proxy& PA, const Proxy& PB); + + + // cubes + + template inline static void apply(Cube& out, const GlueCube& X); + + template inline static void apply(Cube& out, const ProxyCube& PA, const ProxyCube& PB); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_min_meat.hpp b/src/armadillo/include/armadillo_bits/glue_min_meat.hpp new file mode 100644 index 0000000..0fc6e3f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_min_meat.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_min +//! @{ + + + +template +inline +void +glue_min::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy PA(X.A); + const Proxy PB(X.B); + + if( (PA.is_alias(out) && PA.has_subview) || (PB.is_alias(out) && PB.has_subview) ) + { + Mat tmp; + + glue_min::apply(tmp, PA, PB); + + out.steal_mem(tmp); + } + else + { + glue_min::apply(out, PA, PB); + } + } + + + +template +inline +void +glue_min::apply(Mat& out, const Proxy& PA, const Proxy& PB) + { + arma_extra_debug_sigprint(); + + const uword n_rows = PA.get_n_rows(); + const uword n_cols = PA.get_n_cols(); + + arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise min()"); + + const arma_lt_comparator comparator; + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typename Proxy::ea_type A = PA.get_ea(); + typename Proxy::ea_type B = PB.get_ea(); + + const uword N = PA.get_n_elem(); + + for(uword i=0; i +inline +void +glue_min::apply(Cube& out, const GlueCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const ProxyCube PA(X.A); + const ProxyCube PB(X.B); + + if( (PA.is_alias(out) && PA.has_subview) || (PB.is_alias(out) && PB.has_subview) ) + { + Cube tmp; + + glue_min::apply(tmp, PA, PB); + + out.steal_mem(tmp); + } + else + { + glue_min::apply(out, PA, PB); + } + } + + + +template +inline +void +glue_min::apply(Cube& out, const ProxyCube& PA, const ProxyCube& PB) + { + arma_extra_debug_sigprint(); + + const uword n_rows = PA.get_n_rows(); + const uword n_cols = PA.get_n_cols(); + const uword n_slices = PA.get_n_slices(); + + arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise min()"); + + const arma_lt_comparator comparator; + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + if( (ProxyCube::use_at == false) && (ProxyCube::use_at == false) ) + { + typename ProxyCube::ea_type A = PA.get_ea(); + typename ProxyCube::ea_type B = PB.get_ea(); + + const uword N = PA.get_n_elem(); + + for(uword i=0; i + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_times>& X); + }; + + + +class glue_mixed_plus + : public traits_glue_or + { + public: + + template + inline static void apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_plus>& X); + + template + inline static void apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_plus>& X); + }; + + + +class glue_mixed_minus + : public traits_glue_or + { + public: + + template + inline static void apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_minus>& X); + + template + inline static void apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_minus>& X); + }; + + + +class glue_mixed_div + : public traits_glue_or + { + public: + + template + inline static void apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_div>& X); + + template + inline static void apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_div>& X); + }; + + + +class glue_mixed_schur + : public traits_glue_or + { + public: + + template + inline static void apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_schur>& X); + + template + inline static void apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_schur>& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_mixed_meat.hpp b/src/armadillo/include/armadillo_bits/glue_mixed_meat.hpp new file mode 100644 index 0000000..21b6dc4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_mixed_meat.hpp @@ -0,0 +1,560 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_mixed +//! @{ + + + +//! matrix multiplication with different element types +template +inline +void +glue_mixed_times::apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT1; + typedef typename T2::elem_type in_eT2; + + typedef typename eT_promoter::eT out_eT; + + const partial_unwrap tmp1(X.A); + const partial_unwrap tmp2(X.B); + + const typename partial_unwrap::stored_type& A = tmp1.M; + const typename partial_unwrap::stored_type& B = tmp2.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const out_eT alpha = use_alpha ? (upgrade_val::apply(tmp1.get_val()) * upgrade_val::apply(tmp2.get_val())) : out_eT(0); + + const bool do_trans_A = partial_unwrap::do_trans; + const bool do_trans_B = partial_unwrap::do_trans; + + arma_debug_assert_trans_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const uword out_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; + const uword out_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; + + const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out); + + if(alias == false) + { + out.set_size(out_n_rows, out_n_cols); + + gemm_mixed::apply(out, A, B, alpha); + } + else + { + Mat tmp(out_n_rows, out_n_cols, arma_nozeros_indicator()); + + gemm_mixed::apply(tmp, A, B, alpha); + + out.steal_mem(tmp); + } + } + + + +//! matrix addition with different element types +template +inline +void +glue_mixed_plus::apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_plus>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "addition"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (Proxy::use_at || Proxy::use_at); + + if(use_at == false) + { + typename Proxy::ea_type AA = A.get_ea(); + typename Proxy::ea_type BB = B.get_ea(); + + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + for(uword i=0; i::apply(AA[i]) + upgrade_val::apply(BB[i]); + } + } + else + { + for(uword i=0; i::apply(AA[i]) + upgrade_val::apply(BB[i]); + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + (*out_mem) = upgrade_val::apply(A.at(row,col)) + upgrade_val::apply(B.at(row,col)); + out_mem++; + } + } + } + + + +//! matrix subtraction with different element types +template +inline +void +glue_mixed_minus::apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_minus>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "subtraction"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (Proxy::use_at || Proxy::use_at); + + if(use_at == false) + { + typename Proxy::ea_type AA = A.get_ea(); + typename Proxy::ea_type BB = B.get_ea(); + + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + for(uword i=0; i::apply(AA[i]) - upgrade_val::apply(BB[i]); + } + } + else + { + for(uword i=0; i::apply(AA[i]) - upgrade_val::apply(BB[i]); + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + (*out_mem) = upgrade_val::apply(A.at(row,col)) - upgrade_val::apply(B.at(row,col)); + out_mem++; + } + } + } + + + +//! element-wise matrix division with different element types +template +inline +void +glue_mixed_div::apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_div>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "element-wise division"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (Proxy::use_at || Proxy::use_at); + + if(use_at == false) + { + typename Proxy::ea_type AA = A.get_ea(); + typename Proxy::ea_type BB = B.get_ea(); + + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + for(uword i=0; i::apply(AA[i]) / upgrade_val::apply(BB[i]); + } + } + else + { + for(uword i=0; i::apply(AA[i]) / upgrade_val::apply(BB[i]); + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + (*out_mem) = upgrade_val::apply(A.at(row,col)) / upgrade_val::apply(B.at(row,col)); + out_mem++; + } + } + } + + + +//! element-wise matrix multiplication with different element types +template +inline +void +glue_mixed_schur::apply(Mat::eT>& out, const mtGlue::eT, T1, T2, glue_mixed_schur>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "element-wise multiplication"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (Proxy::use_at || Proxy::use_at); + + if(use_at == false) + { + typename Proxy::ea_type AA = A.get_ea(); + typename Proxy::ea_type BB = B.get_ea(); + + if(memory::is_aligned(out_mem)) + { + memory::mark_as_aligned(out_mem); + + for(uword i=0; i::apply(AA[i]) * upgrade_val::apply(BB[i]); + } + } + else + { + for(uword i=0; i::apply(AA[i]) * upgrade_val::apply(BB[i]); + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + (*out_mem) = upgrade_val::apply(A.at(row,col)) * upgrade_val::apply(B.at(row,col)); + out_mem++; + } + } + } + + + +// +// +// + + + +//! cube addition with different element types +template +inline +void +glue_mixed_plus::apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_plus>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const ProxyCube A(X.A); + const ProxyCube B(X.B); + + arma_debug_assert_same_size(A, B, "addition"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + + if(use_at == false) + { + typename ProxyCube::ea_type AA = A.get_ea(); + typename ProxyCube::ea_type BB = B.get_ea(); + + for(uword i=0; i::apply(AA[i]) + upgrade_val::apply(BB[i]); + } + } + else + { + for(uword slice = 0; slice < n_slices; ++slice) + for(uword col = 0; col < n_cols; ++col ) + for(uword row = 0; row < n_rows; ++row ) + { + (*out_mem) = upgrade_val::apply(A.at(row,col,slice)) + upgrade_val::apply(B.at(row,col,slice)); + out_mem++; + } + } + } + + + +//! cube subtraction with different element types +template +inline +void +glue_mixed_minus::apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_minus>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const ProxyCube A(X.A); + const ProxyCube B(X.B); + + arma_debug_assert_same_size(A, B, "subtraction"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + + if(use_at == false) + { + typename ProxyCube::ea_type AA = A.get_ea(); + typename ProxyCube::ea_type BB = B.get_ea(); + + for(uword i=0; i::apply(AA[i]) - upgrade_val::apply(BB[i]); + } + } + else + { + for(uword slice = 0; slice < n_slices; ++slice) + for(uword col = 0; col < n_cols; ++col ) + for(uword row = 0; row < n_rows; ++row ) + { + (*out_mem) = upgrade_val::apply(A.at(row,col,slice)) - upgrade_val::apply(B.at(row,col,slice)); + out_mem++; + } + } + } + + + +//! element-wise cube division with different element types +template +inline +void +glue_mixed_div::apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_div>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const ProxyCube A(X.A); + const ProxyCube B(X.B); + + arma_debug_assert_same_size(A, B, "element-wise division"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + + if(use_at == false) + { + typename ProxyCube::ea_type AA = A.get_ea(); + typename ProxyCube::ea_type BB = B.get_ea(); + + for(uword i=0; i::apply(AA[i]) / upgrade_val::apply(BB[i]); + } + } + else + { + for(uword slice = 0; slice < n_slices; ++slice) + for(uword col = 0; col < n_cols; ++col ) + for(uword row = 0; row < n_rows; ++row ) + { + (*out_mem) = upgrade_val::apply(A.at(row,col,slice)) / upgrade_val::apply(B.at(row,col,slice)); + out_mem++; + } + } + } + + + +//! element-wise cube multiplication with different element types +template +inline +void +glue_mixed_schur::apply(Cube::eT>& out, const mtGlueCube::eT, T1, T2, glue_mixed_schur>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const ProxyCube A(X.A); + const ProxyCube B(X.B); + + arma_debug_assert_same_size(A, B, "element-wise multiplication"); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + out_eT* out_mem = out.memptr(); + const uword n_elem = out.n_elem; + + const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + + if(use_at == false) + { + typename ProxyCube::ea_type AA = A.get_ea(); + typename ProxyCube::ea_type BB = B.get_ea(); + + for(uword i=0; i::apply(AA[i]) * upgrade_val::apply(BB[i]); + } + } + else + { + for(uword slice = 0; slice < n_slices; ++slice) + for(uword col = 0; col < n_cols; ++col ) + for(uword row = 0; row < n_rows; ++row ) + { + (*out_mem) = upgrade_val::apply(A.at(row,col,slice)) * upgrade_val::apply(B.at(row,col,slice)); + out_mem++; + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_mvnrnd_bones.hpp b/src/armadillo/include/armadillo_bits/glue_mvnrnd_bones.hpp new file mode 100644 index 0000000..ab1c437 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_mvnrnd_bones.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_mvnrnd +//! @{ + + +class glue_mvnrnd_vec + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const Glue& expr); + }; + + + +class glue_mvnrnd + : public traits_glue_default + { + public: + + template + inline static void apply(Mat& out, const Glue& expr); + + template + inline static bool apply_direct(Mat& out, const Base& M, const Base& C, const uword N); + + template + inline static bool apply_noalias(Mat& out, const Mat& M, const Mat& C, const uword N); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_mvnrnd_meat.hpp b/src/armadillo/include/armadillo_bits/glue_mvnrnd_meat.hpp new file mode 100644 index 0000000..3c3019f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_mvnrnd_meat.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_mvnrnd +//! @{ + + +// implementation based on: +// James E. Gentle. +// Generation of Random Numbers. +// Computational Statistics, pp. 305-331, 2009. +// http://dx.doi.org/10.1007/978-0-387-98144-4_7 + + +template +inline +void +glue_mvnrnd_vec::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + const bool status = glue_mvnrnd::apply_direct(out, expr.A, expr.B, uword(1)); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); + } + } + + + +template +inline +void +glue_mvnrnd::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + const bool status = glue_mvnrnd::apply_direct(out, expr.A, expr.B, expr.aux_uword); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); + } + } + + + +template +inline +bool +glue_mvnrnd::apply_direct(Mat& out, const Base& M, const Base& C, const uword N) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UM(M.get_ref()); + const quasi_unwrap UC(C.get_ref()); + + arma_debug_check( (UM.M.is_colvec() == false) && (UM.M.is_empty() == false), "mvnrnd(): given mean must be a column vector" ); + arma_debug_check( (UC.M.is_square() == false), "mvnrnd(): given covariance matrix must be square sized" ); + arma_debug_check( (UM.M.n_rows != UC.M.n_rows), "mvnrnd(): number of rows in given mean vector and covariance matrix must match" ); + + if( UM.M.is_empty() || UC.M.is_empty() ) + { + out.set_size(0,N); + return true; + } + + if((arma_config::debug) && (auxlib::rudimentary_sym_check(UC.M) == false)) + { + arma_debug_warn_level(1, "mvnrnd(): given matrix is not symmetric"); + } + + bool status = false; + + if(UM.is_alias(out) || UC.is_alias(out)) + { + Mat tmp; + + status = glue_mvnrnd::apply_noalias(tmp, UM.M, UC.M, N); + + out.steal_mem(tmp); + } + else + { + status = glue_mvnrnd::apply_noalias(out, UM.M, UC.M, N); + } + + return status; + } + + + +template +inline +bool +glue_mvnrnd::apply_noalias(Mat& out, const Mat& M, const Mat& C, const uword N) + { + arma_extra_debug_sigprint(); + + Mat D; + + const bool chol_status = op_chol::apply_direct(D, C, 1); // '1' means "lower triangular" + + if(chol_status == false) + { + // C is not symmetric positive definite, so find approximate square root of C + + Col eigval; // NOTE: eT is constrained to be real (ie. float or double) in fn_mvnrnd.hpp + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, C, 'd', "mvnrnd()"); + + if(eig_status == false) { return false; } + + eT* eigval_mem = eigval.memptr(); + const uword eigval_n_elem = eigval.n_elem; + + // since we're doing an approximation, tolerate tiny negative eigenvalues + + const eT tol = eT(-100) * Datum::eps * norm(C, "fro"); + + if(arma_isfinite(tol) == false) { return false; } + + for(uword i=0; i DD = eigvec * diagmat(sqrt(eigval)); + + D.steal_mem(DD); + } + + out = D * randn< Mat >(M.n_rows, N); + + if(N == 1) + { + out += M; + } + else + if(N > 1) + { + out.each_col() += M; + } + + return true; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_polyfit_bones.hpp b/src/armadillo/include/armadillo_bits/glue_polyfit_bones.hpp new file mode 100644 index 0000000..8e771dc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_polyfit_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_polyfit +//! @{ + + + +class glue_polyfit + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + }; + + template inline static bool apply_noalias(Mat& out, const Col& X, const Col& Y, const uword N); + + template inline static bool apply_direct(Mat& out, const Base& X_expr, const Base& Y_expr, const uword N); + + template inline static void apply(Mat& out, const Glue& expr); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_polyfit_meat.hpp b/src/armadillo/include/armadillo_bits/glue_polyfit_meat.hpp new file mode 100644 index 0000000..3969c75 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_polyfit_meat.hpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_polyfit +//! @{ + + + +template +inline +bool +glue_polyfit::apply_noalias(Mat& out, const Col& X, const Col& Y, const uword N) + { + arma_extra_debug_sigprint(); + + // create Vandermonde matrix + + Mat V(X.n_elem, N+1, arma_nozeros_indicator()); + + V.tail_cols(1).ones(); + + for(uword i=1; i <= N; ++i) + { + const uword j = N-i; + + Col V_col_j (V.colptr(j ), V.n_rows, false, false); + Col V_col_jp1(V.colptr(j+1), V.n_rows, false, false); + + V_col_j = V_col_jp1 % X; + } + + Mat Q; + Mat R; + + const bool status1 = auxlib::qr_econ(Q, R, V); + + if(status1 == false) { return false; } + + const bool status2 = auxlib::solve_trimat_fast(out, R, (Q.t() * Y), uword(0)); + + if(status2 == false) { return false; } + + return true; + } + + + +template +inline +bool +glue_polyfit::apply_direct(Mat& out, const Base& X_expr, const Base& Y_expr, const uword N) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const quasi_unwrap UY(Y_expr.get_ref()); + + const Mat& X = UX.M; + const Mat& Y = UY.M; + + arma_debug_check + ( + ( ((X.is_vec() == false) && (X.is_empty() == false)) || ((Y.is_vec() == false) && (Y.is_empty() == false)) ), + "polyfit(): given object must be a vector" + ); + + arma_debug_check( (X.n_elem != Y.n_elem), "polyfit(): given vectors must have the same number of elements" ); + + if(X.n_elem == 0) + { + out.reset(); + return true; + } + + arma_debug_check( (N >= X.n_elem), "polyfit(): N must be less than the number of elements in X" ); + + const Col X_as_colvec( const_cast(X.memptr()), X.n_elem, false, false); + const Col Y_as_colvec( const_cast(Y.memptr()), Y.n_elem, false, false); + + bool status = false; + + if(UX.is_alias(out) || UY.is_alias(out)) + { + Mat tmp; + status = glue_polyfit::apply_noalias(tmp, X_as_colvec, Y_as_colvec, N); + out.steal_mem(tmp); + } + else + { + status = glue_polyfit::apply_noalias(out, X_as_colvec, Y_as_colvec, N); + } + + return status; + } + + + +template +inline +void +glue_polyfit::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + const bool status = glue_polyfit::apply_direct(out, expr.A, expr.B, expr.aux_uword); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("polyfit(): failed"); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_polyval_bones.hpp b/src/armadillo/include/armadillo_bits/glue_polyval_bones.hpp new file mode 100644 index 0000000..f937bd5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_polyval_bones.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_polyval +//! @{ + + + +class glue_polyval + { + public: + + template + struct traits + { + static constexpr bool is_row = T2::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = T2::is_xvec; + }; + + template inline static void apply_noalias(Mat& out, const Mat& P, const Mat& X); + + template inline static void apply(Mat& out, const Glue& expr); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/glue_polyval_meat.hpp b/src/armadillo/include/armadillo_bits/glue_polyval_meat.hpp new file mode 100644 index 0000000..2c2a59c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_polyval_meat.hpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_polyval +//! @{ + + + +template +inline +void +glue_polyval::apply_noalias(Mat& out, const Mat& P, const Mat& X) + { + arma_extra_debug_sigprint(); + + out.set_size(X.n_rows, X.n_cols); + + const eT* P_mem = P.memptr(); + const uword P_n_elem = P.n_elem; + + out.fill(P_mem[0]); + + for(uword i=1; i < P_n_elem; ++i) + { + out = out % X + P_mem[i]; + } + } + + + +template +inline +void +glue_polyval::apply(Mat& out, const Glue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UP(expr.A); + const quasi_unwrap UX(expr.B); + + const Mat& P = UP.M; + const Mat& X = UX.M; + + arma_debug_check( ((P.is_vec() == false) && (P.is_empty() == false)), "polyval(): argument P must be a vector" ); + + if(P.is_empty() || X.is_empty()) + { + out.zeros(X.n_rows, X.n_cols); + return; + } + + if(UP.is_alias(out) || UX.is_alias(out)) + { + Mat tmp; + glue_polyval::apply_noalias(tmp, P, X); + out.steal_mem(tmp); + } + else + { + glue_polyval::apply_noalias(out, P, X); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_powext_bones.hpp b/src/armadillo/include/armadillo_bits/glue_powext_bones.hpp new file mode 100644 index 0000000..d5698c5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_powext_bones.hpp @@ -0,0 +1,70 @@ + +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_powext +//! @{ + + + +class glue_powext + : public traits_glue_or + { + public: + + template inline static void apply(Mat& out, const Glue& X); + + template inline static void apply(Mat& out, const Mat& A, const Mat& B); + + template inline static Mat apply(const subview_each1& X, const Base& Y); + + // + + template inline static void apply(Cube& out, const GlueCube& X); + + template inline static void apply(Cube& out, const Cube& A, const Cube& B); + + template inline static Cube apply(const subview_cube_each1& X, const Base& Y); + }; + + + +class glue_powext_cx + : public traits_glue_or + { + public: + + template inline static void apply(Mat& out, const mtGlue& X); + + template inline static void apply(Mat< std::complex >& out, const Mat< std::complex >& A, const Mat& B); + + template inline static Mat apply(const subview_each1& X, const Base& Y); + + // + + template inline static void apply(Cube& out, const mtGlueCube& X); + + template inline static void apply(Cube< std::complex >& out, const Cube< std::complex >& A, const Cube& B); + + template inline static Cube< std::complex > apply(const subview_cube_each1< std::complex >& X, const Base& Y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_powext_meat.hpp b/src/armadillo/include/armadillo_bits/glue_powext_meat.hpp new file mode 100644 index 0000000..700a2cf --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_powext_meat.hpp @@ -0,0 +1,674 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_powext +//! @{ + + +template +inline +void +glue_powext::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(X.A); + const quasi_unwrap UB(X.B); + + const Mat& A = UA.M; + const Mat& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + const bool UA_bad_alias = UA.is_alias(out) && (UA.has_subview); // allow inplace operation + const bool UB_bad_alias = UB.is_alias(out); + + if(UA_bad_alias || UB_bad_alias) + { + Mat tmp; + + glue_powext::apply(tmp, A, B); + + out.steal_mem(tmp); + } + else + { + glue_powext::apply(out, A, B); + } + } + + + +template +inline +void +glue_powext::apply(Mat& out, const Mat& A, const Mat& B) + { + arma_extra_debug_sigprint(); + + out.set_size(A.n_rows, A.n_cols); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const eT* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Mat +glue_powext::apply + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + Mat out(A_n_rows, A_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const eT* B_mem = B.memptr(); + + if(mode == 0) // each column + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_mem[row]); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_mem[row]); + } + } + } + } + + if(mode == 1) // each row + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_val); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_val); + } + } + } + } + + return out; + } + + + +template +inline +void +glue_powext::apply(Cube& out, const GlueCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube UA(X.A); + const unwrap_cube UB(X.B); + + const Cube& A = UA.M; + const Cube& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + if(UB.is_alias(out)) + { + Cube tmp; + + glue_powext::apply(tmp, A, B); + + out.steal_mem(tmp); + } + else + { + glue_powext::apply(out, A, B); + } + } + + + +template +inline +void +glue_powext::apply(Cube& out, const Cube& A, const Cube& B) + { + arma_extra_debug_sigprint(); + + out.set_size(A.n_rows, A.n_cols, A.n_slices); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const eT* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Cube +glue_powext::apply + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_n_slices = A.n_slices; + + Cube out(A_n_rows, A_n_cols, A_n_slices, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const eT* B_mem = B.memptr(); + const uword B_n_elem = B.n_elem; + + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_slices) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = eop_aux::pow(A_slice_mem[i], B_mem[i]); + } + } + } + #endif + } + else + { + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = eop_aux::pow(A_slice_mem[i], B_mem[i]); + } + } + } + + return out; + } + + + +// + + + +template +inline +void +glue_powext_cx::apply(Mat& out, const mtGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const quasi_unwrap UA(X.A); + const quasi_unwrap UB(X.B); + + const Mat& A = UA.M; + const Mat< T>& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + if(UA.is_alias(out) && (UA.has_subview)) + { + Mat tmp; + + glue_powext_cx::apply(tmp, A, B); + + out.steal_mem(tmp); + } + else + { + glue_powext_cx::apply(out, A, B); + } + } + + + +template +inline +void +glue_powext_cx::apply(Mat< std::complex >& out, const Mat< std::complex >& A, const Mat& B) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + out.set_size(A.n_rows, A.n_cols); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const T* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Mat +glue_powext_cx::apply + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + typedef typename parent::pod_type T; + + const parent& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + Mat out(A_n_rows, A_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const T* B_mem = B.memptr(); + + if(mode == 0) // each column + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_mem[row]); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_mem[row]); + } + } + } + } + + if(mode == 1) // each row + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_val); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_val); + } + } + } + } + + return out; + } + + + +template +inline +void +glue_powext_cx::apply(Cube& out, const mtGlueCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + typedef typename get_pod_type::result T; + + const unwrap_cube UA(X.A); + const unwrap_cube UB(X.B); + + const Cube& A = UA.M; + const Cube< T>& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + glue_powext_cx::apply(out, A, B); + } + + + +template +inline +void +glue_powext_cx::apply(Cube< std::complex >& out, const Cube< std::complex >& A, const Cube& B) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + out.set_size(A.n_rows, A.n_cols, A.n_slices); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const T* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Cube< std::complex > +glue_powext_cx::apply + ( + const subview_cube_each1< std::complex >& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Cube& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_n_slices = A.n_slices; + + Cube out(A_n_rows, A_n_cols, A_n_slices, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const T* B_mem = B.memptr(); + const uword B_n_elem = B.n_elem; + + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_slices) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = std::pow(A_slice_mem[i], B_mem[i]); + } + } + } + #endif + } + else + { + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = std::pow(A_slice_mem[i], B_mem[i]); + } + } + } + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_quantile_bones.hpp b/src/armadillo/include/armadillo_bits/glue_quantile_bones.hpp new file mode 100644 index 0000000..cd7fcf1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_quantile_bones.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_quantile +//! @{ + + +class glue_quantile + : public traits_glue_default + { + public: + + template + inline static void worker(eTb* out_mem, Col& Y, const Mat& P); + + + template + inline static void apply_noalias(Mat& out, const Mat& X, const Mat& P, const uword dim); + + template + inline static void apply(Mat& out, const mtGlue& expr); + }; + + + +class glue_quantile_default + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + inline static void apply(Mat& out, const mtGlue& expr); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_quantile_meat.hpp b/src/armadillo/include/armadillo_bits/glue_quantile_meat.hpp new file mode 100644 index 0000000..370e432 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_quantile_meat.hpp @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_quantile +//! @{ + + +template +inline +void +glue_quantile::worker(eTb* out_mem, Col& Y, const Mat& P) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming out_mem is an array with P.n_elem elements + + // TODO: ignore non-finite values ? + + // algorithm based on "Definition 5" in: + // Rob J. Hyndman and Yanan Fan. + // Sample Quantiles in Statistical Packages. + // The American Statistician, Vol. 50, No. 4, pp. 361-365, 1996. + // http://doi.org/10.2307/2684934 + + const eTb* P_mem = P.memptr(); + const uword P_n_elem = P.n_elem; + + const eTb alpha = 0.5; + const eTb N = eTb(Y.n_elem); + const eTb P_min = (eTb(1) - alpha) / N; + const eTb P_max = (N - alpha) / N; + + for(uword i=0; i < P_n_elem; ++i) + { + const eTb P_i = P_mem[i]; + + eTb out_val = eTb(0); + + if(P_i < P_min) + { + out_val = (P_i < eTb(0)) ? eTb(-std::numeric_limits::infinity()) : eTb(Y.min()); + } + else + if(P_i > P_max) + { + out_val = (P_i > eTb(1)) ? eTb( std::numeric_limits::infinity()) : eTb(Y.max()); + } + else + { + const uword k = uword(std::floor(N * P_i + alpha)); + const eTb P_k = (eTb(k) - alpha) / N; + + const eTb w = (P_i - P_k) * N; + + eTa* Y_k_ptr = Y.begin() + uword(k); + std::nth_element( Y.begin(), Y_k_ptr, Y.end() ); + const eTa Y_k_val = (*Y_k_ptr); + + eTa* Y_km1_ptr = Y.begin() + uword(k-1); + // std::nth_element( Y.begin(), Y_km1_ptr, Y.end() ); + std::nth_element( Y.begin(), Y_km1_ptr, Y_k_ptr ); + const eTa Y_km1_val = (*Y_km1_ptr); + + out_val = ((eTb(1) - w) * Y_km1_val) + (w * Y_k_val); + } + + out_mem[i] = out_val; + } + } + + + +template +inline +void +glue_quantile::apply_noalias(Mat& out, const Mat& X, const Mat& P, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ((P.is_vec() == false) && (P.is_empty() == false)), "quantile(): parameter 'P' must be a vector" ); + + if(X.is_empty()) { out.reset(); return; } + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword P_n_elem = P.n_elem; + + if(dim == 0) + { + out.set_size(P_n_elem, X_n_cols); + + if(out.is_empty()) { return; } + + Col Y(X_n_rows, arma_nozeros_indicator()); + + if(X_n_cols == 1) + { + arrayops::copy(Y.memptr(), X.memptr(), X_n_rows); + + glue_quantile::worker(out.memptr(), Y, P); + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + arrayops::copy(Y.memptr(), X.colptr(col), X_n_rows); + + glue_quantile::worker(out.colptr(col), Y, P); + } + } + } + else + if(dim == 1) + { + out.set_size(X_n_rows, P_n_elem); + + if(out.is_empty()) { return; } + + Col Y(X_n_cols, arma_nozeros_indicator()); + + if(X_n_rows == 1) + { + arrayops::copy(Y.memptr(), X.memptr(), X_n_cols); + + glue_quantile::worker(out.memptr(), Y, P); + } + else + { + Col tmp(P_n_elem, arma_nozeros_indicator()); + + eTb* tmp_mem = tmp.memptr(); + + for(uword row=0; row < X_n_rows; ++row) + { + eTa* Y_mem = Y.memptr(); + + for(uword col=0; col < X_n_cols; ++col) { Y_mem[col] = X.at(row,col); } + + glue_quantile::worker(tmp_mem, Y, P); + + for(uword i=0; i < P_n_elem; ++i) { out.at(row,i) = tmp_mem[i]; } + } + } + } + } + + + +template +inline +void +glue_quantile::apply(Mat& out, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T2::elem_type eTb; + + const uword dim = expr.aux_uword; + + arma_debug_check( (dim > 1), "quantile(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + arma_debug_check((UA.M.internal_has_nan() || UB.M.internal_has_nan()), "quantile(): detected NaN"); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + Mat tmp; + + glue_quantile::apply_noalias(tmp, UA.M, UB.M, dim); + + out.steal_mem(tmp); + } + else + { + glue_quantile::apply_noalias(out, UA.M, UB.M, dim); + } + } + + + +template +inline +void +glue_quantile_default::apply(Mat& out, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T2::elem_type eTb; + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + const uword dim = (T1::is_xvec) ? uword(UA.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + arma_debug_check((UA.M.internal_has_nan() || UB.M.internal_has_nan()), "quantile(): detected NaN"); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + Mat tmp; + + glue_quantile::apply_noalias(tmp, UA.M, UB.M, dim); + + out.steal_mem(tmp); + } + else + { + glue_quantile::apply_noalias(out, UA.M, UB.M, dim); + } + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_relational_bones.hpp b/src/armadillo/include/armadillo_bits/glue_relational_bones.hpp new file mode 100644 index 0000000..876ffb7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_relational_bones.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_relational +//! @{ + + + +class glue_rel_lt + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +class glue_rel_gt + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +class glue_rel_lteq + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +class glue_rel_gteq + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +class glue_rel_eq + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +class glue_rel_noteq + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +class glue_rel_and + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +class glue_rel_or + : public traits_glue_or + { + public: + + template + inline static void apply(Mat & out, const mtGlue& X); + + template + inline static void apply(Cube & out, const mtGlueCube& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_relational_meat.hpp b/src/armadillo/include/armadillo_bits/glue_relational_meat.hpp new file mode 100644 index 0000000..5728a09 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_relational_meat.hpp @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_relational +//! @{ + + + +#undef operator_rel +#undef operator_str + +#undef arma_applier_mat +#undef arma_applier_cube + + +#define arma_applier_mat(operator_rel, operator_str) \ + {\ + const Proxy P1(X.A);\ + const Proxy P2(X.B);\ + \ + arma_debug_assert_same_size(P1, P2, operator_str);\ + \ + const bool bad_alias = (Proxy::has_subview && P1.is_alias(out)) || (Proxy::has_subview && P2.is_alias(out));\ + \ + if(bad_alias == false)\ + {\ + \ + const uword n_rows = P1.get_n_rows();\ + const uword n_cols = P1.get_n_cols();\ + \ + out.set_size(n_rows, n_cols);\ + \ + uword* out_mem = out.memptr();\ + \ + const bool use_at = (Proxy::use_at || Proxy::use_at);\ + \ + if(use_at == false)\ + {\ + typename Proxy::ea_type A = P1.get_ea();\ + typename Proxy::ea_type B = P2.get_ea();\ + \ + const uword n_elem = out.n_elem;\ + \ + for(uword i=0; i::stored_type> tmp1(P1.Q, P1.is_alias(out));\ + const unwrap_check::stored_type> tmp2(P2.Q, P2.is_alias(out));\ + \ + out = (tmp1.M) operator_rel (tmp2.M);\ + }\ + } + + + + +#define arma_applier_cube(operator_rel, operator_str) \ + {\ + const ProxyCube P1(X.A);\ + const ProxyCube P2(X.B);\ + \ + arma_debug_assert_same_size(P1, P2, operator_str);\ + \ + const bool bad_alias = (ProxyCube::has_subview && P1.is_alias(out)) || (ProxyCube::has_subview && P2.is_alias(out));\ + \ + if(bad_alias == false)\ + {\ + \ + const uword n_rows = P1.get_n_rows();\ + const uword n_cols = P1.get_n_cols();\ + const uword n_slices = P1.get_n_slices();\ + \ + out.set_size(n_rows, n_cols, n_slices);\ + \ + uword* out_mem = out.memptr();\ + \ + const bool use_at = (ProxyCube::use_at || ProxyCube::use_at);\ + \ + if(use_at == false)\ + {\ + typename ProxyCube::ea_type A = P1.get_ea();\ + typename ProxyCube::ea_type B = P2.get_ea();\ + \ + const uword n_elem = out.n_elem;\ + \ + for(uword i=0; i::stored_type> tmp1(P1.Q);\ + const unwrap_cube::stored_type> tmp2(P2.Q);\ + \ + out = (tmp1.M) operator_rel (tmp2.M);\ + }\ + } + + + +template +inline +void +glue_rel_lt::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(<, "operator<"); + } + + + +template +inline +void +glue_rel_gt::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(>, "operator>"); + } + + + +template +inline +void +glue_rel_lteq::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(<=, "operator<="); + } + + + +template +inline +void +glue_rel_gteq::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(>=, "operator>="); + } + + + +template +inline +void +glue_rel_eq::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(==, "operator=="); + } + + + +template +inline +void +glue_rel_noteq::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(!=, "operator!="); + } + + + +template +inline +void +glue_rel_and::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(&&, "operator&&"); + } + + + +template +inline +void +glue_rel_or::apply + ( + Mat & out, + const mtGlue& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_mat(||, "operator||"); + } + + + +// +// +// + + + +template +inline +void +glue_rel_lt::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(<, "operator<"); + } + + + +template +inline +void +glue_rel_gt::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(>, "operator>"); + } + + + +template +inline +void +glue_rel_lteq::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(<=, "operator<="); + } + + + +template +inline +void +glue_rel_gteq::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(>=, "operator>="); + } + + + +template +inline +void +glue_rel_eq::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(==, "operator=="); + } + + + +template +inline +void +glue_rel_noteq::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(!=, "operator!="); + } + + + +template +inline +void +glue_rel_and::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(&&, "operator&&"); + } + + + +template +inline +void +glue_rel_or::apply + ( + Cube & out, + const mtGlueCube& X + ) + { + arma_extra_debug_sigprint(); + + arma_applier_cube(||, "operator||"); + } + + + +#undef arma_applier_mat +#undef arma_applier_cube + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_solve_bones.hpp b/src/armadillo/include/armadillo_bits/glue_solve_bones.hpp new file mode 100644 index 0000000..20c0165 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_solve_bones.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_solve +//! @{ + + + +class glue_solve_gen_default + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template inline static void apply(Mat& out, const Glue& X); + + template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr); + }; + + + +class glue_solve_gen_full + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template inline static void apply(Mat& out, const Glue& X); + + template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr, const uword flags); + }; + + + +class glue_solve_tri_default + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template inline static void apply(Mat& out, const Glue& X); + + template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr, const uword flags); + }; + + + +class glue_solve_tri_full + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template inline static void apply(Mat& out, const Glue& X); + + template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr, const uword flags); + }; + + + +namespace solve_opts + { + struct opts + { + const uword flags; + + inline constexpr explicit opts(const uword in_flags); + + inline const opts operator+(const opts& rhs) const; + }; + + inline + constexpr + opts::opts(const uword in_flags) + : flags(in_flags) + {} + + inline + const opts + opts::operator+(const opts& rhs) const + { + const opts result( flags | rhs.flags ); + + return result; + } + + // The values below (eg. 1u << 1) are for internal Armadillo use only. + // The values can change without notice. + + static constexpr uword flag_none = uword(0 ); + static constexpr uword flag_fast = uword(1u << 0); + static constexpr uword flag_equilibrate = uword(1u << 1); + static constexpr uword flag_no_approx = uword(1u << 2); + static constexpr uword flag_triu = uword(1u << 3); + static constexpr uword flag_tril = uword(1u << 4); + static constexpr uword flag_no_band = uword(1u << 5); + static constexpr uword flag_no_sympd = uword(1u << 6); + static constexpr uword flag_allow_ugly = uword(1u << 7); + static constexpr uword flag_likely_sympd = uword(1u << 8); + static constexpr uword flag_refine = uword(1u << 9); + static constexpr uword flag_no_trimat = uword(1u << 10); + static constexpr uword flag_force_approx = uword(1u << 11); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_fast : public opts { inline constexpr opts_fast() : opts(flag_fast ) {} }; + struct opts_equilibrate : public opts { inline constexpr opts_equilibrate() : opts(flag_equilibrate ) {} }; + struct opts_no_approx : public opts { inline constexpr opts_no_approx() : opts(flag_no_approx ) {} }; + struct opts_triu : public opts { inline constexpr opts_triu() : opts(flag_triu ) {} }; + struct opts_tril : public opts { inline constexpr opts_tril() : opts(flag_tril ) {} }; + struct opts_no_band : public opts { inline constexpr opts_no_band() : opts(flag_no_band ) {} }; + struct opts_no_sympd : public opts { inline constexpr opts_no_sympd() : opts(flag_no_sympd ) {} }; + struct opts_allow_ugly : public opts { inline constexpr opts_allow_ugly() : opts(flag_allow_ugly ) {} }; + struct opts_likely_sympd : public opts { inline constexpr opts_likely_sympd() : opts(flag_likely_sympd) {} }; + struct opts_refine : public opts { inline constexpr opts_refine() : opts(flag_refine ) {} }; + struct opts_no_trimat : public opts { inline constexpr opts_no_trimat() : opts(flag_no_trimat ) {} }; + struct opts_force_approx : public opts { inline constexpr opts_force_approx() : opts(flag_force_approx) {} }; + + static constexpr opts_none none; + static constexpr opts_fast fast; + static constexpr opts_equilibrate equilibrate; + static constexpr opts_no_approx no_approx; + static constexpr opts_triu triu; + static constexpr opts_tril tril; + static constexpr opts_no_band no_band; + static constexpr opts_no_sympd no_sympd; + static constexpr opts_allow_ugly allow_ugly; + static constexpr opts_likely_sympd likely_sympd; + static constexpr opts_refine refine; + static constexpr opts_no_trimat no_trimat; + static constexpr opts_force_approx force_approx; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_solve_meat.hpp b/src/armadillo/include/armadillo_bits/glue_solve_meat.hpp new file mode 100644 index 0000000..1c3bbf0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_solve_meat.hpp @@ -0,0 +1,587 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_solve +//! @{ + + + +// +// glue_solve_gen_default + + +template +inline +void +glue_solve_gen_default::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + const bool status = glue_solve_gen_default::apply(out, X.A, X.B); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("solve(): solution not found"); + } + } + + + +template +inline +bool +glue_solve_gen_default::apply(Mat& out, const Base& A_expr, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + return glue_solve_gen_full::apply( out, A_expr, B_expr, uword(0)); + } + + + +// +// glue_solve_gen_full + + +template +inline +void +glue_solve_gen_full::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + const bool status = glue_solve_gen_full::apply( out, X.A, X.B, X.aux_uword ); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("solve(): solution not found"); + } + } + + + +template +inline +bool +glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const Base& B_expr, const uword flags) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(has_user_flags == true ) { arma_extra_debug_print("glue_solve_gen_full::apply(): has_user_flags = true" ); } + if(has_user_flags == false) { arma_extra_debug_print("glue_solve_gen_full::apply(): has_user_flags = false"); } + + const bool fast = has_user_flags && bool(flags & solve_opts::flag_fast ); + const bool equilibrate = has_user_flags && bool(flags & solve_opts::flag_equilibrate ); + const bool no_approx = has_user_flags && bool(flags & solve_opts::flag_no_approx ); + const bool no_band = has_user_flags && bool(flags & solve_opts::flag_no_band ); + const bool no_sympd = has_user_flags && bool(flags & solve_opts::flag_no_sympd ); + const bool allow_ugly = has_user_flags && bool(flags & solve_opts::flag_allow_ugly ); + const bool likely_sympd = has_user_flags && bool(flags & solve_opts::flag_likely_sympd); + const bool refine = has_user_flags && bool(flags & solve_opts::flag_refine ); + const bool no_trimat = has_user_flags && bool(flags & solve_opts::flag_no_trimat ); + const bool force_approx = has_user_flags && bool(flags & solve_opts::flag_force_approx); + + if(has_user_flags) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): enabled flags:"); + + if(fast ) { arma_extra_debug_print("fast"); } + if(equilibrate ) { arma_extra_debug_print("equilibrate"); } + if(no_approx ) { arma_extra_debug_print("no_approx"); } + if(no_band ) { arma_extra_debug_print("no_band"); } + if(no_sympd ) { arma_extra_debug_print("no_sympd"); } + if(allow_ugly ) { arma_extra_debug_print("allow_ugly"); } + if(likely_sympd) { arma_extra_debug_print("likely_sympd"); } + if(refine ) { arma_extra_debug_print("refine"); } + if(no_trimat ) { arma_extra_debug_print("no_trimat"); } + if(force_approx) { arma_extra_debug_print("force_approx"); } + + arma_debug_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" ); + arma_debug_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" ); + arma_debug_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" ); + } + + Mat A = A_expr.get_ref(); + + if(force_approx) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): forced approximate solution"); + + arma_debug_check( no_approx, "solve(): options 'no_approx' and 'force_approx' are mutually exclusive" ); + + if(fast) { arma_debug_warn_level(2, "solve(): option 'fast' ignored for forced approximate solution" ); } + if(equilibrate) { arma_debug_warn_level(2, "solve(): option 'equilibrate' ignored for forced approximate solution" ); } + if(refine) { arma_debug_warn_level(2, "solve(): option 'refine' ignored for forced approximate solution" ); } + if(likely_sympd) { arma_debug_warn_level(2, "solve(): option 'likely_sympd' ignored for forced approximate solution" ); } + + return auxlib::solve_approx_svd(actual_out, A, B_expr.get_ref()); // A is overwritten + } + + // A_expr and B_expr can be used more than once (sympd optimisation fails or approximate solution required), + // so ensure they are not overwritten in case we have aliasing + + bool is_alias = true; // assume we have aliasing until we can prove otherwise + + if(is_Mat::value && is_Mat::value) + { + const quasi_unwrap UA( A_expr.get_ref() ); + const quasi_unwrap UB( B_expr.get_ref() ); + + is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); + } + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + T rcond = T(0); + bool status = false; + + if(A.n_rows == A.n_cols) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): detected square system"); + + uword KL = 0; + uword KU = 0; + + const bool is_band = arma_config::optimise_band && ((no_band || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32))); + + const bool is_triu = (no_trimat || refine || equilibrate || likely_sympd || is_band ) ? false : trimat_helper::is_triu(A); + const bool is_tril = (no_trimat || refine || equilibrate || likely_sympd || is_band || is_triu) ? false : trimat_helper::is_tril(A); + + const bool try_sympd = arma_config::optimise_sym && ((no_sympd || auxlib::crippled_lapack(A) || is_band || is_triu || is_tril) ? false : (likely_sympd ? true : sym_helper::guess_sympd(A, uword(16)))); + + if(fast) + { + // fast mode: solvers without refinement and without rcond estimate + + arma_extra_debug_print("glue_solve_gen_full::apply(): fast mode"); + + if(is_band) + { + if( (KL == 1) && (KU == 1) ) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + tridiagonal"); + + status = auxlib::solve_tridiag_fast(out, A, B_expr.get_ref()); + } + else + { + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + band"); + + status = auxlib::solve_band_fast(out, A, KL, KU, B_expr.get_ref()); + } + } + else + if(is_triu || is_tril) + { + if(is_triu) { arma_extra_debug_print("glue_solve_gen_full::apply(): fast + upper triangular matrix"); } + if(is_tril) { arma_extra_debug_print("glue_solve_gen_full::apply(): fast + lower triangular matrix"); } + + const uword layout = (is_triu) ? uword(0) : uword(1); + + status = auxlib::solve_trimat_fast(out, A, B_expr.get_ref(), layout); + } + else + if(try_sympd) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + try_sympd"); + + status = auxlib::solve_sympd_fast(out, A, B_expr.get_ref()); // A is overwritten + + if(status == false) + { + // auxlib::solve_sympd_fast() may have failed because A isn't really sympd + + arma_extra_debug_print("glue_solve_gen_full::apply(): auxlib::solve_sympd_fast() failed; retrying"); + + A = A_expr.get_ref(); + + status = auxlib::solve_square_fast(out, A, B_expr.get_ref()); // A is overwritten + } + } + else + { + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + dense"); + + status = auxlib::solve_square_fast(out, A, B_expr.get_ref()); // A is overwritten + } + } + else + if(refine || equilibrate) + { + // refine mode: solvers with refinement and with rcond estimate + + arma_extra_debug_print("glue_solve_gen_full::apply(): refine mode"); + + if(is_band) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): refine + band"); + + status = auxlib::solve_band_refine(out, rcond, A, KL, KU, B_expr, equilibrate); + } + else + if(try_sympd) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): refine + try_sympd"); + + status = auxlib::solve_sympd_refine(out, rcond, A, B_expr.get_ref(), equilibrate); // A is overwritten + + if( (status == false) && (rcond == T(0)) ) + { + // auxlib::solve_sympd_refine() may have failed because A isn't really sympd; + // in that case rcond is set to zero + + arma_extra_debug_print("glue_solve_gen_full::apply(): auxlib::solve_sympd_refine() failed; retrying"); + + A = A_expr.get_ref(); + + status = auxlib::solve_square_refine(out, rcond, A, B_expr.get_ref(), equilibrate); // A is overwritten + } + } + else + { + arma_extra_debug_print("glue_solve_gen_full::apply(): refine + dense"); + + status = auxlib::solve_square_refine(out, rcond, A, B_expr, equilibrate); // A is overwritten + } + } + else + { + // default mode: solvers without refinement but with rcond estimate + + arma_extra_debug_print("glue_solve_gen_full::apply(): default mode"); + + if(is_band) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): rcond + band"); + + status = auxlib::solve_band_rcond(out, rcond, A, KL, KU, B_expr.get_ref()); + } + else + if(is_triu || is_tril) + { + if(is_triu) { arma_extra_debug_print("glue_solve_gen_full::apply(): rcond + upper triangular matrix"); } + if(is_tril) { arma_extra_debug_print("glue_solve_gen_full::apply(): rcond + lower triangular matrix"); } + + const uword layout = (is_triu) ? uword(0) : uword(1); + + status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); + } + else + if(try_sympd) + { + bool sympd_state = false; + + status = auxlib::solve_sympd_rcond(out, sympd_state, rcond, A, B_expr.get_ref()); // A is overwritten + + if( (status == false) && (sympd_state == false) ) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): auxlib::solve_sympd_rcond() failed; retrying"); + + A = A_expr.get_ref(); + + status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten + } + } + else + { + status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten + } + } + } + else + { + arma_extra_debug_print("glue_solve_gen_full::apply(): detected non-square system"); + + if(equilibrate) { arma_debug_warn_level(2, "solve(): option 'equilibrate' ignored for non-square matrix" ); } + if(refine) { arma_debug_warn_level(2, "solve(): option 'refine' ignored for non-square matrix" ); } + if(likely_sympd) { arma_debug_warn_level(2, "solve(): option 'likely_sympd' ignored for non-square matrix" ); } + + if(fast) + { + status = auxlib::solve_rect_fast(out, A, B_expr.get_ref()); // A is overwritten + } + else + { + status = auxlib::solve_rect_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten + } + } + + + if( (status == true) && (fast == false) && (allow_ugly == false) && ((rcond < std::numeric_limits::epsilon()) || arma_isnan(rcond)) ) + { + status = false; + } + + + if( (status == false) && (no_approx == false) ) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): solving rank deficient system"); + + if(rcond == T(0)) + { + arma_debug_warn_level(2, "solve(): system is singular; attempting approx solution"); + } + else + { + arma_debug_warn_level(2, "solve(): system is singular (rcond: ", rcond, "); attempting approx solution"); + } + + // TODO: conditionally recreate A: have a separate state flag which indicates whether A was previously overwritten + + A = A_expr.get_ref(); // as A may have been overwritten + + status = auxlib::solve_approx_svd(out, A, B_expr.get_ref()); // A is overwritten + } + + if(is_alias) { actual_out.steal_mem(out); } + + return status; + } + + + +// +// glue_solve_tri_default + + +template +inline +void +glue_solve_tri_default::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + const bool status = glue_solve_tri_default::apply( out, X.A, X.B, X.aux_uword ); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("solve(): solution not found"); + } + } + + + +template +inline +bool +glue_solve_tri_default::apply(Mat& actual_out, const Base& A_expr, const Base& B_expr, const uword flags) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const bool triu = bool(flags & solve_opts::flag_triu); + const bool tril = bool(flags & solve_opts::flag_tril); + + arma_extra_debug_print("glue_solve_tri_default::apply(): enabled flags:"); + + if(triu) { arma_extra_debug_print("triu"); } + if(tril) { arma_extra_debug_print("tril"); } + + const quasi_unwrap UA(A_expr.get_ref()); + const Mat& A = UA.M; + + arma_debug_check( (A.is_square() == false), "solve(): matrix marked as triangular must be square sized" ); + + const uword layout = (triu) ? uword(0) : uword(1); + + bool is_alias = true; + + if(is_Mat::value) + { + const quasi_unwrap UB(B_expr.get_ref()); + + is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); + } + + T rcond = T(0); + bool status = false; + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); // A is not modified + + + if( (status == true) && ( (rcond < std::numeric_limits::epsilon()) || arma_isnan(rcond) ) ) + { + status = false; + } + + + if(status == false) + { + arma_extra_debug_print("glue_solve_tri_default::apply(): solving rank deficient system"); + + if(rcond == T(0)) + { + arma_debug_warn_level(2, "solve(): system is singular; attempting approx solution"); + } + else + { + arma_debug_warn_level(2, "solve(): system is singular (rcond: ", rcond, "); attempting approx solution"); + } + + Mat triA = (triu) ? trimatu(A) : trimatl(A); // trimatu() and trimatl() return the same type + + status = auxlib::solve_approx_svd(out, triA, B_expr.get_ref()); // triA is overwritten + } + + + if(is_alias) { actual_out.steal_mem(out); } + + return status; + } + + + +// +// glue_solve_tri_full + + +template +inline +void +glue_solve_tri_full::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + const bool status = glue_solve_tri_full::apply( out, X.A, X.B, X.aux_uword ); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("solve(): solution not found"); + } + } + + + +template +inline +bool +glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const Base& B_expr, const uword flags) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const bool fast = bool(flags & solve_opts::flag_fast ); + const bool equilibrate = bool(flags & solve_opts::flag_equilibrate ); + const bool no_approx = bool(flags & solve_opts::flag_no_approx ); + const bool triu = bool(flags & solve_opts::flag_triu ); + const bool tril = bool(flags & solve_opts::flag_tril ); + const bool allow_ugly = bool(flags & solve_opts::flag_allow_ugly ); + const bool likely_sympd = bool(flags & solve_opts::flag_likely_sympd); + const bool refine = bool(flags & solve_opts::flag_refine ); + const bool no_trimat = bool(flags & solve_opts::flag_no_trimat ); + const bool force_approx = bool(flags & solve_opts::flag_force_approx); + + arma_extra_debug_print("glue_solve_tri_full::apply(): enabled flags:"); + + if(fast ) { arma_extra_debug_print("fast"); } + if(equilibrate ) { arma_extra_debug_print("equilibrate"); } + if(no_approx ) { arma_extra_debug_print("no_approx"); } + if(triu ) { arma_extra_debug_print("triu"); } + if(tril ) { arma_extra_debug_print("tril"); } + if(allow_ugly ) { arma_extra_debug_print("allow_ugly"); } + if(likely_sympd) { arma_extra_debug_print("likely_sympd"); } + if(refine ) { arma_extra_debug_print("refine"); } + if(no_trimat ) { arma_extra_debug_print("no_trimat"); } + if(force_approx) { arma_extra_debug_print("force_approx"); } + + if(no_trimat || equilibrate || refine || force_approx) + { + const uword mask = ~(solve_opts::flag_triu | solve_opts::flag_tril); + + return glue_solve_gen_full::apply(actual_out, ((triu) ? trimatu(A_expr.get_ref()) : trimatl(A_expr.get_ref())), B_expr, (flags & mask)); + } + + if(likely_sympd) { arma_debug_warn_level(2, "solve(): option 'likely_sympd' ignored for triangular matrix"); } + + const quasi_unwrap UA(A_expr.get_ref()); + const Mat& A = UA.M; + + arma_debug_check( (A.is_square() == false), "solve(): matrix marked as triangular must be square sized" ); + + const uword layout = (triu) ? uword(0) : uword(1); + + bool is_alias = true; + + if(is_Mat::value) + { + const quasi_unwrap UB(B_expr.get_ref()); + + is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); + } + + T rcond = T(0); + bool status = false; + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + if(fast) + { + status = auxlib::solve_trimat_fast(out, A, B_expr.get_ref(), layout); // A is not modified + } + else + { + status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); // A is not modified + } + + + if( (status == true) && (fast == false) && (allow_ugly == false) && ((rcond < std::numeric_limits::epsilon()) || arma_isnan(rcond)) ) + { + status = false; + } + + + if( (status == false) && (no_approx == false) ) + { + arma_extra_debug_print("glue_solve_tri_full::apply(): solving rank deficient system"); + + if(rcond == T(0)) + { + arma_debug_warn_level(2, "solve(): system is singular; attempting approx solution"); + } + else + { + arma_debug_warn_level(2, "solve(): system is singular (rcond: ", rcond, "); attempting approx solution"); + } + + Mat triA = (triu) ? trimatu(A) : trimatl(A); // trimatu() and trimatl() return the same type + + status = auxlib::solve_approx_svd(out, triA, B_expr.get_ref()); // triA is overwritten + } + + + if(is_alias) { actual_out.steal_mem(out); } + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_times_bones.hpp b/src/armadillo/include/armadillo_bits/glue_times_bones.hpp new file mode 100644 index 0000000..5792e4e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_times_bones.hpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_times +//! @{ + + + +//! \brief +//! Template metaprogram depth_lhs +//! calculates the number of Glue instances on the left hand side argument of Glue +//! ie. it recursively expands each Tx, until the type of Tx is not "Glue<..,.., glue_type>" (i.e the "glue_type" changes) + +template +struct depth_lhs + { + static constexpr uword num = 0; + }; + +template +struct depth_lhs< glue_type, Glue > + { + static constexpr uword num = 1 + depth_lhs::num; + }; + + + +template +struct glue_times_redirect2_helper + { + template + arma_hot inline static void apply(Mat& out, const Glue& X); + }; + + +template<> +struct glue_times_redirect2_helper + { + template + arma_hot inline static void apply(Mat& out, const Glue& X); + }; + + + +template +struct glue_times_redirect3_helper + { + template + arma_hot inline static void apply(Mat& out, const Glue< Glue,T3,glue_times>& X); + }; + + +template<> +struct glue_times_redirect3_helper + { + template + arma_hot inline static void apply(Mat& out, const Glue< Glue,T3,glue_times>& X); + }; + + + +template +struct glue_times_redirect + { + template + arma_hot inline static void apply(Mat& out, const Glue& X); + }; + + +template<> +struct glue_times_redirect<2> + { + template + arma_hot inline static void apply(Mat& out, const Glue& X); + }; + + +template<> +struct glue_times_redirect<3> + { + template + arma_hot inline static void apply(Mat& out, const Glue< Glue,T3,glue_times>& X); + }; + + +template<> +struct glue_times_redirect<4> + { + template + arma_hot inline static void apply(Mat& out, const Glue< Glue< Glue, T3, glue_times>, T4, glue_times>& X); + }; + + + +//! Class which implements the immediate multiplication of two or more matrices +class glue_times + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + arma_hot inline static void apply(Mat& out, const Glue& X); + + + template + arma_hot inline static void apply_inplace(Mat& out, const T1& X); + + template + arma_hot inline static void apply_inplace_plus(Mat& out, const Glue& X, const sword sign); + + // + + template + arma_inline static uword mul_storage_cost(const TA& A, const TB& B); + + template + arma_hot inline static void apply(Mat& out, const TA& A, const TB& B, const eT val); + + template + arma_hot inline static void apply(Mat& out, const TA& A, const TB& B, const TC& C, const eT val); + + template + arma_hot inline static void apply(Mat& out, const TA& A, const TB& B, const TC& C, const TD& D, const eT val); + }; + + + +class glue_times_diag + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + arma_hot inline static void apply(Mat& out, const Glue& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_times_meat.hpp b/src/armadillo/include/armadillo_bits/glue_times_meat.hpp new file mode 100644 index 0000000..0dc8a02 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_times_meat.hpp @@ -0,0 +1,952 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_times +//! @{ + + + +template +template +inline +void +glue_times_redirect2_helper::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const partial_unwrap tmp1(X.A); + const partial_unwrap tmp2(X.B); + + const typename partial_unwrap::stored_type& A = tmp1.M; + const typename partial_unwrap::stored_type& B = tmp2.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); + + const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out); + + if(alias == false) + { + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times) + > + (out, A, B, alpha); + } + else + { + Mat tmp; + + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times) + > + (tmp, A, B, alpha); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_times_redirect2_helper::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(arma_config::optimise_invexpr && (strip_inv::do_inv_gen || strip_inv::do_inv_spd)) + { + // replace inv(A)*B with solve(A,B) + + arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B"); + + const strip_inv A_strip(X.A); + + Mat A = A_strip.M; + + arma_debug_check( (A.is_square() == false), "inv(): given matrix must be square sized" ); + + if( (strip_inv::do_inv_spd) && (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) ) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + + const unwrap_check B_tmp(X.B, out); + const Mat& B = B_tmp.M; + + arma_debug_assert_mul_size(A, B, "matrix multiplication"); + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, B) : auxlib::solve_square_fast(out, A, B); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); + } + + return; + } + + if(arma_config::optimise_invexpr && strip_inv::do_inv_spd) + { + // replace A*inv_sympd(B) with trans( solve(trans(B),trans(A)) ) + // transpose of B is avoided as B is explicitly marked as symmetric + + arma_extra_debug_print("glue_times_redirect<2>::apply(): detected A*inv_sympd(B)"); + + const Mat At = trans(X.A); + + const strip_inv B_strip(X.B); + + Mat B = B_strip.M; + + arma_debug_check( (B.is_square() == false), "inv_sympd(): given matrix must be square sized" ); + + if( (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) ) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + + arma_debug_assert_mul_size(At.n_cols, At.n_rows, B.n_rows, B.n_cols, "matrix multiplication"); + + const bool status = auxlib::solve_sympd_fast(out, B, At); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); + } + + out = trans(out); + + return; + } + + glue_times_redirect2_helper::apply(out, X); + } + + + +template +template +inline +void +glue_times_redirect3_helper::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // we have exactly 3 objects + // hence we can safely expand X as X.A.A, X.A.B and X.B + + const partial_unwrap tmp1(X.A.A); + const partial_unwrap tmp2(X.A.B); + const partial_unwrap tmp3(X.B ); + + const typename partial_unwrap::stored_type& A = tmp1.M; + const typename partial_unwrap::stored_type& B = tmp2.M; + const typename partial_unwrap::stored_type& C = tmp3.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0); + + const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out) || tmp3.is_alias(out); + + if(alias == false) + { + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times) + > + (out, A, B, C, alpha); + } + else + { + Mat tmp; + + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times) + > + (tmp, A, B, C, alpha); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_times_redirect3_helper::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(arma_config::optimise_invexpr && (strip_inv::do_inv_gen || strip_inv::do_inv_spd)) + { + // replace inv(A)*B*C with solve(A,B*C); + + arma_extra_debug_print("glue_times_redirect<3>::apply(): detected inv(A)*B*C"); + + const strip_inv A_strip(X.A.A); + + Mat A = A_strip.M; + + arma_debug_check( (A.is_square() == false), "inv(): given matrix must be square sized" ); + + const partial_unwrap tmp2(X.A.B); + const partial_unwrap tmp3(X.B ); + + const typename partial_unwrap::stored_type& B = tmp2.M; + const typename partial_unwrap::stored_type& C = tmp3.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (tmp2.get_val() * tmp3.get_val()) : eT(0); + + Mat BC; + + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times) + > + (BC, B, C, alpha); + + arma_debug_assert_mul_size(A, BC, "matrix multiplication"); + + if( (strip_inv::do_inv_spd) && (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) ) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); + } + + return; + } + + + if(arma_config::optimise_invexpr && (strip_inv::do_inv_gen || strip_inv::do_inv_spd)) + { + // replace A*inv(B)*C with A*solve(B,C) + + arma_extra_debug_print("glue_times_redirect<3>::apply(): detected A*inv(B)*C"); + + const strip_inv B_strip(X.A.B); + + Mat B = B_strip.M; + + arma_debug_check( (B.is_square() == false), "inv(): given matrix must be square sized" ); + + const unwrap C_tmp(X.B); + const Mat& C = C_tmp.M; + + arma_debug_assert_mul_size(B, C, "matrix multiplication"); + + if( (strip_inv::do_inv_spd) && (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) ) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + + Mat solve_result; + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); + return; + } + + const partial_unwrap_check tmp1(X.A.A, out); + + const typename partial_unwrap_check::stored_type& A = tmp1.M; + + const bool use_alpha = partial_unwrap_check::do_times; + const eT alpha = use_alpha ? tmp1.get_val() : eT(0); + + glue_times::apply + < + eT, + partial_unwrap_check::do_trans, + false, + partial_unwrap_check::do_times + > + (out, A, solve_result, alpha); + + return; + } + + + glue_times_redirect3_helper::apply(out, X); + } + + + +template +template +inline +void +glue_times_redirect::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const partial_unwrap tmp1(X.A); + const partial_unwrap tmp2(X.B); + + const typename partial_unwrap::stored_type& A = tmp1.M; + const typename partial_unwrap::stored_type& B = tmp2.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); + + const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out); + + if(alias == false) + { + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times) + > + (out, A, B, alpha); + } + else + { + Mat tmp; + + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times) + > + (tmp, A, B, alpha); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_times_redirect<2>::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + glue_times_redirect2_helper< is_supported_blas_type::value >::apply(out, X); + } + + + +template +inline +void +glue_times_redirect<3>::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + glue_times_redirect3_helper< is_supported_blas_type::value >::apply(out, X); + } + + + +template +inline +void +glue_times_redirect<4>::apply(Mat& out, const Glue< Glue< Glue, T3, glue_times>, T4, glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // there is exactly 4 objects + // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B + + const partial_unwrap tmp1(X.A.A.A); + const partial_unwrap tmp2(X.A.A.B); + const partial_unwrap tmp3(X.A.B ); + const partial_unwrap tmp4(X.B ); + + const typename partial_unwrap::stored_type& A = tmp1.M; + const typename partial_unwrap::stored_type& B = tmp2.M; + const typename partial_unwrap::stored_type& C = tmp3.M; + const typename partial_unwrap::stored_type& D = tmp4.M; + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0); + + const bool alias = tmp1.is_alias(out) || tmp2.is_alias(out) || tmp3.is_alias(out) || tmp4.is_alias(out); + + if(alias == false) + { + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times) + > + (out, A, B, C, D, alpha); + } + else + { + Mat tmp; + + glue_times::apply + < + eT, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + partial_unwrap::do_trans, + (partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times || partial_unwrap::do_times) + > + (tmp, A, B, C, D, alpha); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +glue_times::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + constexpr uword N_mat = 1 + depth_lhs< glue_times, Glue >::num; + + arma_extra_debug_print(arma_str::format("N_mat = %u") % N_mat); + + glue_times_redirect::apply(out, X); + } + + + +template +inline +void +glue_times::apply_inplace(Mat& out, const T1& X) + { + arma_extra_debug_sigprint(); + + out = out * X; + } + + + +template +inline +void +glue_times::apply_inplace_plus(Mat& out, const Glue& X, const sword sign) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + if( (is_outer_product::value) || (has_op_inv_any::value) || (has_op_inv_any::value) ) + { + // partial workaround for corner cases + + const Mat tmp(X); + + if(sign > sword(0)) { out += tmp; } else { out -= tmp; } + + return; + } + + const partial_unwrap_check tmp1(X.A, out); + const partial_unwrap_check tmp2(X.B, out); + + typedef typename partial_unwrap_check::stored_type TA; + typedef typename partial_unwrap_check::stored_type TB; + + const TA& A = tmp1.M; + const TB& B = tmp2.M; + + const bool do_trans_A = partial_unwrap_check::do_trans; + const bool do_trans_B = partial_unwrap_check::do_trans; + + const bool use_alpha = partial_unwrap_check::do_times || partial_unwrap_check::do_times || (sign < sword(0)); + + const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0); + + arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); + + const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols); + const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows); + + arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, ( (sign > sword(0)) ? "addition" : "subtraction" ) ); + + if(out.n_elem == 0) { return; } + + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + else + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A, alpha, eT(1)); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::yes) ) { herk::apply(out, A, T(0), T(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A, alpha, eT(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A, alpha, eT(1)); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::yes) ) { herk::apply(out, A, T(0), T(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A, alpha, eT(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); } + else { gemm::apply(out, A, B, alpha, eT(1)); } + } + } + + + +template +arma_inline +uword +glue_times::mul_storage_cost(const TA& A, const TB& B) + { + const uword final_A_n_rows = (do_trans_A == false) ? ( TA::is_row ? 1 : A.n_rows ) : ( TA::is_col ? 1 : A.n_cols ); + const uword final_B_n_cols = (do_trans_B == false) ? ( TB::is_col ? 1 : B.n_cols ) : ( TB::is_row ? 1 : B.n_rows ); + + return final_A_n_rows * final_B_n_cols; + } + + + +template + < + typename eT, + const bool do_trans_A, + const bool do_trans_B, + const bool use_alpha, + typename TA, + typename TB + > +inline +void +glue_times::apply + ( + Mat& out, + const TA& A, + const TB& B, + const eT alpha + ) + { + arma_extra_debug_sigprint(); + + //arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); + arma_debug_assert_trans_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols); + const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows); + + out.set_size(final_n_rows, final_n_cols); + + if( (A.n_elem == 0) || (B.n_elem == 0) ) { out.zeros(); return; } + + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr()); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr()); } + else { gemm::apply(out, A, B ); } + } + else + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha); } + else { gemm::apply(out, A, B, alpha); } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr()); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr()); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A ); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::yes) ) { herk::apply(out, A ); } + else { gemm::apply(out, A, B ); } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha); } + else if( (B.n_cols == 1) || (TB::is_col) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A, alpha); } + else { gemm::apply(out, A, B, alpha); } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr()); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr()); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A ); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::yes) ) { herk::apply(out, A ); } + else { gemm::apply(out, A, B ); } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) + { + if( ((A.n_rows == 1) || (TA::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha); } + else if( (void_ptr(&A) == void_ptr(&B)) && (is_cx::no) ) { syrk::apply(out, A, alpha); } + else { gemm::apply(out, A, B, alpha); } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr()); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr()); } + else { gemm::apply(out, A, B ); } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) + { + if( ((A.n_cols == 1) || (TA::is_col)) && (is_cx::no) ) { gemv::apply(out.memptr(), B, A.memptr(), alpha); } + else if( ((B.n_rows == 1) || (TB::is_row)) && (is_cx::no) ) { gemv::apply(out.memptr(), A, B.memptr(), alpha); } + else { gemm::apply(out, A, B, alpha); } + } + } + + + +template + < + typename eT, + const bool do_trans_A, + const bool do_trans_B, + const bool do_trans_C, + const bool use_alpha, + typename TA, + typename TB, + typename TC + > +inline +void +glue_times::apply + ( + Mat& out, + const TA& A, + const TB& B, + const TC& C, + const eT alpha + ) + { + arma_extra_debug_sigprint(); + + Mat tmp; + + const uword storage_cost_AB = glue_times::mul_storage_cost(A, B); + const uword storage_cost_BC = glue_times::mul_storage_cost(B, C); + + if(storage_cost_AB <= storage_cost_BC) + { + // out = (A*B)*C + + glue_times::apply(tmp, A, B, alpha); + glue_times::apply(out, tmp, C, eT(0)); + } + else + { + // out = A*(B*C) + + glue_times::apply(tmp, B, C, alpha); + glue_times::apply(out, A, tmp, eT(0)); + } + } + + + +template + < + typename eT, + const bool do_trans_A, + const bool do_trans_B, + const bool do_trans_C, + const bool do_trans_D, + const bool use_alpha, + typename TA, + typename TB, + typename TC, + typename TD + > +inline +void +glue_times::apply + ( + Mat& out, + const TA& A, + const TB& B, + const TC& C, + const TD& D, + const eT alpha + ) + { + arma_extra_debug_sigprint(); + + Mat tmp; + + const uword storage_cost_AC = glue_times::mul_storage_cost(A, C); + const uword storage_cost_BD = glue_times::mul_storage_cost(B, D); + + if(storage_cost_AC <= storage_cost_BD) + { + // out = (A*B*C)*D + + glue_times::apply(tmp, A, B, C, alpha); + + glue_times::apply(out, tmp, D, eT(0)); + } + else + { + // out = A*(B*C*D) + + glue_times::apply(tmp, B, C, D, alpha); + + glue_times::apply(out, A, tmp, eT(0)); + } + } + + + +// +// glue_times_diag + + +template +inline +void +glue_times_diag::apply(Mat& actual_out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const strip_diagmat S1(X.A); + const strip_diagmat S2(X.B); + + typedef typename strip_diagmat::stored_type T1_stripped; + typedef typename strip_diagmat::stored_type T2_stripped; + + if( (strip_diagmat::do_diagmat == true) && (strip_diagmat::do_diagmat == false) ) + { + arma_extra_debug_print("glue_times_diag::apply(): diagmat(A) * B"); + + const diagmat_proxy A(S1.M); + + const quasi_unwrap UB(X.B); + const Mat& B = UB.M; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_length = (std::min)(A_n_rows, A_n_cols); + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication"); + + const bool is_alias = (A.is_alias(actual_out) || UB.is_alias(actual_out)); + + if(is_alias) { arma_extra_debug_print("glue_times_diag::apply(): aliasing detected"); } + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + out.zeros(A_n_rows, B_n_cols); + + for(uword col=0; col < B_n_cols; ++col) + { + eT* out_coldata = out.colptr(col); + const eT* B_coldata = B.colptr(col); + + for(uword i=0; i < A_length; ++i) { out_coldata[i] = A[i] * B_coldata[i]; } + } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + else + if( (strip_diagmat::do_diagmat == false) && (strip_diagmat::do_diagmat == true) ) + { + arma_extra_debug_print("glue_times_diag::apply(): A * diagmat(B)"); + + const quasi_unwrap UA(X.A); + const Mat& A = UA.M; + + const diagmat_proxy B(S2.M); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + const uword B_length = (std::min)(B_n_rows, B_n_cols); + + arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication"); + + const bool is_alias = (UA.is_alias(actual_out) || B.is_alias(actual_out)); + + if(is_alias) { arma_extra_debug_print("glue_times_diag::apply(): aliasing detected"); } + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + out.zeros(A_n_rows, B_n_cols); + + for(uword col=0; col < B_length; ++col) + { + const eT val = B[col]; + + eT* out_coldata = out.colptr(col); + const eT* A_coldata = A.colptr(col); + + for(uword i=0; i < A_n_rows; ++i) { out_coldata[i] = A_coldata[i] * val; } + } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + else + if( (strip_diagmat::do_diagmat == true) && (strip_diagmat::do_diagmat == true) ) + { + arma_extra_debug_print("glue_times_diag::apply(): diagmat(A) * diagmat(B)"); + + const diagmat_proxy A(S1.M); + const diagmat_proxy B(S2.M); + + arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const bool is_alias = (A.is_alias(actual_out) || B.is_alias(actual_out)); + + if(is_alias) { arma_extra_debug_print("glue_times_diag::apply(): aliasing detected"); } + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + out.zeros(A.n_rows, B.n_cols); + + const uword A_length = (std::min)(A.n_rows, A.n_cols); + const uword B_length = (std::min)(B.n_rows, B.n_cols); + + const uword N = (std::min)(A_length, B_length); + + for(uword i=0; i < N; ++i) { out.at(i,i) = A[i] * B[i]; } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_times_misc_bones.hpp b/src/armadillo/include/armadillo_bits/glue_times_misc_bones.hpp new file mode 100644 index 0000000..ca01a1c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_times_misc_bones.hpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_times_misc +//! @{ + + + +class dense_sparse_helper + { + public: + + template + arma_inline static typename arma_not_cx::result dot(const eT* A_mem, const SpMat& B, const uword col); + + template + arma_inline static typename arma_cx_only::result dot(const eT* A_mem, const SpMat& B, const uword col); + }; + + + +class glue_times_dense_sparse + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const SpToDGlue& expr); + + template + inline static void apply_noalias(Mat& out, const T1& x, const T2& y); + + template + inline static void apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y); + }; + + + +class glue_times_sparse_dense + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const SpToDGlue& expr); + + template + inline static void apply_noalias(Mat& out, const T1& x, const T2& y); + + template + inline static void apply_noalias_trans(Mat& out, const T1& x, const T2& y); + + template + inline static void apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_times_misc_meat.hpp b/src/armadillo/include/armadillo_bits/glue_times_misc_meat.hpp new file mode 100644 index 0000000..cafbb98 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_times_misc_meat.hpp @@ -0,0 +1,646 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_times_misc +//! @{ + + + +template +arma_inline +typename arma_not_cx::result +dense_sparse_helper::dot(const eT* A_mem, const SpMat& B, const uword col) + { + arma_extra_debug_sigprint(); + + uword col_offset = B.col_ptrs[col ]; + const uword next_col_offset = B.col_ptrs[col + 1]; + + const uword* start_ptr = &(B.row_indices[ col_offset]); + const uword* end_ptr = &(B.row_indices[next_col_offset]); + + const eT* B_values = B.values; + + eT acc = eT(0); + + for(const uword* ptr = start_ptr; ptr != end_ptr; ++ptr) + { + const uword index = (*ptr); + + acc += A_mem[index] * B_values[col_offset]; + + ++col_offset; + } + + return acc; + } + + + +template +arma_inline +typename arma_cx_only::result +dense_sparse_helper::dot(const eT* A_mem, const SpMat& B, const uword col) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + uword col_offset = B.col_ptrs[col ]; + const uword next_col_offset = B.col_ptrs[col + 1]; + + const uword* start_ptr = &(B.row_indices[ col_offset]); + const uword* end_ptr = &(B.row_indices[next_col_offset]); + + const eT* B_values = B.values; + + T acc_real = T(0); + T acc_imag = T(0); + + for(const uword* ptr = start_ptr; ptr != end_ptr; ++ptr) + { + const uword index = (*ptr); + + const std::complex& X = A_mem[index]; + const std::complex& Y = B_values[col_offset]; + + const T a = X.real(); + const T b = X.imag(); + + const T c = Y.real(); + const T d = Y.imag(); + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + + ++col_offset; + } + + return std::complex(acc_real, acc_imag); + } + + + +template +inline +void +glue_times_dense_sparse::apply(Mat& out, const SpToDGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_op_diagmat::value) { out = SpMat(expr.A) * expr.B; return; } // SpMat has specialised handling for op_diagmat + + const quasi_unwrap UA(expr.A); + + if(UA.is_alias(out)) + { + Mat tmp; + + glue_times_dense_sparse::apply_noalias(tmp, UA.M, expr.B); + + out.steal_mem(tmp); + } + else + { + glue_times_dense_sparse::apply_noalias(out, UA.M, expr.B); + } + } + + + +template +inline +void +glue_times_dense_sparse::apply_noalias(Mat& out, const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(x); + const Mat& A = UA.M; + + const unwrap_spmat UB(y); + const SpMat& B = UB.M; + + arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + out.set_size(A.n_rows, B.n_cols); + + if((A.n_elem == 0) || (B.n_nonzero == 0)) { out.zeros(); return; } + + if((resolves_to_rowvector::value) || (A.n_rows == 1)) + { + arma_extra_debug_print("using row vector specialisation"); + + if( (arma_config::openmp) && (mp_thread_limit::in_parallel() == false) && (B.n_cols >= 2) && mp_gate::eval(B.n_nonzero) ) + { + #if defined(ARMA_USE_OPENMP) + { + arma_extra_debug_print("openmp implementation"); + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + + const uword B_n_cols = B.n_cols; + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < B_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(A_mem, B, col); + } + } + #endif + } + else + { + arma_extra_debug_print("serial implementation"); + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + + const uword B_n_cols = B.n_cols; + + for(uword col=0; col < B_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(A_mem, B, col); + } + } + } + else + if( (arma_config::openmp) && (mp_thread_limit::in_parallel() == false) && (A.n_rows <= (A.n_cols / uword(100))) ) + { + #if defined(ARMA_USE_OPENMP) + { + arma_extra_debug_print("using parallelised multiplication"); + + const uword B_n_cols = B.n_cols; + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < B_n_cols; ++i) + { + const uword col_offset_1 = B.col_ptrs[i ]; + const uword col_offset_2 = B.col_ptrs[i+1]; + + const uword col_offset_delta = col_offset_2 - col_offset_1; + + const uvec indices(const_cast(&(B.row_indices[col_offset_1])), col_offset_delta, false, false); + const Col B_col(const_cast< eT*>(&( B.values[col_offset_1])), col_offset_delta, false, false); + + out.col(i) = A.cols(indices) * B_col; + } + } + #endif + } + else + { + arma_extra_debug_print("using standard multiplication"); + + out.zeros(); + + typename SpMat::const_iterator B_it = B.begin(); + + const uword nnz = B.n_nonzero; + const uword out_n_rows = out.n_rows; + + for(uword count = 0; count < nnz; ++count, ++B_it) + { + const eT B_it_val = (*B_it); + const uword B_it_col = B_it.col(); + const uword B_it_row = B_it.row(); + + const eT* A_col = A.colptr(B_it_row); + eT* out_col = out.colptr(B_it_col); + + for(uword row = 0; row < out_n_rows; ++row) + { + out_col[row] += A_col[row] * B_it_val; + } + } + } + } + + + +template +inline +void +glue_times_dense_sparse::apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const quasi_unwrap UA(X); + const unwrap_spmat UB(Y); + + const Mat& A = UA.M; + const SpMat& B = UB.M; + + const Mat AA = conv_to< Mat >::from(A); + + const SpMat& BB = reinterpret_cast< const SpMat& >(B); + + glue_times_dense_sparse::apply_noalias(out, AA, BB); + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const quasi_unwrap UA(X); + const unwrap_spmat UB(Y); + + const Mat& A = UA.M; + const SpMat& B = UB.M; + + const Mat& AA = reinterpret_cast< const Mat& >(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + glue_times_dense_sparse::apply_noalias(out, AA, BB); + } + else + { + // upgrade T1 and T2 + + const quasi_unwrap UA(X); + const unwrap_spmat UB(Y); + + const Mat& A = UA.M; + const SpMat& B = UB.M; + + const Mat AA = conv_to< Mat >::from(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + glue_times_dense_sparse::apply_noalias(out, AA, BB); + } + } + + + +// + + + +template +inline +void +glue_times_sparse_dense::apply(Mat& out, const SpToDGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_op_diagmat::value) { out = expr.A * SpMat(expr.B); return; } // SpMat has specialised handling for op_diagmat + + const quasi_unwrap UB(expr.B); + + if((sp_strip_trans::do_htrans && is_cx::no) || (sp_strip_trans::do_strans)) + { + arma_extra_debug_print("detected non-conjugate transpose of A"); + + const sp_strip_trans x_strip(expr.A); + + if(UB.is_alias(out)) + { + Mat tmp; + + glue_times_sparse_dense::apply_noalias_trans(tmp, x_strip.M, UB.M); + + out.steal_mem(tmp); + } + else + { + glue_times_sparse_dense::apply_noalias_trans(out, x_strip.M, UB.M); + } + } + else + { + if(UB.is_alias(out)) + { + Mat tmp; + + glue_times_sparse_dense::apply_noalias(tmp, expr.A, UB.M); + + out.steal_mem(tmp); + } + else + { + glue_times_sparse_dense::apply_noalias(out, expr.A, UB.M); + } + } + } + + + +template +inline +void +glue_times_sparse_dense::apply_noalias(Mat& out, const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(x); + const SpMat& A = UA.M; + + const quasi_unwrap UB(y); + const Mat& B = UB.M; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication"); + + if((resolves_to_colvector::value) || (B_n_cols == 1)) + { + arma_extra_debug_print("using column vector specialisation"); + + out.zeros(A_n_rows, 1); + + eT* out_mem = out.memptr(); + const eT* B_mem = B.memptr(); + + typename SpMat::const_iterator A_it = A.begin(); + + const uword nnz = A.n_nonzero; + + for(uword count = 0; count < nnz; ++count, ++A_it) + { + const eT A_it_val = (*A_it); + const uword A_it_row = A_it.row(); + const uword A_it_col = A_it.col(); + + out_mem[A_it_row] += A_it_val * B_mem[A_it_col]; + } + } + else + if(B_n_cols >= (B_n_rows / uword(100))) + { + arma_extra_debug_print("using transpose-based multiplication"); + + const SpMat At = A.st(); + const Mat Bt = B.st(); + + if(A_n_rows == B_n_cols) + { + glue_times_dense_sparse::apply_noalias(out, Bt, At); + + op_strans::apply_mat(out, out); // since 'out' is square-sized, this will do an inplace transpose + } + else + { + Mat tmp; + + glue_times_dense_sparse::apply_noalias(tmp, Bt, At); + + op_strans::apply_mat(out, tmp); + } + } + else + { + arma_extra_debug_print("using standard multiplication"); + + out.zeros(A_n_rows, B_n_cols); + + typename SpMat::const_iterator A_it = A.begin(); + + const uword nnz = A.n_nonzero; + + for(uword count = 0; count < nnz; ++count, ++A_it) + { + const eT A_it_val = (*A_it); + const uword A_it_row = A_it.row(); + const uword A_it_col = A_it.col(); + + for(uword col = 0; col < B_n_cols; ++col) + { + out.at(A_it_row, col) += A_it_val * B.at(A_it_col, col); + } + } + } + } + + + +template +inline +void +glue_times_sparse_dense::apply_noalias_trans(Mat& out, const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(x); + const SpMat& A = UA.M; // NOTE: this is the given matrix without the transpose operation applied + + const quasi_unwrap UB(y); + const Mat& B = UB.M; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + arma_debug_assert_mul_size(A_n_cols, A_n_rows, B_n_rows, B_n_cols, "matrix multiplication"); + + if((resolves_to_colvector::value) || (B_n_cols == 1)) + { + arma_extra_debug_print("using column vector specialisation (avoiding transpose of A)"); + + if( (arma_config::openmp) && (mp_thread_limit::in_parallel() == false) && (A_n_cols >= 2) && mp_gate::eval(A.n_nonzero) ) + { + arma_extra_debug_print("opemp implementation"); + + #if defined(ARMA_USE_OPENMP) + { + out.zeros(A_n_cols, 1); + + eT* out_mem = out.memptr(); + const eT* B_mem = B.memptr(); + + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < A_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(B_mem, A, col); + } + } + #endif + } + else + { + arma_extra_debug_print("serial implementation"); + + out.zeros(A_n_cols, 1); + + eT* out_mem = out.memptr(); + const eT* B_mem = B.memptr(); + + for(uword col=0; col < A_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(B_mem, A, col); + } + } + } + else + if(B_n_cols >= (B_n_rows / uword(100))) + { + arma_extra_debug_print("using transpose-based multiplication (avoiding transpose of A)"); + + const Mat Bt = B.st(); + + if(A_n_cols == B_n_cols) + { + glue_times_dense_sparse::apply_noalias(out, Bt, A); + + op_strans::apply_mat(out, out); // since 'out' is square-sized, this will do an inplace transpose + } + else + { + Mat tmp; + + glue_times_dense_sparse::apply_noalias(tmp, Bt, A); + + op_strans::apply_mat(out, tmp); + } + } + else + { + arma_extra_debug_print("using standard multiplication (avoiding transpose of A)"); + + out.zeros(A_n_cols, B_n_cols); + + typename SpMat::const_iterator A_it = A.begin(); + + const uword nnz = A.n_nonzero; + + for(uword count = 0; count < nnz; ++count, ++A_it) + { + const eT A_it_val = (*A_it); + const uword A_it_row = A_it.row(); + const uword A_it_col = A_it.col(); + + for(uword col = 0; col < B_n_cols; ++col) + { + out.at(A_it_col, col) += A_it_val * B.at(A_it_row, col); + } + } + } + } + + + +template +inline +void +glue_times_sparse_dense::apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const unwrap_spmat UA(X); + const quasi_unwrap UB(Y); + + const SpMat& A = UA.M; + const Mat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const Mat& BB = reinterpret_cast< const Mat& >(B); + + glue_times_sparse_dense::apply_noalias(out, AA, BB); + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const unwrap_spmat UA(X); + const quasi_unwrap UB(Y); + + const SpMat& A = UA.M; + const Mat& B = UB.M; + + const SpMat& AA = reinterpret_cast< const SpMat& >(A); + + const Mat BB = conv_to< Mat >::from(B); + + glue_times_sparse_dense::apply_noalias(out, AA, BB); + } + else + { + // upgrade T1 and T2 + + const unwrap_spmat UA(X); + const quasi_unwrap UB(Y); + + const SpMat& A = UA.M; + const Mat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const Mat BB = conv_to< Mat >::from(B); + + glue_times_sparse_dense::apply_noalias(out, AA, BB); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_toeplitz_bones.hpp b/src/armadillo/include/armadillo_bits/glue_toeplitz_bones.hpp new file mode 100644 index 0000000..338de14 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_toeplitz_bones.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_toeplitz +//! @{ + + + +class glue_toeplitz + : public traits_glue_default + { + public: + + template inline static void apply(Mat& out, const Glue& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_toeplitz_meat.hpp b/src/armadillo/include/armadillo_bits/glue_toeplitz_meat.hpp new file mode 100644 index 0000000..77f9a09 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_toeplitz_meat.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_toeplitz +//! @{ + + + +template +inline +void +glue_toeplitz::apply(Mat& out, const Glue& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_check tmp1(in.A, out); + const unwrap_check tmp2(in.B, out); + + const Mat& A = tmp1.M; + const Mat& B = tmp2.M; + + arma_debug_check + ( + ( ((A.is_vec() == false) && (A.is_empty() == false)) || ((B.is_vec() == false) && (B.is_empty() == false)) ), + "toeplitz(): given object must be a vector" + ); + + const uword A_N = A.n_elem; + const uword B_N = B.n_elem; + + const eT* A_mem = A.memptr(); + const eT* B_mem = B.memptr(); + + out.set_size(A_N, B_N); + + if( out.is_empty() ) { return; } + + for(uword col=0; col < B_N; ++col) + { + eT* col_mem = out.colptr(col); + + uword i = 0; + for(uword row=col; row < A_N; ++row, ++i) { col_mem[row] = A_mem[i]; } + } + + for(uword row=0; row < A_N; ++row) + { + uword i = 1; + for(uword col=(row+1); col < B_N; ++col, ++i) { out.at(row,col) = B_mem[i]; } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_trapz_bones.hpp b/src/armadillo/include/armadillo_bits/glue_trapz_bones.hpp new file mode 100644 index 0000000..8b3019a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_trapz_bones.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_trapz +//! @{ + + + +class glue_trapz + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; + }; + + template inline static void apply(Mat& out, const Glue& in); + + template inline static void apply_noalias(Mat& out, const Mat& X, const Mat& Y, const uword dim); + }; + + + +class op_trapz + : public traits_op_xvec + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static void apply_noalias(Mat& out, const Mat& Y, const uword dim); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/glue_trapz_meat.hpp b/src/armadillo/include/armadillo_bits/glue_trapz_meat.hpp new file mode 100644 index 0000000..ed7b577 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/glue_trapz_meat.hpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_trapz +//! @{ + + + +template +inline +void +glue_trapz::apply(Mat& out, const Glue& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword; + + const quasi_unwrap UX(in.A); + const quasi_unwrap UY(in.B); + + if( UX.is_alias(out) || UY.is_alias(out) ) + { + Mat tmp; + + glue_trapz::apply_noalias(tmp, UX.M, UY.M, dim); + + out.steal_mem(tmp); + } + else + { + glue_trapz::apply_noalias(out, UX.M, UY.M, dim); + } + } + + + +template +inline +void +glue_trapz::apply_noalias(Mat& out, const Mat& X, const Mat& Y, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (dim > 1), "trapz(): argument 'dim' must be 0 or 1" ); + + arma_debug_check( ((X.is_vec() == false) && (X.is_empty() == false)), "trapz(): argument 'X' must be a vector" ); + + const uword N = X.n_elem; + + if(dim == 0) + { + arma_debug_check( (N != Y.n_rows), "trapz(): length of X must equal the number of rows in Y when dim=0" ); + } + else + if(dim == 1) + { + arma_debug_check( (N != Y.n_cols), "trapz(): length of X must equal the number of columns in Y when dim=1" ); + } + + if(N <= 1) + { + if(dim == 0) { out.zeros(1, Y.n_cols); } + else if(dim == 1) { out.zeros(Y.n_rows, 1); } + + return; + } + + const Col vec_X( const_cast(X.memptr()), X.n_elem, false, true ); + + const Col diff_X = diff(vec_X); + + if(dim == 0) + { + const Row diff_X_t( const_cast(diff_X.memptr()), diff_X.n_elem, false, true ); + + out = diff_X_t * (0.5 * (Y.rows(0, N-2) + Y.rows(1, N-1))); + } + else + if(dim == 1) + { + out = (0.5 * (Y.cols(0, N-2) + Y.cols(1, N-1))) * diff_X; + } + } + + + +template +inline +void +op_trapz::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + + const quasi_unwrap UY(in.m); + + if(UY.is_alias(out)) + { + Mat tmp; + + op_trapz::apply_noalias(tmp, UY.M, dim); + + out.steal_mem(tmp); + } + else + { + op_trapz::apply_noalias(out, UY.M, dim); + } + } + + + +template +inline +void +op_trapz::apply_noalias(Mat& out, const Mat& Y, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (dim > 1), "trapz(): argument 'dim' must be 0 or 1" ); + + uword N = 0; + + if(dim == 0) { N = Y.n_rows; } + else if(dim == 1) { N = Y.n_cols; } + + if(N <= 1) + { + if(dim == 0) { out.zeros(1, Y.n_cols); } + else if(dim == 1) { out.zeros(Y.n_rows, 1); } + + return; + } + + if(dim == 0) + { + out = sum( (0.5 * (Y.rows(0, N-2) + Y.rows(1, N-1))), 0 ); + } + else + if(dim == 1) + { + out = sum( (0.5 * (Y.cols(0, N-2) + Y.cols(1, N-1))), 1 ); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/gmm_diag_bones.hpp b/src/armadillo/include/armadillo_bits/gmm_diag_bones.hpp new file mode 100644 index 0000000..386a40d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/gmm_diag_bones.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gmm_diag +//! @{ + + +namespace gmm_priv +{ + +template +class gmm_diag + { + public: + + arma_aligned const Mat means; + arma_aligned const Mat dcovs; + arma_aligned const Row hefts; + + // + // + + inline ~gmm_diag(); + inline gmm_diag(); + + inline gmm_diag(const gmm_diag& x); + inline gmm_diag& operator=(const gmm_diag& x); + + inline explicit gmm_diag(const gmm_full& x); + inline gmm_diag& operator=(const gmm_full& x); + + inline gmm_diag(const uword in_n_dims, const uword in_n_gaus); + inline void reset(const uword in_n_dims, const uword in_n_gaus); + inline void reset(); + + template + inline void set_params(const Base& in_means, const Base& in_dcovs, const Base& in_hefts); + + template inline void set_means(const Base& in_means); + template inline void set_dcovs(const Base& in_dcovs); + template inline void set_hefts(const Base& in_hefts); + + inline uword n_dims() const; + inline uword n_gaus() const; + + inline bool load(const std::string name); + inline bool save(const std::string name) const; + + inline Col generate() const; + inline Mat generate(const uword N) const; + + template inline eT log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; + template inline eT log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; + + template inline Row log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; + template inline Row log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; + + template inline eT sum_log_p(const Base& expr) const; + template inline eT sum_log_p(const Base& expr, const uword gaus_id) const; + + template inline eT avg_log_p(const Base& expr) const; + template inline eT avg_log_p(const Base& expr, const uword gaus_id) const; + + template inline uword assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk = nullptr) const; + template inline urowvec assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk = nullptr) const; + + template inline urowvec raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; + template inline Row norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; + + template + inline + bool + learn + ( + const Base& data, + const uword n_gaus, + const gmm_dist_mode& dist_mode, + const gmm_seed_mode& seed_mode, + const uword km_iter, + const uword em_iter, + const eT var_floor, + const bool print_mode + ); + + + template + inline + bool + kmeans_wrapper + ( + Mat& user_means, + const Base& data, + const uword n_gaus, + const gmm_seed_mode& seed_mode, + const uword km_iter, + const bool print_mode + ); + + + // + + protected: + + arma_aligned Mat inv_dcovs; + arma_aligned Row log_det_etc; + arma_aligned Row log_hefts; + arma_aligned Col mah_aux; + + // + + inline void init(const gmm_diag& x); + inline void init(const gmm_full& x); + + inline void init(const uword in_n_dim, const uword in_n_gaus); + + inline void init_constants(); + + inline umat internal_gen_boundaries(const uword N) const; + + inline eT internal_scalar_log_p(const eT* x ) const; + inline eT internal_scalar_log_p(const eT* x, const uword gaus_id) const; + + inline Row internal_vec_log_p(const Mat& X ) const; + inline Row internal_vec_log_p(const Mat& X, const uword gaus_id) const; + + inline eT internal_sum_log_p(const Mat& X ) const; + inline eT internal_sum_log_p(const Mat& X, const uword gaus_id) const; + + inline eT internal_avg_log_p(const Mat& X ) const; + inline eT internal_avg_log_p(const Mat& X, const uword gaus_id) const; + + inline uword internal_scalar_assign(const Mat& X, const gmm_dist_mode& dist_mode) const; + + inline void internal_vec_assign(urowvec& out, const Mat& X, const gmm_dist_mode& dist_mode) const; + + inline void internal_raw_hist(urowvec& hist, const Mat& X, const gmm_dist_mode& dist_mode) const; + + // + + template inline void generate_initial_means(const Mat& X, const gmm_seed_mode& seed); + + template inline void generate_initial_params(const Mat& X, const eT var_floor); + + template inline bool km_iterate(const Mat& X, const uword max_iter, const bool verbose, const char* signature); + + // + + inline bool em_iterate(const Mat& X, const uword max_iter, const eT var_floor, const bool verbose); + + inline void em_update_params(const Mat& X, const umat& boundaries, field< Mat >& t_acc_means, field< Mat >& t_acc_dcovs, field< Col >& t_acc_norm_lhoods, field< Col >& t_gaus_log_lhoods, Col& t_progress_log_lhoods); + + inline void em_generate_acc(const Mat& X, const uword start_index, const uword end_index, Mat& acc_means, Mat& acc_dcovs, Col& acc_norm_lhoods, Col& gaus_log_lhoods, eT& progress_log_lhood) const; + + inline void em_fix_params(const eT var_floor); + }; + +} + + +typedef gmm_priv::gmm_diag gmm_diag; +typedef gmm_priv::gmm_diag fgmm_diag; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/gmm_diag_meat.hpp b/src/armadillo/include/armadillo_bits/gmm_diag_meat.hpp new file mode 100644 index 0000000..1b6681e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/gmm_diag_meat.hpp @@ -0,0 +1,2655 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gmm_diag +//! @{ + + +namespace gmm_priv +{ + + +template +inline +gmm_diag::~gmm_diag() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( (is_same_type::value == false) && (is_same_type::value == false) )); + } + + + +template +inline +gmm_diag::gmm_diag() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +gmm_diag::gmm_diag(const gmm_diag& x) + { + arma_extra_debug_sigprint_this(this); + + init(x); + } + + + +template +inline +gmm_diag& +gmm_diag::operator=(const gmm_diag& x) + { + arma_extra_debug_sigprint(); + + init(x); + + return *this; + } + + + +template +inline +gmm_diag::gmm_diag(const gmm_full& x) + { + arma_extra_debug_sigprint_this(this); + + init(x); + } + + + +template +inline +gmm_diag& +gmm_diag::operator=(const gmm_full& x) + { + arma_extra_debug_sigprint(); + + init(x); + + return *this; + } + + + +template +inline +gmm_diag::gmm_diag(const uword in_n_dims, const uword in_n_gaus) + { + arma_extra_debug_sigprint_this(this); + + init(in_n_dims, in_n_gaus); + } + + + +template +inline +void +gmm_diag::reset() + { + arma_extra_debug_sigprint(); + + init(0, 0); + } + + + +template +inline +void +gmm_diag::reset(const uword in_n_dims, const uword in_n_gaus) + { + arma_extra_debug_sigprint(); + + init(in_n_dims, in_n_gaus); + } + + + +template +template +inline +void +gmm_diag::set_params(const Base& in_means_expr, const Base& in_dcovs_expr, const Base& in_hefts_expr) + { + arma_extra_debug_sigprint(); + + const unwrap tmp1(in_means_expr.get_ref()); + const unwrap tmp2(in_dcovs_expr.get_ref()); + const unwrap tmp3(in_hefts_expr.get_ref()); + + const Mat& in_means = tmp1.M; + const Mat& in_dcovs = tmp2.M; + const Mat& in_hefts = tmp3.M; + + arma_debug_check + ( + (arma::size(in_means) != arma::size(in_dcovs)) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1), + "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes" + ); + + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_diag::set_params(): given means have non-finite values" ); + arma_debug_check( (in_dcovs.internal_has_nonfinite()), "gmm_diag::set_params(): given dcovs have non-finite values" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_diag::set_params(): given hefts have non-finite values" ); + + arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" ); + arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_params(): given hefts have negative values" ); + + const eT s = accu(in_hefts); + + arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_params(): sum of given hefts is not 1" ); + + access::rw(means) = in_means; + access::rw(dcovs) = in_dcovs; + access::rw(hefts) = in_hefts; + + init_constants(); + } + + + +template +template +inline +void +gmm_diag::set_means(const Base& in_means_expr) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(in_means_expr.get_ref()); + + const Mat& in_means = tmp.M; + + arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_diag::set_means(): given means have incompatible size" ); + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_diag::set_means(): given means have non-finite values" ); + + access::rw(means) = in_means; + } + + + +template +template +inline +void +gmm_diag::set_dcovs(const Base& in_dcovs_expr) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(in_dcovs_expr.get_ref()); + + const Mat& in_dcovs = tmp.M; + + arma_debug_check( (arma::size(in_dcovs) != arma::size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size" ); + arma_debug_check( (in_dcovs.internal_has_nonfinite()), "gmm_diag::set_dcovs(): given dcovs have non-finite values" ); + arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_dcovs(): given dcovs have negative or zero values" ); + + access::rw(dcovs) = in_dcovs; + + init_constants(); + } + + + +template +template +inline +void +gmm_diag::set_hefts(const Base& in_hefts_expr) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(in_hefts_expr.get_ref()); + + const Mat& in_hefts = tmp.M; + + arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_diag::set_hefts(): given hefts have non-finite values" ); + arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_hefts(): given hefts have negative values" ); + + const eT s = accu(in_hefts); + + arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_hefts(): sum of given hefts is not 1" ); + + // make sure all hefts are positive and non-zero + + const eT* in_hefts_mem = in_hefts.memptr(); + eT* hefts_mem = access::rw(hefts).memptr(); + + for(uword i=0; i < hefts.n_elem; ++i) + { + hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits::min() ); + } + + access::rw(hefts) /= accu(hefts); + + log_hefts = log(hefts); + } + + + +template +inline +uword +gmm_diag::n_dims() const + { + return means.n_rows; + } + + + +template +inline +uword +gmm_diag::n_gaus() const + { + return means.n_cols; + } + + + +template +inline +bool +gmm_diag::load(const std::string name) + { + arma_extra_debug_sigprint(); + + Cube Q; + + bool status = Q.load(name, arma_binary); + + if( (status == false) || (Q.n_slices != 2) ) + { + reset(); + arma_debug_warn_level(3, "gmm_diag::load(): problem with loading or incompatible format"); + return false; + } + + if( (Q.n_rows < 2) || (Q.n_cols < 1) ) + { + reset(); + return true; + } + + access::rw(hefts) = Q.slice(0).row(0); + access::rw(means) = Q.slice(0).submat(1, 0, Q.n_rows-1, Q.n_cols-1); + access::rw(dcovs) = Q.slice(1).submat(1, 0, Q.n_rows-1, Q.n_cols-1); + + init_constants(); + + return true; + } + + + +template +inline +bool +gmm_diag::save(const std::string name) const + { + arma_extra_debug_sigprint(); + + Cube Q(means.n_rows + 1, means.n_cols, 2, arma_nozeros_indicator()); + + if(Q.n_elem > 0) + { + Q.slice(0).row(0) = hefts; + Q.slice(1).row(0).zeros(); // reserved for future use + + Q.slice(0).submat(1, 0, arma::size(means)) = means; + Q.slice(1).submat(1, 0, arma::size(dcovs)) = dcovs; + } + + const bool status = Q.save(name, arma_binary); + + return status; + } + + + +template +inline +Col +gmm_diag::generate() const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + Col out( ((N_gaus > 0) ? N_dims : uword(0)), fill::randn ); + + if(N_gaus > 0) + { + const double val = randu(); + + double csum = double(0); + uword gaus_id = 0; + + for(uword j=0; j < N_gaus; ++j) + { + csum += hefts[j]; + + if(val <= csum) { gaus_id = j; break; } + } + + out %= sqrt(dcovs.col(gaus_id)); + out += means.col(gaus_id); + } + + return out; + } + + + +template +inline +Mat +gmm_diag::generate(const uword N_vec) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + Mat out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn ); + + if(N_gaus > 0) + { + const eT* hefts_mem = hefts.memptr(); + + const Mat sqrt_dcovs = sqrt(dcovs); + + for(uword i=0; i < N_vec; ++i) + { + const double val = randu(); + + double csum = double(0); + uword gaus_id = 0; + + for(uword j=0; j < N_gaus; ++j) + { + csum += hefts_mem[j]; + + if(val <= csum) { gaus_id = j; break; } + } + + subview_col out_col = out.col(i); + + out_col %= sqrt_dcovs.col(gaus_id); + out_col += means.col(gaus_id); + } + } + + return out; + } + + + +template +template +inline +eT +gmm_diag::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const quasi_unwrap tmp(expr); + + arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); + + return internal_scalar_log_p( tmp.M.memptr() ); + } + + + +template +template +inline +eT +gmm_diag::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk2); + + const quasi_unwrap tmp(expr); + + arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); + + arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" ); + + return internal_scalar_log_p( tmp.M.memptr(), gaus_id ); + } + + + +template +template +inline +Row +gmm_diag::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + return internal_vec_log_p(X); + } + + + +template +template +inline +Row +gmm_diag::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk2); + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + return internal_vec_log_p(X, gaus_id); + } + + + +template +template +inline +eT +gmm_diag::sum_log_p(const Base& expr) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_sum_log_p(X); + } + + + +template +template +inline +eT +gmm_diag::sum_log_p(const Base& expr, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_sum_log_p(X, gaus_id); + } + + + +template +template +inline +eT +gmm_diag::avg_log_p(const Base& expr) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_avg_log_p(X); + } + + + +template +template +inline +eT +gmm_diag::avg_log_p(const Base& expr, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_avg_log_p(X, gaus_id); + } + + + +template +template +inline +uword +gmm_diag::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + return internal_scalar_assign(X, dist); + } + + + +template +template +inline +urowvec +gmm_diag::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + urowvec out; + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + internal_vec_assign(out, X, dist); + + return out; + } + + + +template +template +inline +urowvec +gmm_diag::raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const unwrap tmp(expr.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::raw_hist(): incompatible dimensions" ); + + arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::raw_hist(): unsupported distance mode" ); + + urowvec hist; + + internal_raw_hist(hist, X, dist_mode); + + return hist; + } + + + +template +template +inline +Row +gmm_diag::norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const unwrap tmp(expr.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::norm_hist(): incompatible dimensions" ); + + arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::norm_hist(): unsupported distance mode" ); + + urowvec hist; + + internal_raw_hist(hist, X, dist_mode); + + const uword hist_n_elem = hist.n_elem; + const uword* hist_mem = hist.memptr(); + + eT acc = eT(0); + for(uword i=0; i out(hist_n_elem, arma_nozeros_indicator()); + + eT* out_mem = out.memptr(); + + for(uword i=0; i +template +inline +bool +gmm_diag::learn + ( + const Base& data, + const uword N_gaus, + const gmm_dist_mode& dist_mode, + const gmm_seed_mode& seed_mode, + const uword km_iter, + const uword em_iter, + const eT var_floor, + const bool print_mode + ) + { + arma_extra_debug_sigprint(); + + const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist); + + const bool seed_mode_ok = \ + (seed_mode == keep_existing) + || (seed_mode == static_subset) + || (seed_mode == static_spread) + || (seed_mode == random_subset) + || (seed_mode == random_spread); + + arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" ); + arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown seed_mode" ); + arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance floor is negative" ); + + const unwrap tmp_X(data.get_ref()); + const Mat& X = tmp_X.M; + + if(X.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix is empty" ); return false; } + if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix has non-finite values"); return false; } + + if(N_gaus == 0) { reset(); return true; } + + if(dist_mode == maha_dist) + { + mah_aux = var(X,1,1); + + const uword mah_aux_n_elem = mah_aux.n_elem; + eT* mah_aux_mem = mah_aux.memptr(); + + for(uword i=0; i < mah_aux_n_elem; ++i) + { + const eT val = mah_aux_mem[i]; + + mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1); + } + } + + + // copy current model, in case of failure by k-means and/or EM + + const gmm_diag orig = (*this); + + + // initial means + + if(seed_mode == keep_existing) + { + if(means.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): no existing means" ); return false; } + if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "gmm_diag::learn(): dimensionality mismatch"); return false; } + + // TODO: also check for number of vectors? + } + else + { + if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; } + + reset(X.n_rows, N_gaus); + + if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial means\n"; get_cout_stream().flush(); } + + if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); } + else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); } + } + + + // k-means + + if(km_iter > 0) + { + const arma_ostream_state stream_state(get_cout_stream()); + + bool status = false; + + if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); } + else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); } + + stream_state.restore(get_cout_stream()); + + if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } + } + + + // initial dcovs + + const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits::min(); + + if(seed_mode != keep_existing) + { + if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial covariances\n"; get_cout_stream().flush(); } + + if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); } + else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); } + } + + + // EM algorithm + + if(em_iter > 0) + { + const arma_ostream_state stream_state(get_cout_stream()); + + const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode); + + stream_state.restore(get_cout_stream()); + + if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): EM algorithm failed"); init(orig); return false; } + } + + mah_aux.reset(); + + init_constants(); + + return true; + } + + + +template +template +inline +bool +gmm_diag::kmeans_wrapper + ( + Mat& user_means, + const Base& data, + const uword N_gaus, + const gmm_seed_mode& seed_mode, + const uword km_iter, + const bool print_mode + ) + { + arma_extra_debug_sigprint(); + + const bool seed_mode_ok = \ + (seed_mode == keep_existing) + || (seed_mode == static_subset) + || (seed_mode == static_spread) + || (seed_mode == random_subset) + || (seed_mode == random_spread); + + arma_debug_check( (seed_mode_ok == false), "kmeans(): unknown seed_mode" ); + + const unwrap tmp_X(data.get_ref()); + const Mat& X = tmp_X.M; + + if(X.is_empty() ) { arma_debug_warn_level(3, "kmeans(): given matrix is empty" ); return false; } + if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "kmeans(): given matrix has non-finite values"); return false; } + + if(N_gaus == 0) { reset(); return true; } + + + // initial means + + if(seed_mode == keep_existing) + { + access::rw(means) = user_means; + + if(means.is_empty() ) { arma_debug_warn_level(3, "kmeans(): no existing means" ); return false; } + if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "kmeans(): dimensionality mismatch"); return false; } + + // TODO: also check for number of vectors? + } + else + { + if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "kmeans(): number of vectors is less than number of means"); return false; } + + access::rw(means).zeros(X.n_rows, N_gaus); + + if(print_mode) { get_cout_stream() << "kmeans(): generating initial means\n"; } + + generate_initial_means<1>(X, seed_mode); + } + + + // k-means + + if(km_iter > 0) + { + const arma_ostream_state stream_state(get_cout_stream()); + + bool status = false; + + status = km_iterate<1>(X, km_iter, print_mode, "kmeans()"); + + stream_state.restore(get_cout_stream()); + + if(status == false) { arma_debug_warn_level(3, "kmeans(): clustering failed; not enough data, or too many means requested"); return false; } + } + + return true; + } + + + +// +// +// + + + +template +inline +void +gmm_diag::init(const gmm_diag& x) + { + arma_extra_debug_sigprint(); + + gmm_diag& t = *this; + + if(&t != &x) + { + access::rw(t.means) = x.means; + access::rw(t.dcovs) = x.dcovs; + access::rw(t.hefts) = x.hefts; + + init_constants(); + } + } + + + +template +inline +void +gmm_diag::init(const gmm_full& x) + { + arma_extra_debug_sigprint(); + + access::rw(hefts) = x.hefts; + access::rw(means) = x.means; + + const uword N_dims = x.means.n_rows; + const uword N_gaus = x.means.n_cols; + + access::rw(dcovs).zeros(N_dims,N_gaus); + + for(uword g=0; g < N_gaus; ++g) + { + const Mat& fcov = x.fcovs.slice(g); + + eT* dcov_mem = access::rw(dcovs).colptr(g); + + for(uword d=0; d < N_dims; ++d) + { + dcov_mem[d] = fcov.at(d,d); + } + } + + init_constants(); + } + + + +template +inline +void +gmm_diag::init(const uword in_n_dims, const uword in_n_gaus) + { + arma_extra_debug_sigprint(); + + access::rw(means).zeros(in_n_dims, in_n_gaus); + + access::rw(dcovs).ones(in_n_dims, in_n_gaus); + + access::rw(hefts).set_size(in_n_gaus); + + access::rw(hefts).fill(eT(1) / eT(in_n_gaus)); + + init_constants(); + } + + + +template +inline +void +gmm_diag::init_constants() + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + // + + inv_dcovs.copy_size(dcovs); + + const eT* dcovs_mem = dcovs.memptr(); + eT* inv_dcovs_mem = inv_dcovs.memptr(); + + const uword dcovs_n_elem = dcovs.n_elem; + + for(uword i=0; i < dcovs_n_elem; ++i) + { + inv_dcovs_mem[i] = eT(1) / (std::max)( dcovs_mem[i], std::numeric_limits::min() ); + } + + // + + const eT tmp = (eT(N_dims)/eT(2)) * std::log(Datum::tau); + + log_det_etc.set_size(N_gaus); + + for(uword g=0; g < N_gaus; ++g) + { + const eT* dcovs_colmem = dcovs.colptr(g); + + eT log_det_val = eT(0); + + for(uword d=0; d < N_dims; ++d) + { + log_det_val += std::log( (std::max)( dcovs_colmem[d], std::numeric_limits::min() ) ); + } + + log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val ); + } + + // + + eT* hefts_mem = access::rw(hefts).memptr(); + + for(uword g=0; g < N_gaus; ++g) + { + hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits::min() ); + } + + log_hefts = log(hefts); + } + + + +template +inline +umat +gmm_diag::internal_gen_boundaries(const uword N) const + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_OPENMP) + const uword n_threads_avail = (omp_in_parallel()) ? uword(1) : uword(omp_get_max_threads()); + const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1; + #else + static constexpr uword n_threads = 1; + #endif + + // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n'; + + umat boundaries(2, n_threads, arma_nozeros_indicator()); + + if(N > 0) + { + const uword chunk_size = N / n_threads; + + uword count = 0; + + for(uword t=0; t +inline +eT +gmm_diag::internal_scalar_log_p(const eT* x) const + { + arma_extra_debug_sigprint(); + + const eT* log_hefts_mem = log_hefts.mem; + + const uword N_gaus = means.n_cols; + + if(N_gaus > 0) + { + eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0]; + + for(uword g=1; g < N_gaus; ++g) + { + const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g]; + + log_sum = log_add_exp(log_sum, tmp); + } + + return log_sum; + } + else + { + return -Datum::inf; + } + } + + + +template +inline +eT +gmm_diag::internal_scalar_log_p(const eT* x, const uword g) const + { + arma_extra_debug_sigprint(); + + const eT* mean = means.colptr(g); + const eT* inv_dcov = inv_dcovs.colptr(g); + + const uword N_dims = means.n_rows; + + eT val_i = eT(0); + eT val_j = eT(0); + + uword i,j; + + for(i=0, j=1; j +inline +Row +gmm_diag::internal_vec_log_p(const Mat& X) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); + + const uword N = X.n_cols; + + Row out(N, arma_nozeros_indicator()); + + if(N > 0) + { + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT* out_mem = out.memptr(); + + for(uword i=start_index; i <= end_index; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i) ); + } + } + } + #else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i < N; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i) ); + } + } + #endif + } + + return out; + } + + + +template +inline +Row +gmm_diag::internal_vec_log_p(const Mat& X, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); + arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" ); + + const uword N = X.n_cols; + + Row out(N, arma_nozeros_indicator()); + + if(N > 0) + { + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT* out_mem = out.memptr(); + + for(uword i=start_index; i <= end_index; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); + } + } + } + #else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i < N; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); + } + } + #endif + } + + return out; + } + + + +template +inline +eT +gmm_diag::internal_sum_log_p(const Mat& X) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" ); + + const uword N = X.n_cols; + + if(N == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + Col t_accs(n_threads, arma_zeros_indicator()); + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT t_acc = eT(0); + + for(uword i=start_index; i <= end_index; ++i) + { + t_acc += internal_scalar_log_p( X.colptr(i) ); + } + + t_accs[t] = t_acc; + } + + return eT(accu(t_accs)); + } + #else + { + eT acc = eT(0); + + for(uword i=0; i +inline +eT +gmm_diag::internal_sum_log_p(const Mat& X, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" ); + arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::sum_log_p(): specified gaussian is out of range" ); + + const uword N = X.n_cols; + + if(N == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + Col t_accs(n_threads, arma_zeros_indicator()); + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT t_acc = eT(0); + + for(uword i=start_index; i <= end_index; ++i) + { + t_acc += internal_scalar_log_p( X.colptr(i), gaus_id ); + } + + t_accs[t] = t_acc; + } + + return eT(accu(t_accs)); + } + #else + { + eT acc = eT(0); + + for(uword i=0; i +inline +eT +gmm_diag::internal_avg_log_p(const Mat& X) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" ); + + const uword N = X.n_cols; + + if(N == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + field< running_mean_scalar > t_running_means(n_threads); + + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + running_mean_scalar& current_running_mean = t_running_means[t]; + + for(uword i=start_index; i <= end_index; ++i) + { + current_running_mean( internal_scalar_log_p( X.colptr(i) ) ); + } + } + + + eT avg = eT(0); + + for(uword t=0; t < n_threads; ++t) + { + running_mean_scalar& current_running_mean = t_running_means[t]; + + const eT w = eT(current_running_mean.count()) / eT(N); + + avg += w * current_running_mean.mean(); + } + + return avg; + } + #else + { + running_mean_scalar running_mean; + + for(uword i=0; i +inline +eT +gmm_diag::internal_avg_log_p(const Mat& X, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" ); + arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::avg_log_p(): specified gaussian is out of range" ); + + const uword N = X.n_cols; + + if(N == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + field< running_mean_scalar > t_running_means(n_threads); + + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + running_mean_scalar& current_running_mean = t_running_means[t]; + + for(uword i=start_index; i <= end_index; ++i) + { + current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) ); + } + } + + + eT avg = eT(0); + + for(uword t=0; t < n_threads; ++t) + { + running_mean_scalar& current_running_mean = t_running_means[t]; + + const eT w = eT(current_running_mean.count()) / eT(N); + + avg += w * current_running_mean.mean(); + } + + return avg; + } + #else + { + running_mean_scalar running_mean; + + for(uword i=0; i +inline +uword +gmm_diag::internal_scalar_assign(const Mat& X, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" ); + arma_debug_check( (N_gaus == 0), "gmm_diag::assign(): model has no means" ); + + const eT* X_mem = X.colptr(0); + + if(dist_mode == eucl_dist) + { + eT best_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_dist = distance::eval(N_dims, X_mem, means.colptr(g), X_mem); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + return best_g; + } + else + if(dist_mode == prob_dist) + { + const eT* log_hefts_mem = log_hefts.memptr(); + + eT best_p = -Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g]; + + if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } + } + + return best_g; + } + else + { + arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode"); + } + + return uword(0); + } + + + +template +inline +void +gmm_diag::internal_vec_assign(urowvec& out, const Mat& X, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" ); + + const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0; + + out.set_size(1,X_n_cols); + + uword* out_mem = out.memptr(); + + if(dist_mode == eucl_dist) + { + #if defined(ARMA_USE_OPENMP) + { + #pragma omp parallel for schedule(static) + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + out_mem[i] = best_g; + } + } + #else + { + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + out_mem[i] = best_g; + } + } + #endif + } + else + if(dist_mode == prob_dist) + { + #if defined(ARMA_USE_OPENMP) + { + const eT* log_hefts_mem = log_hefts.memptr(); + + #pragma omp parallel for schedule(static) + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g= best_p) { best_p = tmp_p; best_g = g; } + } + + out_mem[i] = best_g; + } + } + #else + { + const eT* log_hefts_mem = log_hefts.memptr(); + + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g= best_p) { best_p = tmp_p; best_g = g; } + } + + out_mem[i] = best_g; + } + } + #endif + } + else + { + arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode"); + } + } + + + + +template +inline +void +gmm_diag::internal_raw_hist(urowvec& hist, const Mat& X, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const uword X_n_cols = X.n_cols; + + hist.zeros(N_gaus); + + if(N_gaus == 0) { return; } + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(X_n_cols); + + const uword n_threads = boundaries.n_cols; + + field thread_hist(n_threads); + + for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); } + + + if(dist_mode == eucl_dist) + { + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + uword* thread_hist_mem = thread_hist(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT best_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_dist = distance::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + thread_hist_mem[best_g]++; + } + } + } + else + if(dist_mode == prob_dist) + { + const eT* log_hefts_mem = log_hefts.memptr(); + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + uword* thread_hist_mem = thread_hist(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT best_p = -Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g]; + + if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } + } + + thread_hist_mem[best_g]++; + } + } + } + + // reduction + hist = thread_hist(0); + + for(uword t=1; t < n_threads; ++t) + { + hist += thread_hist(t); + } + } + #else + { + uword* hist_mem = hist.memptr(); + + if(dist_mode == eucl_dist) + { + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_dist = distance::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + hist_mem[best_g]++; + } + } + else + if(dist_mode == prob_dist) + { + const eT* log_hefts_mem = log_hefts.memptr(); + + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g]; + + if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } + } + + hist_mem[best_g]++; + } + } + } + #endif + } + + + +template +template +inline +void +gmm_diag::generate_initial_means(const Mat& X, const gmm_seed_mode& seed_mode) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + if( (seed_mode == static_subset) || (seed_mode == random_subset) ) + { + uvec initial_indices; + + if(seed_mode == static_subset) { initial_indices = linspace(0, X.n_cols-1, N_gaus); } + else if(seed_mode == random_subset) { initial_indices = randperm(X.n_cols, N_gaus); } + + // initial_indices.print("initial_indices:"); + + access::rw(means) = X.cols(initial_indices); + } + else + if( (seed_mode == static_spread) || (seed_mode == random_spread) ) + { + // going through all of the samples can be extremely time consuming; + // instead, if there are enough samples, randomly choose samples with probability 0.1 + + const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus); + const uword step = (use_sampling) ? uword(10) : uword(1); + + uword start_index = 0; + + if(seed_mode == static_spread) { start_index = X.n_cols / 2; } + else if(seed_mode == random_spread) { start_index = as_scalar(randi(1, distr_param(0,X.n_cols-1))); } + + access::rw(means).col(0) = X.unsafe_col(start_index); + + const eT* mah_aux_mem = mah_aux.memptr(); + + running_stat rs; + + for(uword g=1; g < N_gaus; ++g) + { + eT max_dist = eT(0); + uword best_i = uword(0); + uword start_i = uword(0); + + if(use_sampling) + { + uword start_i_proposed = uword(0); + + if(seed_mode == static_spread) { start_i_proposed = g % uword(10); } + if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi(1, distr_param(0,9))); } + + if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; } + } + + + for(uword i=start_i; i < X.n_cols; i += step) + { + rs.reset(); + + const eT* X_colptr = X.colptr(i); + + bool ignore_i = false; + + // find the average distance between sample i and the means so far + for(uword h = 0; h < g; ++h) + { + const eT dist = distance::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem); + + // ignore sample already selected as a mean + if(dist == eT(0)) { ignore_i = true; break; } + else { rs(dist); } + } + + if( (rs.mean() >= max_dist) && (ignore_i == false)) + { + max_dist = eT(rs.mean()); best_i = i; + } + } + + // set the mean to the sample that is the furthest away from the means so far + access::rw(means).col(g) = X.unsafe_col(best_i); + } + } + + // get_cout_stream() << "generate_initial_means():" << '\n'; + // means.print(); + } + + + +template +template +inline +void +gmm_diag::generate_initial_params(const Mat& X, const eT var_floor) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT* mah_aux_mem = mah_aux.memptr(); + + const uword X_n_cols = X.n_cols; + + if(X_n_cols == 0) { return; } + + // as the covariances are calculated via accumulators, + // the means also need to be calculated via accumulators to ensure numerical consistency + + Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); + Mat acc_dcovs(N_dims, N_gaus, arma_zeros_indicator()); + + Row acc_hefts(N_gaus, arma_zeros_indicator()); + + uword* acc_hefts_mem = acc_hefts.memptr(); + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(X_n_cols); + + const uword n_threads = boundaries.n_cols; + + field< Mat > t_acc_means(n_threads); + field< Mat > t_acc_dcovs(n_threads); + field< Row > t_acc_hefts(n_threads); + + for(uword t=0; t < n_threads; ++t) + { + t_acc_means(t).zeros(N_dims, N_gaus); + t_acc_dcovs(t).zeros(N_dims, N_gaus); + t_acc_hefts(t).zeros(N_gaus); + } + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + uword* t_acc_hefts_mem = t_acc_hefts(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT min_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* t_acc_mean = t_acc_means(t).colptr(best_g); + eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g); + + for(uword d=0; d::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* acc_mean = acc_means.colptr(best_g); + eT* acc_dcov = acc_dcovs.colptr(best_g); + + for(uword d=0; d= 1) ? tmp : eT(0); + dcov[d] = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor); + } + + hefts_mem[g] = eT(acc_heft) / eT(X_n_cols); + } + + em_fix_params(var_floor); + } + + + +//! multi-threaded implementation of k-means, inspired by MapReduce +template +template +inline +bool +gmm_diag::km_iterate(const Mat& X, const uword max_iter, const bool verbose, const char* signature) + { + arma_extra_debug_sigprint(); + + if(verbose) + { + get_cout_stream().unsetf(ios::showbase); + get_cout_stream().unsetf(ios::uppercase); + get_cout_stream().unsetf(ios::showpos); + get_cout_stream().unsetf(ios::scientific); + + get_cout_stream().setf(ios::right); + get_cout_stream().setf(ios::fixed); + } + + const uword X_n_cols = X.n_cols; + + if(X_n_cols == 0) { return true; } + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT* mah_aux_mem = mah_aux.memptr(); + + Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); + Row acc_hefts( N_gaus, arma_zeros_indicator()); + Row last_indx( N_gaus, arma_zeros_indicator()); + + Mat new_means = means; + Mat old_means = means; + + running_mean_scalar rs_delta; + + #if defined(ARMA_USE_OPENMP) + const umat boundaries = internal_gen_boundaries(X_n_cols); + const uword n_threads = boundaries.n_cols; + + field< Mat > t_acc_means(n_threads); + field< Row > t_acc_hefts(n_threads); + field< Row > t_last_indx(n_threads); + #else + const uword n_threads = 1; + #endif + + if(verbose) { get_cout_stream() << signature << ": n_threads: " << n_threads << '\n'; get_cout_stream().flush(); } + + for(uword iter=1; iter <= max_iter; ++iter) + { + #if defined(ARMA_USE_OPENMP) + { + for(uword t=0; t < n_threads; ++t) + { + t_acc_means(t).zeros(N_dims, N_gaus); + t_acc_hefts(t).zeros(N_gaus); + t_last_indx(t).zeros(N_gaus); + } + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + Mat& t_acc_means_t = t_acc_means(t); + uword* t_acc_hefts_mem = t_acc_hefts(t).memptr(); + uword* t_last_indx_mem = t_last_indx(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT min_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* t_acc_mean = t_acc_means_t.colptr(best_g); + + for(uword d=0; d= 1 ) { last_indx(g) = t_last_indx(t)(g); } + } + } + #else + { + acc_hefts.zeros(); + acc_means.zeros(); + last_indx.zeros(); + + uword* acc_hefts_mem = acc_hefts.memptr(); + uword* last_indx_mem = last_indx.memptr(); + + for(uword i=0; i < X_n_cols; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT min_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* acc_mean = acc_means.colptr(best_g); + + for(uword d=0; d= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0); + } + } + + + // heuristics to resurrect dead means + + const uvec dead_gs = find(acc_hefts == uword(0)); + + if(dead_gs.n_elem > 0) + { + if(verbose) { get_cout_stream() << signature << ": recovering from dead means\n"; get_cout_stream().flush(); } + + uword* last_indx_mem = last_indx.memptr(); + + const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" ); + + if(live_gs.n_elem == 0) { return false; } + + uword live_gs_count = 0; + + for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count) + { + const uword dead_g_id = dead_gs(dead_gs_count); + + uword proposed_i = 0; + + if(live_gs_count < live_gs.n_elem) + { + const uword live_g_id = live_gs(live_gs_count); ++live_gs_count; + + if(live_g_id == dead_g_id) { return false; } + + // recover by using a sample from a known good mean + proposed_i = last_indx_mem[live_g_id]; + } + else + { + // recover by using a randomly seleced sample (last resort) + proposed_i = as_scalar(randi(1, distr_param(0,X_n_cols-1))); + } + + if(proposed_i >= X_n_cols) { return false; } + + new_means.col(dead_g_id) = X.col(proposed_i); + } + } + + rs_delta.reset(); + + for(uword g=0; g < N_gaus; ++g) + { + rs_delta( distance::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) ); + } + + if(verbose) + { + get_cout_stream() << signature << ": iteration: "; + get_cout_stream().unsetf(ios::scientific); + get_cout_stream().setf(ios::fixed); + get_cout_stream().width(std::streamsize(4)); + get_cout_stream() << iter; + get_cout_stream() << " delta: "; + get_cout_stream().unsetf(ios::fixed); + //get_cout_stream().setf(ios::scientific); + get_cout_stream() << rs_delta.mean() << '\n'; + get_cout_stream().flush(); + } + + arma::swap(old_means, new_means); + + if(rs_delta.mean() <= Datum::eps) { break; } + } + + access::rw(means) = old_means; + + if(means.internal_has_nonfinite()) { return false; } + + return true; + } + + + +//! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce +template +inline +bool +gmm_diag::em_iterate(const Mat& X, const uword max_iter, const eT var_floor, const bool verbose) + { + arma_extra_debug_sigprint(); + + if(X.n_cols == 0) { return true; } + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + if(verbose) + { + get_cout_stream().unsetf(ios::showbase); + get_cout_stream().unsetf(ios::uppercase); + get_cout_stream().unsetf(ios::showpos); + get_cout_stream().unsetf(ios::scientific); + + get_cout_stream().setf(ios::right); + get_cout_stream().setf(ios::fixed); + } + + const umat boundaries = internal_gen_boundaries(X.n_cols); + + const uword n_threads = boundaries.n_cols; + + field< Mat > t_acc_means(n_threads); + field< Mat > t_acc_dcovs(n_threads); + + field< Col > t_acc_norm_lhoods(n_threads); + field< Col > t_gaus_log_lhoods(n_threads); + + Col t_progress_log_lhood(n_threads, arma_nozeros_indicator()); + + for(uword t=0; t::inf; + + for(uword iter=1; iter <= max_iter; ++iter) + { + init_constants(); + + em_update_params(X, boundaries, t_acc_means, t_acc_dcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood); + + em_fix_params(var_floor); + + const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem); + + if(verbose) + { + get_cout_stream() << "gmm_diag::learn(): EM: iteration: "; + get_cout_stream().unsetf(ios::scientific); + get_cout_stream().setf(ios::fixed); + get_cout_stream().width(std::streamsize(4)); + get_cout_stream() << iter; + get_cout_stream() << " avg_log_p: "; + get_cout_stream().unsetf(ios::fixed); + //get_cout_stream().setf(ios::scientific); + get_cout_stream() << new_avg_log_p << '\n'; + get_cout_stream().flush(); + } + + if(arma_isfinite(new_avg_log_p) == false) { return false; } + + if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum::eps) { break; } + + + old_avg_log_p = new_avg_log_p; + } + + + if(any(vectorise(dcovs) <= eT(0))) { return false; } + if(means.internal_has_nonfinite()) { return false; } + if(dcovs.internal_has_nonfinite()) { return false; } + if(hefts.internal_has_nonfinite()) { return false; } + + return true; + } + + + + +template +inline +void +gmm_diag::em_update_params + ( + const Mat& X, + const umat& boundaries, + field< Mat >& t_acc_means, + field< Mat >& t_acc_dcovs, + field< Col >& t_acc_norm_lhoods, + field< Col >& t_gaus_log_lhoods, + Col& t_progress_log_lhood + ) + { + arma_extra_debug_sigprint(); + + const uword n_threads = boundaries.n_cols; + + + // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp parallel for schedule(static) + for(uword t=0; t& acc_means = t_acc_means[t]; + Mat& acc_dcovs = t_acc_dcovs[t]; + Col& acc_norm_lhoods = t_acc_norm_lhoods[t]; + Col& gaus_log_lhoods = t_gaus_log_lhoods[t]; + eT& progress_log_lhood = t_progress_log_lhood[t]; + + em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_dcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood); + } + } + #else + { + em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_dcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]); + } + #endif + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + Mat& final_acc_means = t_acc_means[0]; + Mat& final_acc_dcovs = t_acc_dcovs[0]; + + Col& final_acc_norm_lhoods = t_acc_norm_lhoods[0]; + + + // the "reduce" operation, which combines the partial accumulators produced by the separate threads + + for(uword t=1; t::min() ); + // + // eT* mean_mem = access::rw(means).colptr(g); + // eT* dcov_mem = access::rw(dcovs).colptr(g); + // + // eT* acc_mean_mem = final_acc_means.colptr(g); + // eT* acc_dcov_mem = final_acc_dcovs.colptr(g); + // + // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); + // + // for(uword d=0; d < N_dims; ++d) + // { + // const eT tmp = acc_mean_mem[d] / acc_norm_lhood; + // + // mean_mem[d] = tmp; + // dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp; + // } + // } + + + // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them + for(uword g=0; g < N_gaus; ++g) + { + const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits::min() ); + + if(arma_isfinite(acc_norm_lhood) == false) { continue; } + + eT* acc_mean_mem = final_acc_means.colptr(g); + eT* acc_dcov_mem = final_acc_dcovs.colptr(g); + + bool ok = true; + + for(uword d=0; d < N_dims; ++d) + { + const eT tmp1 = acc_mean_mem[d] / acc_norm_lhood; + const eT tmp2 = acc_dcov_mem[d] / acc_norm_lhood - tmp1*tmp1; + + acc_mean_mem[d] = tmp1; + acc_dcov_mem[d] = tmp2; + + if(arma_isfinite(tmp2) == false) { ok = false; } + } + + + if(ok) + { + hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); + + eT* mean_mem = access::rw(means).colptr(g); + eT* dcov_mem = access::rw(dcovs).colptr(g); + + for(uword d=0; d < N_dims; ++d) + { + mean_mem[d] = acc_mean_mem[d]; + dcov_mem[d] = acc_dcov_mem[d]; + } + } + } + } + + + +template +inline +void +gmm_diag::em_generate_acc + ( + const Mat& X, + const uword start_index, + const uword end_index, + Mat& acc_means, + Mat& acc_dcovs, + Col& acc_norm_lhoods, + Col& gaus_log_lhoods, + eT& progress_log_lhood + ) + const + { + arma_extra_debug_sigprint(); + + progress_log_lhood = eT(0); + + acc_means.zeros(); + acc_dcovs.zeros(); + + acc_norm_lhoods.zeros(); + gaus_log_lhoods.zeros(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT* log_hefts_mem = log_hefts.memptr(); + eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr(); + + + for(uword i=start_index; i <= end_index; i++) + { + const eT* x = X.colptr(i); + + for(uword g=0; g < N_gaus; ++g) + { + gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g]; + } + + eT log_lhood_sum = gaus_log_lhoods_mem[0]; + + for(uword g=1; g < N_gaus; ++g) + { + log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); + } + + progress_log_lhood += log_lhood_sum; + + for(uword g=0; g < N_gaus; ++g) + { + const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum); + + acc_norm_lhoods[g] += norm_lhood; + + eT* acc_mean_mem = acc_means.colptr(g); + eT* acc_dcov_mem = acc_dcovs.colptr(g); + + for(uword d=0; d < N_dims; ++d) + { + const eT x_d = x[d]; + const eT y_d = x_d * norm_lhood; + + acc_mean_mem[d] += y_d; + acc_dcov_mem[d] += y_d * x_d; // equivalent to x_d * x_d * norm_lhood + } + } + } + + progress_log_lhood /= eT((end_index - start_index) + 1); + } + + + +template +inline +void +gmm_diag::em_fix_params(const eT var_floor) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT var_ceiling = std::numeric_limits::max(); + + const uword dcovs_n_elem = dcovs.n_elem; + eT* dcovs_mem = access::rw(dcovs).memptr(); + + for(uword i=0; i < dcovs_n_elem; ++i) + { + eT& var_val = dcovs_mem[i]; + + if(var_val < var_floor ) { var_val = var_floor; } + else if(var_val > var_ceiling) { var_val = var_ceiling; } + else if(arma_isnan(var_val) ) { var_val = eT(1); } + } + + + eT* hefts_mem = access::rw(hefts).memptr(); + + for(uword g1=0; g1 < N_gaus; ++g1) + { + if(hefts_mem[g1] > eT(0)) + { + const eT* means_colptr_g1 = means.colptr(g1); + + for(uword g2=(g1+1); g2 < N_gaus; ++g2) + { + if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits::epsilon()) ) + { + const eT dist = distance::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1); + + if(dist == eT(0)) { hefts_mem[g2] = eT(0); } + } + } + } + } + + const eT heft_floor = std::numeric_limits::min(); + const eT heft_initial = eT(1) / eT(N_gaus); + + for(uword i=0; i < N_gaus; ++i) + { + eT& heft_val = hefts_mem[i]; + + if(heft_val < heft_floor) { heft_val = heft_floor; } + else if(heft_val > eT(1) ) { heft_val = eT(1); } + else if(arma_isnan(heft_val) ) { heft_val = heft_initial; } + } + + const eT heft_sum = accu(hefts); + + if((heft_sum < (eT(1) - Datum::eps)) || (heft_sum > (eT(1) + Datum::eps))) { access::rw(hefts) /= heft_sum; } + } + + +} // namespace gmm_priv + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/gmm_full_bones.hpp b/src/armadillo/include/armadillo_bits/gmm_full_bones.hpp new file mode 100644 index 0000000..a842a62 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/gmm_full_bones.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gmm_full +//! @{ + + +namespace gmm_priv +{ + +template +class gmm_full + { + public: + + arma_aligned const Mat means; + arma_aligned const Cube fcovs; + arma_aligned const Row hefts; + + // + // + + inline ~gmm_full(); + inline gmm_full(); + + inline gmm_full(const gmm_full& x); + inline gmm_full& operator=(const gmm_full& x); + + inline explicit gmm_full(const gmm_diag& x); + inline gmm_full& operator=(const gmm_diag& x); + + inline gmm_full(const uword in_n_dims, const uword in_n_gaus); + inline void reset(const uword in_n_dims, const uword in_n_gaus); + inline void reset(); + + template + inline void set_params(const Base& in_means, const BaseCube& in_fcovs, const Base& in_hefts); + + template inline void set_means(const Base & in_means); + template inline void set_fcovs(const BaseCube& in_fcovs); + template inline void set_hefts(const Base & in_hefts); + + inline uword n_dims() const; + inline uword n_gaus() const; + + inline bool load(const std::string name); + inline bool save(const std::string name) const; + + inline Col generate() const; + inline Mat generate(const uword N) const; + + template inline eT log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; + template inline eT log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; + + template inline Row log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; + template inline Row log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; + + template inline eT sum_log_p(const Base& expr) const; + template inline eT sum_log_p(const Base& expr, const uword gaus_id) const; + + template inline eT avg_log_p(const Base& expr) const; + template inline eT avg_log_p(const Base& expr, const uword gaus_id) const; + + template inline uword assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk = nullptr) const; + template inline urowvec assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk = nullptr) const; + + template inline urowvec raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; + template inline Row norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; + + template + inline + bool + learn + ( + const Base& data, + const uword n_gaus, + const gmm_dist_mode& dist_mode, + const gmm_seed_mode& seed_mode, + const uword km_iter, + const uword em_iter, + const eT var_floor, + const bool print_mode + ); + + + // + + protected: + + + arma_aligned Cube inv_fcovs; + arma_aligned Row log_det_etc; + arma_aligned Row log_hefts; + arma_aligned Col mah_aux; + arma_aligned Cube chol_fcovs; + + // + + inline void init(const gmm_full& x); + inline void init(const gmm_diag& x); + + inline void init(const uword in_n_dim, const uword in_n_gaus); + + inline void init_constants(const bool calc_chol = true); + + inline umat internal_gen_boundaries(const uword N) const; + + inline eT internal_scalar_log_p(const eT* x ) const; + inline eT internal_scalar_log_p(const eT* x, const uword gaus_id) const; + + inline Row internal_vec_log_p(const Mat& X ) const; + inline Row internal_vec_log_p(const Mat& X, const uword gaus_id) const; + + inline eT internal_sum_log_p(const Mat& X ) const; + inline eT internal_sum_log_p(const Mat& X, const uword gaus_id) const; + + inline eT internal_avg_log_p(const Mat& X ) const; + inline eT internal_avg_log_p(const Mat& X, const uword gaus_id) const; + + inline uword internal_scalar_assign(const Mat& X, const gmm_dist_mode& dist_mode) const; + + inline void internal_vec_assign(urowvec& out, const Mat& X, const gmm_dist_mode& dist_mode) const; + + inline void internal_raw_hist(urowvec& hist, const Mat& X, const gmm_dist_mode& dist_mode) const; + + // + + template inline void generate_initial_means(const Mat& X, const gmm_seed_mode& seed); + + template inline void generate_initial_params(const Mat& X, const eT var_floor); + + template inline bool km_iterate(const Mat& X, const uword max_iter, const bool verbose); + + // + + inline bool em_iterate(const Mat& X, const uword max_iter, const eT var_floor, const bool verbose); + + inline void em_update_params(const Mat& X, const umat& boundaries, field< Mat >& t_acc_means, field< Cube >& t_acc_fcovs, field< Col >& t_acc_norm_lhoods, field< Col >& t_gaus_log_lhoods, Col& t_progress_log_lhoods, const eT var_floor); + + inline void em_generate_acc(const Mat& X, const uword start_index, const uword end_index, Mat& acc_means, Cube& acc_fcovs, Col& acc_norm_lhoods, Col& gaus_log_lhoods, eT& progress_log_lhood) const; + + inline void em_fix_params(const eT var_floor); + }; + +} + + +typedef gmm_priv::gmm_full gmm_full; +typedef gmm_priv::gmm_full fgmm_full; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/gmm_full_meat.hpp b/src/armadillo/include/armadillo_bits/gmm_full_meat.hpp new file mode 100644 index 0000000..5bbcce0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/gmm_full_meat.hpp @@ -0,0 +1,2739 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gmm_full +//! @{ + + +namespace gmm_priv +{ + + +template +inline +gmm_full::~gmm_full() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( (is_same_type::value == false) && (is_same_type::value == false) )); + } + + + +template +inline +gmm_full::gmm_full() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +gmm_full::gmm_full(const gmm_full& x) + { + arma_extra_debug_sigprint_this(this); + + init(x); + } + + + +template +inline +gmm_full& +gmm_full::operator=(const gmm_full& x) + { + arma_extra_debug_sigprint(); + + init(x); + + return *this; + } + + + +template +inline +gmm_full::gmm_full(const gmm_diag& x) + { + arma_extra_debug_sigprint_this(this); + + init(x); + } + + + +template +inline +gmm_full& +gmm_full::operator=(const gmm_diag& x) + { + arma_extra_debug_sigprint(); + + init(x); + + return *this; + } + + + +template +inline +gmm_full::gmm_full(const uword in_n_dims, const uword in_n_gaus) + { + arma_extra_debug_sigprint_this(this); + + init(in_n_dims, in_n_gaus); + } + + + +template +inline +void +gmm_full::reset() + { + arma_extra_debug_sigprint(); + + init(0, 0); + } + + + +template +inline +void +gmm_full::reset(const uword in_n_dims, const uword in_n_gaus) + { + arma_extra_debug_sigprint(); + + init(in_n_dims, in_n_gaus); + } + + + +template +template +inline +void +gmm_full::set_params(const Base& in_means_expr, const BaseCube& in_fcovs_expr, const Base& in_hefts_expr) + { + arma_extra_debug_sigprint(); + + const unwrap tmp1(in_means_expr.get_ref()); + const unwrap_cube tmp2(in_fcovs_expr.get_ref()); + const unwrap tmp3(in_hefts_expr.get_ref()); + + const Mat & in_means = tmp1.M; + const Cube& in_fcovs = tmp2.M; + const Mat & in_hefts = tmp3.M; + + arma_debug_check + ( + (in_means.n_cols != in_fcovs.n_slices) || (in_means.n_rows != in_fcovs.n_rows) || (in_fcovs.n_rows != in_fcovs.n_cols) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1), + "gmm_full::set_params(): given parameters have inconsistent and/or wrong sizes" + ); + + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_full::set_params(): given means have non-finite values" ); + arma_debug_check( (in_fcovs.internal_has_nonfinite()), "gmm_full::set_params(): given fcovs have non-finite values" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_full::set_params(): given hefts have non-finite values" ); + + for(uword g=0; g < in_fcovs.n_slices; ++g) + { + arma_debug_check( (any(diagvec(in_fcovs.slice(g)) <= eT(0))), "gmm_full::set_params(): given fcovs have negative or zero values on diagonals" ); + } + + arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_params(): given hefts have negative values" ); + + const eT s = accu(in_hefts); + + arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_params(): sum of given hefts is not 1" ); + + access::rw(means) = in_means; + access::rw(fcovs) = in_fcovs; + access::rw(hefts) = in_hefts; + + init_constants(); + } + + + +template +template +inline +void +gmm_full::set_means(const Base& in_means_expr) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(in_means_expr.get_ref()); + + const Mat& in_means = tmp.M; + + arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_full::set_means(): given means have incompatible size" ); + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_full::set_means(): given means have non-finite values" ); + + access::rw(means) = in_means; + } + + + +template +template +inline +void +gmm_full::set_fcovs(const BaseCube& in_fcovs_expr) + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp(in_fcovs_expr.get_ref()); + + const Cube& in_fcovs = tmp.M; + + arma_debug_check( (arma::size(in_fcovs) != arma::size(fcovs)), "gmm_full::set_fcovs(): given fcovs have incompatible size" ); + arma_debug_check( (in_fcovs.internal_has_nonfinite()), "gmm_full::set_fcovs(): given fcovs have non-finite values" ); + + for(uword i=0; i < in_fcovs.n_slices; ++i) + { + arma_debug_check( (any(diagvec(in_fcovs.slice(i)) <= eT(0))), "gmm_full::set_fcovs(): given fcovs have negative or zero values on diagonals" ); + } + + access::rw(fcovs) = in_fcovs; + + init_constants(); + } + + + +template +template +inline +void +gmm_full::set_hefts(const Base& in_hefts_expr) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(in_hefts_expr.get_ref()); + + const Mat& in_hefts = tmp.M; + + arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_full::set_hefts(): given hefts have incompatible size" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_full::set_hefts(): given hefts have non-finite values" ); + arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_hefts(): given hefts have negative values" ); + + const eT s = accu(in_hefts); + + arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_hefts(): sum of given hefts is not 1" ); + + // make sure all hefts are positive and non-zero + + const eT* in_hefts_mem = in_hefts.memptr(); + eT* hefts_mem = access::rw(hefts).memptr(); + + for(uword i=0; i < hefts.n_elem; ++i) + { + hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits::min() ); + } + + access::rw(hefts) /= accu(hefts); + + log_hefts = log(hefts); + } + + + +template +inline +uword +gmm_full::n_dims() const + { + return means.n_rows; + } + + + +template +inline +uword +gmm_full::n_gaus() const + { + return means.n_cols; + } + + + +template +inline +bool +gmm_full::load(const std::string name) + { + arma_extra_debug_sigprint(); + + field< Mat > storage; + + bool status = storage.load(name, arma_binary); + + if( (status == false) || (storage.n_elem < 2) ) + { + reset(); + arma_debug_warn_level(3, "gmm_full::load(): problem with loading or incompatible format"); + return false; + } + + uword count = 0; + + const Mat& storage_means = storage(count); ++count; + const Mat& storage_hefts = storage(count); ++count; + + const uword N_dims = storage_means.n_rows; + const uword N_gaus = storage_means.n_cols; + + if( (storage.n_elem != (N_gaus + 2)) || (storage_hefts.n_rows != 1) || (storage_hefts.n_cols != N_gaus) ) + { + reset(); + arma_debug_warn_level(3, "gmm_full::load(): incompatible format"); + return false; + } + + reset(N_dims, N_gaus); + + access::rw(means) = storage_means; + access::rw(hefts) = storage_hefts; + + for(uword g=0; g < N_gaus; ++g) + { + const Mat& storage_fcov = storage(count); ++count; + + if( (storage_fcov.n_rows != N_dims) || (storage_fcov.n_cols != N_dims) ) + { + reset(); + arma_debug_warn_level(3, "gmm_full::load(): incompatible format"); + return false; + } + + access::rw(fcovs).slice(g) = storage_fcov; + } + + init_constants(); + + return true; + } + + + +template +inline +bool +gmm_full::save(const std::string name) const + { + arma_extra_debug_sigprint(); + + const uword N_gaus = means.n_cols; + + field< Mat > storage(2 + N_gaus); + + uword count = 0; + + storage(count) = means; ++count; + storage(count) = hefts; ++count; + + for(uword g=0; g < N_gaus; ++g) + { + storage(count) = fcovs.slice(g); ++count; + } + + const bool status = storage.save(name, arma_binary); + + return status; + } + + + +template +inline +Col +gmm_full::generate() const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + Col out( (N_gaus > 0) ? N_dims : uword(0), arma_nozeros_indicator() ); + Col tmp( (N_gaus > 0) ? N_dims : uword(0), fill::randn ); + + if(N_gaus > 0) + { + const double val = randu(); + + double csum = double(0); + uword gaus_id = 0; + + for(uword j=0; j < N_gaus; ++j) + { + csum += hefts[j]; + + if(val <= csum) { gaus_id = j; break; } + } + + out = chol_fcovs.slice(gaus_id) * tmp; + out += means.col(gaus_id); + } + + return out; + } + + + +template +inline +Mat +gmm_full::generate(const uword N_vec) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + Mat out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, arma_nozeros_indicator() ); + Mat tmp( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn ); + + if(N_gaus > 0) + { + const eT* hefts_mem = hefts.memptr(); + + for(uword i=0; i < N_vec; ++i) + { + const double val = randu(); + + double csum = double(0); + uword gaus_id = 0; + + for(uword j=0; j < N_gaus; ++j) + { + csum += hefts_mem[j]; + + if(val <= csum) { gaus_id = j; break; } + } + + Col out_vec(out.colptr(i), N_dims, false, true); + Col tmp_vec(tmp.colptr(i), N_dims, false, true); + + out_vec = chol_fcovs.slice(gaus_id) * tmp_vec; + out_vec += means.col(gaus_id); + } + } + + return out; + } + + + +template +template +inline +eT +gmm_full::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword N_dims = means.n_rows; + + const quasi_unwrap U(expr); + + arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" ); + + return internal_scalar_log_p( U.M.memptr() ); + } + + + +template +template +inline +eT +gmm_full::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk2); + + const uword N_dims = means.n_rows; + + const quasi_unwrap U(expr); + + arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" ); + arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" ); + + return internal_scalar_log_p( U.M.memptr(), gaus_id ); + } + + + +template +template +inline +Row +gmm_full::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + return internal_vec_log_p(X); + } + + + +template +template +inline +Row +gmm_full::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk2); + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + return internal_vec_log_p(X, gaus_id); + } + + + +template +template +inline +eT +gmm_full::sum_log_p(const Base& expr) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_sum_log_p(X); + } + + + +template +template +inline +eT +gmm_full::sum_log_p(const Base& expr, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_sum_log_p(X, gaus_id); + } + + + +template +template +inline +eT +gmm_full::avg_log_p(const Base& expr) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_avg_log_p(X); + } + + + +template +template +inline +eT +gmm_full::avg_log_p(const Base& expr, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(expr.get_ref()); + + const Mat& X = tmp.M; + + return internal_avg_log_p(X, gaus_id); + } + + + +template +template +inline +uword +gmm_full::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + return internal_scalar_assign(X, dist); + } + + + +template +template +inline +urowvec +gmm_full::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk) const + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + urowvec out; + + const quasi_unwrap tmp(expr); + + const Mat& X = tmp.M; + + internal_vec_assign(out, X, dist); + + return out; + } + + + +template +template +inline +urowvec +gmm_full::raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const unwrap tmp(expr.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::raw_hist(): incompatible dimensions" ); + + arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::raw_hist(): unsupported distance mode" ); + + urowvec hist; + + internal_raw_hist(hist, X, dist_mode); + + return hist; + } + + + +template +template +inline +Row +gmm_full::norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const unwrap tmp(expr.get_ref()); + const Mat& X = tmp.M; + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::norm_hist(): incompatible dimensions" ); + + arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::norm_hist(): unsupported distance mode" ); + + urowvec hist; + + internal_raw_hist(hist, X, dist_mode); + + const uword hist_n_elem = hist.n_elem; + const uword* hist_mem = hist.memptr(); + + eT acc = eT(0); + for(uword i=0; i out(hist_n_elem, arma_nozeros_indicator()); + + eT* out_mem = out.memptr(); + + for(uword i=0; i +template +inline +bool +gmm_full::learn + ( + const Base& data, + const uword N_gaus, + const gmm_dist_mode& dist_mode, + const gmm_seed_mode& seed_mode, + const uword km_iter, + const uword em_iter, + const eT var_floor, + const bool print_mode + ) + { + arma_extra_debug_sigprint(); + + const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist); + + const bool seed_mode_ok = \ + (seed_mode == keep_existing) + || (seed_mode == static_subset) + || (seed_mode == static_spread) + || (seed_mode == random_subset) + || (seed_mode == random_spread); + + arma_debug_check( (dist_mode_ok == false), "gmm_full::learn(): dist_mode must be eucl_dist or maha_dist" ); + arma_debug_check( (seed_mode_ok == false), "gmm_full::learn(): unknown seed_mode" ); + arma_debug_check( (var_floor < eT(0) ), "gmm_full::learn(): variance floor is negative" ); + + const unwrap tmp_X(data.get_ref()); + const Mat& X = tmp_X.M; + + if(X.is_empty() ) { arma_debug_warn_level(3, "gmm_full::learn(): given matrix is empty" ); return false; } + if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "gmm_full::learn(): given matrix has non-finite values"); return false; } + + if(N_gaus == 0) { reset(); return true; } + + if(dist_mode == maha_dist) + { + mah_aux = var(X,1,1); + + const uword mah_aux_n_elem = mah_aux.n_elem; + eT* mah_aux_mem = mah_aux.memptr(); + + for(uword i=0; i < mah_aux_n_elem; ++i) + { + const eT val = mah_aux_mem[i]; + + mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1); + } + } + + + // copy current model, in case of failure by k-means and/or EM + + const gmm_full orig = (*this); + + + // initial means + + if(seed_mode == keep_existing) + { + if(means.is_empty() ) { arma_debug_warn_level(3, "gmm_full::learn(): no existing means" ); return false; } + if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "gmm_full::learn(): dimensionality mismatch"); return false; } + + // TODO: also check for number of vectors? + } + else + { + if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "gmm_full::learn(): number of vectors is less than number of gaussians"); return false; } + + reset(X.n_rows, N_gaus); + + if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial means\n"; get_cout_stream().flush(); } + + if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); } + else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); } + } + + + // k-means + + if(km_iter > 0) + { + const arma_ostream_state stream_state(get_cout_stream()); + + bool status = false; + + if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode); } + else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode); } + + stream_state.restore(get_cout_stream()); + + if(status == false) { arma_debug_warn_level(3, "gmm_full::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } + } + + + // initial fcovs + + const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits::min(); + + if(seed_mode != keep_existing) + { + if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial covariances\n"; get_cout_stream().flush(); } + + if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); } + else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); } + } + + + // EM algorithm + + if(em_iter > 0) + { + const arma_ostream_state stream_state(get_cout_stream()); + + const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode); + + stream_state.restore(get_cout_stream()); + + if(status == false) { arma_debug_warn_level(3, "gmm_full::learn(): EM algorithm failed"); init(orig); return false; } + } + + mah_aux.reset(); + + init_constants(); + + return true; + } + + + +// +// +// + + + +template +inline +void +gmm_full::init(const gmm_full& x) + { + arma_extra_debug_sigprint(); + + gmm_full& t = *this; + + if(&t != &x) + { + access::rw(t.means) = x.means; + access::rw(t.fcovs) = x.fcovs; + access::rw(t.hefts) = x.hefts; + + init_constants(); + } + } + + + +template +inline +void +gmm_full::init(const gmm_diag& x) + { + arma_extra_debug_sigprint(); + + access::rw(hefts) = x.hefts; + access::rw(means) = x.means; + + const uword N_dims = x.means.n_rows; + const uword N_gaus = x.means.n_cols; + + access::rw(fcovs).zeros(N_dims,N_dims,N_gaus); + + for(uword g=0; g < N_gaus; ++g) + { + Mat& fcov = access::rw(fcovs).slice(g); + + const eT* dcov_mem = x.dcovs.colptr(g); + + for(uword d=0; d < N_dims; ++d) + { + fcov.at(d,d) = dcov_mem[d]; + } + } + + init_constants(); + } + + + +template +inline +void +gmm_full::init(const uword in_n_dims, const uword in_n_gaus) + { + arma_extra_debug_sigprint(); + + access::rw(means).zeros(in_n_dims, in_n_gaus); + + access::rw(fcovs).zeros(in_n_dims, in_n_dims, in_n_gaus); + + for(uword g=0; g < in_n_gaus; ++g) + { + access::rw(fcovs).slice(g).diag().ones(); + } + + access::rw(hefts).set_size(in_n_gaus); + access::rw(hefts).fill(eT(1) / eT(in_n_gaus)); + + init_constants(); + } + + + +template +inline +void +gmm_full::init_constants(const bool calc_chol) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT tmp = (eT(N_dims)/eT(2)) * std::log(Datum::tau); + + // + + inv_fcovs.copy_size(fcovs); + log_det_etc.set_size(N_gaus); + + Mat tmp_inv; + + for(uword g=0; g < N_gaus; ++g) + { + const Mat& fcov = fcovs.slice(g); + Mat& inv_fcov = inv_fcovs.slice(g); + + //const bool inv_ok = auxlib::inv(tmp_inv, fcov); + const bool inv_ok = auxlib::inv_sympd(tmp_inv, fcov); + + eT log_det_val = eT(0); + eT log_det_sign = eT(0); + + const bool log_det_status = log_det(log_det_val, log_det_sign, fcov); + + const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) ); + + if(inv_ok && log_det_ok) + { + inv_fcov = tmp_inv; + } + else + { + // last resort: treat the covariance matrix as diagonal + + inv_fcov.zeros(); + + log_det_val = eT(0); + + for(uword d=0; d < N_dims; ++d) + { + const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits::min()) ); + + inv_fcov.at(d,d) = eT(1) / sanitised_val; + + log_det_val += std::log(sanitised_val); + } + } + + log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val ); + } + + // + + eT* hefts_mem = access::rw(hefts).memptr(); + + for(uword g=0; g < N_gaus; ++g) + { + hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits::min() ); + } + + log_hefts = log(hefts); + + + if(calc_chol) + { + chol_fcovs.copy_size(fcovs); + + Mat tmp_chol; + + for(uword g=0; g < N_gaus; ++g) + { + const Mat& fcov = fcovs.slice(g); + Mat& chol_fcov = chol_fcovs.slice(g); + + const uword chol_layout = 1; // indicates "lower" + + const bool chol_ok = op_chol::apply_direct(tmp_chol, fcov, chol_layout); + + if(chol_ok) + { + chol_fcov = tmp_chol; + } + else + { + // last resort: treat the covariance matrix as diagonal + + chol_fcov.zeros(); + + for(uword d=0; d < N_dims; ++d) + { + const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits::min()) ); + + chol_fcov.at(d,d) = std::sqrt(sanitised_val); + } + } + } + } + } + + + +template +inline +umat +gmm_full::internal_gen_boundaries(const uword N) const + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_OPENMP) + const uword n_threads_avail = uword(omp_get_max_threads()); + const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1; + #else + static constexpr uword n_threads = 1; + #endif + + // get_cout_stream() << "gmm_full::internal_gen_boundaries(): n_threads: " << n_threads << '\n'; + + umat boundaries(2, n_threads, arma_nozeros_indicator()); + + if(N > 0) + { + const uword chunk_size = N / n_threads; + + uword count = 0; + + for(uword t=0; t +inline +eT +gmm_full::internal_scalar_log_p(const eT* x) const + { + arma_extra_debug_sigprint(); + + const eT* log_hefts_mem = log_hefts.mem; + + const uword N_gaus = means.n_cols; + + if(N_gaus > 0) + { + eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0]; + + for(uword g=1; g < N_gaus; ++g) + { + const eT log_val = internal_scalar_log_p(x, g) + log_hefts_mem[g]; + + log_sum = log_add_exp(log_sum, log_val); + } + + return log_sum; + } + else + { + return -Datum::inf; + } + } + + + +template +inline +eT +gmm_full::internal_scalar_log_p(const eT* x, const uword g) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const eT* mean_mem = means.colptr(g); + + eT outer_acc = eT(0); + + const eT* inv_fcov_coldata = inv_fcovs.slice(g).memptr(); + + for(uword i=0; i < N_dims; ++i) + { + eT inner_acc = eT(0); + + for(uword j=0; j < N_dims; ++j) + { + inner_acc += (x[j] - mean_mem[j]) * inv_fcov_coldata[j]; + } + + inv_fcov_coldata += N_dims; + + outer_acc += inner_acc * (x[i] - mean_mem[i]); + } + + return eT(-0.5)*outer_acc + log_det_etc.mem[g]; + } + + + +template +inline +Row +gmm_full::internal_vec_log_p(const Mat& X) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_samples = X.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" ); + + Row out(N_samples, arma_nozeros_indicator()); + + if(N_samples > 0) + { + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N_samples); + + const uword n_threads = boundaries.n_cols; + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT* out_mem = out.memptr(); + + for(uword i=start_index; i <= end_index; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i) ); + } + } + } + #else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i < N_samples; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i) ); + } + } + #endif + } + + return out; + } + + + +template +inline +Row +gmm_full::internal_vec_log_p(const Mat& X, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_samples = X.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" ); + arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" ); + + Row out(N_samples, arma_nozeros_indicator()); + + if(N_samples > 0) + { + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N_samples); + + const uword n_threads = boundaries.n_cols; + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT* out_mem = out.memptr(); + + for(uword i=start_index; i <= end_index; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); + } + } + } + #else + { + eT* out_mem = out.memptr(); + + for(uword i=0; i < N_samples; ++i) + { + out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); + } + } + #endif + } + + return out; + } + + + +template +inline +eT +gmm_full::internal_sum_log_p(const Mat& X) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" ); + + const uword N = X.n_cols; + + if(N == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + Col t_accs(n_threads, arma_zeros_indicator()); + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT t_acc = eT(0); + + for(uword i=start_index; i <= end_index; ++i) + { + t_acc += internal_scalar_log_p( X.colptr(i) ); + } + + t_accs[t] = t_acc; + } + + return eT(accu(t_accs)); + } + #else + { + eT acc = eT(0); + + for(uword i=0; i +inline +eT +gmm_full::internal_sum_log_p(const Mat& X, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" ); + arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::sum_log_p(): specified gaussian is out of range" ); + + const uword N = X.n_cols; + + if(N == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N); + + const uword n_threads = boundaries.n_cols; + + Col t_accs(n_threads, arma_zeros_indicator()); + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + eT t_acc = eT(0); + + for(uword i=start_index; i <= end_index; ++i) + { + t_acc += internal_scalar_log_p( X.colptr(i), gaus_id ); + } + + t_accs[t] = t_acc; + } + + return eT(accu(t_accs)); + } + #else + { + eT acc = eT(0); + + for(uword i=0; i +inline +eT +gmm_full::internal_avg_log_p(const Mat& X) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_samples = X.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" ); + + if(N_samples == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N_samples); + + const uword n_threads = boundaries.n_cols; + + field< running_mean_scalar > t_running_means(n_threads); + + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + running_mean_scalar& current_running_mean = t_running_means[t]; + + for(uword i=start_index; i <= end_index; ++i) + { + current_running_mean( internal_scalar_log_p( X.colptr(i) ) ); + } + } + + + eT avg = eT(0); + + for(uword t=0; t < n_threads; ++t) + { + running_mean_scalar& current_running_mean = t_running_means[t]; + + const eT w = eT(current_running_mean.count()) / eT(N_samples); + + avg += w * current_running_mean.mean(); + } + + return avg; + } + #else + { + running_mean_scalar running_mean; + + for(uword i=0; i < N_samples; ++i) + { + running_mean( internal_scalar_log_p( X.colptr(i) ) ); + } + + return running_mean.mean(); + } + #endif + } + + + +template +inline +eT +gmm_full::internal_avg_log_p(const Mat& X, const uword gaus_id) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_samples = X.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" ); + arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::avg_log_p(): specified gaussian is out of range" ); + + if(N_samples == 0) { return (-Datum::inf); } + + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(N_samples); + + const uword n_threads = boundaries.n_cols; + + field< running_mean_scalar > t_running_means(n_threads); + + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + running_mean_scalar& current_running_mean = t_running_means[t]; + + for(uword i=start_index; i <= end_index; ++i) + { + current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) ); + } + } + + + eT avg = eT(0); + + for(uword t=0; t < n_threads; ++t) + { + running_mean_scalar& current_running_mean = t_running_means[t]; + + const eT w = eT(current_running_mean.count()) / eT(N_samples); + + avg += w * current_running_mean.mean(); + } + + return avg; + } + #else + { + running_mean_scalar running_mean; + + for(uword i=0; i +inline +uword +gmm_full::internal_scalar_assign(const Mat& X, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" ); + arma_debug_check( (N_gaus == 0), "gmm_full::assign(): model has no means" ); + + const eT* X_mem = X.colptr(0); + + if(dist_mode == eucl_dist) + { + eT best_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_dist = distance::eval(N_dims, X_mem, means.colptr(g), X_mem); + + if(tmp_dist <= best_dist) + { + best_dist = tmp_dist; + best_g = g; + } + } + + return best_g; + } + else + if(dist_mode == prob_dist) + { + const eT* log_hefts_mem = log_hefts.memptr(); + + eT best_p = -Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g]; + + if(tmp_p >= best_p) + { + best_p = tmp_p; + best_g = g; + } + } + + return best_g; + } + else + { + arma_debug_check(true, "gmm_full::assign(): unsupported distance mode"); + } + + return uword(0); + } + + + +template +inline +void +gmm_full::internal_vec_assign(urowvec& out, const Mat& X, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" ); + + const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0; + + out.set_size(1,X_n_cols); + + uword* out_mem = out.memptr(); + + if(dist_mode == eucl_dist) + { + #if defined(ARMA_USE_OPENMP) + { + #pragma omp parallel for schedule(static) + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + out_mem[i] = best_g; + } + } + #else + { + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + out_mem[i] = best_g; + } + } + #endif + } + else + if(dist_mode == prob_dist) + { + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(X_n_cols); + + const uword n_threads = boundaries.n_cols; + + const eT* log_hefts_mem = log_hefts.memptr(); + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT best_p = -Datum::inf; + uword best_g = 0; + + for(uword g=0; g= best_p) { best_p = tmp_p; best_g = g; } + } + + out_mem[i] = best_g; + } + } + } + #else + { + const eT* log_hefts_mem = log_hefts.memptr(); + + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g= best_p) { best_p = tmp_p; best_g = g; } + } + + out_mem[i] = best_g; + } + } + #endif + } + else + { + arma_debug_check(true, "gmm_full::assign(): unsupported distance mode"); + } + } + + + + +template +inline +void +gmm_full::internal_raw_hist(urowvec& hist, const Mat& X, const gmm_dist_mode& dist_mode) const + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const uword X_n_cols = X.n_cols; + + hist.zeros(N_gaus); + + if(N_gaus == 0) { return; } + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(X_n_cols); + + const uword n_threads = boundaries.n_cols; + + field thread_hist(n_threads); + + for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); } + + + if(dist_mode == eucl_dist) + { + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + uword* thread_hist_mem = thread_hist(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT best_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_dist = distance::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + thread_hist_mem[best_g]++; + } + } + } + else + if(dist_mode == prob_dist) + { + const eT* log_hefts_mem = log_hefts.memptr(); + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + uword* thread_hist_mem = thread_hist(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT best_p = -Datum::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g]; + + if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } + } + + thread_hist_mem[best_g]++; + } + } + } + + // reduction + for(uword t=0; t < n_threads; ++t) + { + hist += thread_hist(t); + } + } + #else + { + uword* hist_mem = hist.memptr(); + + if(dist_mode == eucl_dist) + { + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_dist = distance::eval(N_dims, X_colptr, means.colptr(g), X_colptr); + + if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } + } + + hist_mem[best_g]++; + } + } + else + if(dist_mode == prob_dist) + { + const eT* log_hefts_mem = log_hefts.memptr(); + + for(uword i=0; i::inf; + uword best_g = 0; + + for(uword g=0; g < N_gaus; ++g) + { + const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g]; + + if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } + } + + hist_mem[best_g]++; + } + } + } + #endif + } + + + +template +template +inline +void +gmm_full::generate_initial_means(const Mat& X, const gmm_seed_mode& seed_mode) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + if( (seed_mode == static_subset) || (seed_mode == random_subset) ) + { + uvec initial_indices; + + if(seed_mode == static_subset) { initial_indices = linspace(0, X.n_cols-1, N_gaus); } + else if(seed_mode == random_subset) { initial_indices = randperm(X.n_cols, N_gaus); } + + // initial_indices.print("initial_indices:"); + + access::rw(means) = X.cols(initial_indices); + } + else + if( (seed_mode == static_spread) || (seed_mode == random_spread) ) + { + // going through all of the samples can be extremely time consuming; + // instead, if there are enough samples, randomly choose samples with probability 0.1 + + const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus); + const uword step = (use_sampling) ? uword(10) : uword(1); + + uword start_index = 0; + + if(seed_mode == static_spread) { start_index = X.n_cols / 2; } + else if(seed_mode == random_spread) { start_index = as_scalar(randi(1, distr_param(0,X.n_cols-1))); } + + access::rw(means).col(0) = X.unsafe_col(start_index); + + const eT* mah_aux_mem = mah_aux.memptr(); + + running_stat rs; + + for(uword g=1; g < N_gaus; ++g) + { + eT max_dist = eT(0); + uword best_i = uword(0); + uword start_i = uword(0); + + if(use_sampling) + { + uword start_i_proposed = uword(0); + + if(seed_mode == static_spread) { start_i_proposed = g % uword(10); } + if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi(1, distr_param(0,9))); } + + if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; } + } + + + for(uword i=start_i; i < X.n_cols; i += step) + { + rs.reset(); + + const eT* X_colptr = X.colptr(i); + + bool ignore_i = false; + + // find the average distance between sample i and the means so far + for(uword h = 0; h < g; ++h) + { + const eT dist = distance::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem); + + // ignore sample already selected as a mean + if(dist == eT(0)) { ignore_i = true; break; } + else { rs(dist); } + } + + if( (rs.mean() >= max_dist) && (ignore_i == false)) + { + max_dist = eT(rs.mean()); best_i = i; + } + } + + // set the mean to the sample that is the furthest away from the means so far + access::rw(means).col(g) = X.unsafe_col(best_i); + } + } + + // get_cout_stream() << "generate_initial_means():" << '\n'; + // means.print(); + } + + + +template +template +inline +void +gmm_full::generate_initial_params(const Mat& X, const eT var_floor) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT* mah_aux_mem = mah_aux.memptr(); + + const uword X_n_cols = X.n_cols; + + if(X_n_cols == 0) { return; } + + // as the covariances are calculated via accumulators, + // the means also need to be calculated via accumulators to ensure numerical consistency + + Mat acc_means(N_dims, N_gaus); + Mat acc_dcovs(N_dims, N_gaus); + + Row acc_hefts(N_gaus, arma_zeros_indicator()); + + uword* acc_hefts_mem = acc_hefts.memptr(); + + #if defined(ARMA_USE_OPENMP) + { + const umat boundaries = internal_gen_boundaries(X_n_cols); + + const uword n_threads = boundaries.n_cols; + + field< Mat > t_acc_means(n_threads); + field< Mat > t_acc_dcovs(n_threads); + field< Row > t_acc_hefts(n_threads); + + for(uword t=0; t < n_threads; ++t) + { + t_acc_means(t).zeros(N_dims, N_gaus); + t_acc_dcovs(t).zeros(N_dims, N_gaus); + t_acc_hefts(t).zeros(N_gaus); + } + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + uword* t_acc_hefts_mem = t_acc_hefts(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT min_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* t_acc_mean = t_acc_means(t).colptr(best_g); + eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g); + + for(uword d=0; d::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* acc_mean = acc_means.colptr(best_g); + eT* acc_dcov = acc_dcovs.colptr(best_g); + + for(uword d=0; d& fcov = access::rw(fcovs).slice(g); + fcov.zeros(); + + for(uword d=0; d= 1) ? tmp : eT(0); + fcov.at(d,d) = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor); + } + + hefts_mem[g] = eT(acc_heft) / eT(X_n_cols); + } + + em_fix_params(var_floor); + } + + + +//! multi-threaded implementation of k-means, inspired by MapReduce +template +template +inline +bool +gmm_full::km_iterate(const Mat& X, const uword max_iter, const bool verbose) + { + arma_extra_debug_sigprint(); + + if(verbose) + { + get_cout_stream().unsetf(ios::showbase); + get_cout_stream().unsetf(ios::uppercase); + get_cout_stream().unsetf(ios::showpos); + get_cout_stream().unsetf(ios::scientific); + + get_cout_stream().setf(ios::right); + get_cout_stream().setf(ios::fixed); + } + + const uword X_n_cols = X.n_cols; + + if(X_n_cols == 0) { return true; } + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT* mah_aux_mem = mah_aux.memptr(); + + Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); + Row acc_hefts( N_gaus, arma_zeros_indicator()); + Row last_indx( N_gaus, arma_zeros_indicator()); + + Mat new_means = means; + Mat old_means = means; + + running_mean_scalar rs_delta; + + #if defined(ARMA_USE_OPENMP) + const umat boundaries = internal_gen_boundaries(X_n_cols); + const uword n_threads = boundaries.n_cols; + + field< Mat > t_acc_means(n_threads); + field< Row > t_acc_hefts(n_threads); + field< Row > t_last_indx(n_threads); + #else + const uword n_threads = 1; + #endif + + if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: n_threads: " << n_threads << '\n'; get_cout_stream().flush(); } + + for(uword iter=1; iter <= max_iter; ++iter) + { + #if defined(ARMA_USE_OPENMP) + { + for(uword t=0; t < n_threads; ++t) + { + t_acc_means(t).zeros(N_dims, N_gaus); + t_acc_hefts(t).zeros(N_gaus); + t_last_indx(t).zeros(N_gaus); + } + + #pragma omp parallel for schedule(static) + for(uword t=0; t < n_threads; ++t) + { + Mat& t_acc_means_t = t_acc_means(t); + uword* t_acc_hefts_mem = t_acc_hefts(t).memptr(); + uword* t_last_indx_mem = t_last_indx(t).memptr(); + + const uword start_index = boundaries.at(0,t); + const uword end_index = boundaries.at(1,t); + + for(uword i=start_index; i <= end_index; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT min_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* t_acc_mean = t_acc_means_t.colptr(best_g); + + for(uword d=0; d= 1 ) { last_indx(g) = t_last_indx(t)(g); } + } + } + #else + { + acc_hefts.zeros(); + acc_means.zeros(); + last_indx.zeros(); + + uword* acc_hefts_mem = acc_hefts.memptr(); + uword* last_indx_mem = last_indx.memptr(); + + for(uword i=0; i < X_n_cols; ++i) + { + const eT* X_colptr = X.colptr(i); + + eT min_dist = Datum::inf; + uword best_g = 0; + + for(uword g=0; g::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem); + + if(dist < min_dist) { min_dist = dist; best_g = g; } + } + + eT* acc_mean = acc_means.colptr(best_g); + + for(uword d=0; d= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0); + } + } + + + // heuristics to resurrect dead means + + const uvec dead_gs = find(acc_hefts == uword(0)); + + if(dead_gs.n_elem > 0) + { + if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: recovering from dead means\n"; get_cout_stream().flush(); } + + uword* last_indx_mem = last_indx.memptr(); + + const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" ); + + if(live_gs.n_elem == 0) { return false; } + + uword live_gs_count = 0; + + for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count) + { + const uword dead_g_id = dead_gs(dead_gs_count); + + uword proposed_i = 0; + + if(live_gs_count < live_gs.n_elem) + { + const uword live_g_id = live_gs(live_gs_count); ++live_gs_count; + + if(live_g_id == dead_g_id) { return false; } + + // recover by using a sample from a known good mean + proposed_i = last_indx_mem[live_g_id]; + } + else + { + // recover by using a randomly seleced sample (last resort) + proposed_i = as_scalar(randi(1, distr_param(0,X_n_cols-1))); + } + + if(proposed_i >= X_n_cols) { return false; } + + new_means.col(dead_g_id) = X.col(proposed_i); + } + } + + rs_delta.reset(); + + for(uword g=0; g < N_gaus; ++g) + { + rs_delta( distance::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) ); + } + + if(verbose) + { + get_cout_stream() << "gmm_full::learn(): k-means: iteration: "; + get_cout_stream().unsetf(ios::scientific); + get_cout_stream().setf(ios::fixed); + get_cout_stream().width(std::streamsize(4)); + get_cout_stream() << iter; + get_cout_stream() << " delta: "; + get_cout_stream().unsetf(ios::fixed); + //get_cout_stream().setf(ios::scientific); + get_cout_stream() << rs_delta.mean() << '\n'; + get_cout_stream().flush(); + } + + arma::swap(old_means, new_means); + + if(rs_delta.mean() <= Datum::eps) { break; } + } + + access::rw(means) = old_means; + + if(means.internal_has_nonfinite()) { return false; } + + return true; + } + + + +//! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce +template +inline +bool +gmm_full::em_iterate(const Mat& X, const uword max_iter, const eT var_floor, const bool verbose) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + if(verbose) + { + get_cout_stream().unsetf(ios::showbase); + get_cout_stream().unsetf(ios::uppercase); + get_cout_stream().unsetf(ios::showpos); + get_cout_stream().unsetf(ios::scientific); + + get_cout_stream().setf(ios::right); + get_cout_stream().setf(ios::fixed); + } + + const umat boundaries = internal_gen_boundaries(X.n_cols); + + const uword n_threads = boundaries.n_cols; + + field< Mat > t_acc_means(n_threads); + field< Cube > t_acc_fcovs(n_threads); + + field< Col > t_acc_norm_lhoods(n_threads); + field< Col > t_gaus_log_lhoods(n_threads); + + Col t_progress_log_lhood(n_threads, arma_nozeros_indicator()); + + for(uword t=0; t::inf; + + const bool calc_chol = false; + + for(uword iter=1; iter <= max_iter; ++iter) + { + init_constants(calc_chol); + + em_update_params(X, boundaries, t_acc_means, t_acc_fcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood, var_floor); + + em_fix_params(var_floor); + + const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem); + + if(verbose) + { + get_cout_stream() << "gmm_full::learn(): EM: iteration: "; + get_cout_stream().unsetf(ios::scientific); + get_cout_stream().setf(ios::fixed); + get_cout_stream().width(std::streamsize(4)); + get_cout_stream() << iter; + get_cout_stream() << " avg_log_p: "; + get_cout_stream().unsetf(ios::fixed); + //get_cout_stream().setf(ios::scientific); + get_cout_stream() << new_avg_log_p << '\n'; + get_cout_stream().flush(); + } + + if(arma_isfinite(new_avg_log_p) == false) { return false; } + + if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum::eps) { break; } + + + old_avg_log_p = new_avg_log_p; + } + + + for(uword g=0; g < N_gaus; ++g) + { + const Mat& fcov = fcovs.slice(g); + + if(any(vectorise(fcov.diag()) <= eT(0))) { return false; } + } + + if(means.internal_has_nonfinite()) { return false; } + if(fcovs.internal_has_nonfinite()) { return false; } + if(hefts.internal_has_nonfinite()) { return false; } + + return true; + } + + + + +template +inline +void +gmm_full::em_update_params + ( + const Mat& X, + const umat& boundaries, + field< Mat >& t_acc_means, + field< Cube >& t_acc_fcovs, + field< Col >& t_acc_norm_lhoods, + field< Col >& t_gaus_log_lhoods, + Col& t_progress_log_lhood, + const eT var_floor + ) + { + arma_extra_debug_sigprint(); + + const uword n_threads = boundaries.n_cols; + + + // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp parallel for schedule(static) + for(uword t=0; t& acc_means = t_acc_means[t]; + Cube& acc_fcovs = t_acc_fcovs[t]; + Col& acc_norm_lhoods = t_acc_norm_lhoods[t]; + Col& gaus_log_lhoods = t_gaus_log_lhoods[t]; + eT& progress_log_lhood = t_progress_log_lhood[t]; + + em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_fcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood); + } + } + #else + { + em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_fcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]); + } + #endif + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + Mat& final_acc_means = t_acc_means[0]; + Cube& final_acc_fcovs = t_acc_fcovs[0]; + + Col& final_acc_norm_lhoods = t_acc_norm_lhoods[0]; + + + // the "reduce" operation, which combines the partial accumulators produced by the separate threads + + for(uword t=1; t mean_outer(N_dims, N_dims, arma_nozeros_indicator()); + + + //// update each component without sanity checking + //for(uword g=0; g < N_gaus; ++g) + // { + // const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits::min() ); + // + // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); + // + // eT* mean_mem = access::rw(means).colptr(g); + // eT* acc_mean_mem = final_acc_means.colptr(g); + // + // for(uword d=0; d < N_dims; ++d) + // { + // mean_mem[d] = acc_mean_mem[d] / acc_norm_lhood; + // } + // + // const Col mean(mean_mem, N_dims, false, true); + // + // mean_outer = mean * mean.t(); + // + // Mat& fcov = access::rw(fcovs).slice(g); + // Mat& acc_fcov = final_acc_fcovs.slice(g); + // + // fcov = acc_fcov / acc_norm_lhood - mean_outer; + // } + + + // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them + for(uword g=0; g < N_gaus; ++g) + { + const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits::min() ); + + if(arma_isfinite(acc_norm_lhood) == false) { continue; } + + eT* acc_mean_mem = final_acc_means.colptr(g); + + for(uword d=0; d < N_dims; ++d) + { + acc_mean_mem[d] /= acc_norm_lhood; + } + + const Col new_mean(acc_mean_mem, N_dims, false, true); + + mean_outer = new_mean * new_mean.t(); + + Mat& acc_fcov = final_acc_fcovs.slice(g); + + acc_fcov /= acc_norm_lhood; + acc_fcov -= mean_outer; + + for(uword d=0; d < N_dims; ++d) + { + eT& val = acc_fcov.at(d,d); + + if(val < var_floor) { val = var_floor; } + } + + if(acc_fcov.internal_has_nonfinite()) { continue; } + + eT log_det_val = eT(0); + eT log_det_sign = eT(0); + + const bool log_det_status = log_det(log_det_val, log_det_sign, acc_fcov); + + const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) ); + + const bool inv_ok = (log_det_ok) ? bool(auxlib::inv_sympd(mean_outer, acc_fcov)) : bool(false); // mean_outer is used as a junk matrix + + if(log_det_ok && inv_ok) + { + hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); + + eT* mean_mem = access::rw(means).colptr(g); + + for(uword d=0; d < N_dims; ++d) + { + mean_mem[d] = acc_mean_mem[d]; + } + + Mat& fcov = access::rw(fcovs).slice(g); + + fcov = acc_fcov; + } + } + } + + + +template +inline +void +gmm_full::em_generate_acc + ( + const Mat& X, + const uword start_index, + const uword end_index, + Mat& acc_means, + Cube& acc_fcovs, + Col& acc_norm_lhoods, + Col& gaus_log_lhoods, + eT& progress_log_lhood + ) + const + { + arma_extra_debug_sigprint(); + + progress_log_lhood = eT(0); + + acc_means.zeros(); + acc_fcovs.zeros(); + + acc_norm_lhoods.zeros(); + gaus_log_lhoods.zeros(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT* log_hefts_mem = log_hefts.memptr(); + eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr(); + + + for(uword i=start_index; i <= end_index; i++) + { + const eT* x = X.colptr(i); + + for(uword g=0; g < N_gaus; ++g) + { + gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g]; + } + + eT log_lhood_sum = gaus_log_lhoods_mem[0]; + + for(uword g=1; g < N_gaus; ++g) + { + log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); + } + + progress_log_lhood += log_lhood_sum; + + for(uword g=0; g < N_gaus; ++g) + { + const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum); + + acc_norm_lhoods[g] += norm_lhood; + + eT* acc_mean_mem = acc_means.colptr(g); + + for(uword d=0; d < N_dims; ++d) + { + acc_mean_mem[d] += x[d] * norm_lhood; + } + + Mat& acc_fcov = access::rw(acc_fcovs).slice(g); + + // specialised version of acc_fcov += norm_lhood * (xx * xx.t()); + + for(uword d=0; d < N_dims; ++d) + { + const uword dp1 = d+1; + + const eT xd = x[d]; + + eT* acc_fcov_col_d = acc_fcov.colptr(d) + d; + eT* acc_fcov_row_d = &(acc_fcov.at(d,dp1)); + + (*acc_fcov_col_d) += norm_lhood * (xd * xd); acc_fcov_col_d++; + + for(uword e=dp1; e < N_dims; ++e) + { + const eT val = norm_lhood * (xd * x[e]); + + (*acc_fcov_col_d) += val; acc_fcov_col_d++; + (*acc_fcov_row_d) += val; acc_fcov_row_d += N_dims; + } + } + } + } + + progress_log_lhood /= eT((end_index - start_index) + 1); + } + + + +template +inline +void +gmm_full::em_fix_params(const eT var_floor) + { + arma_extra_debug_sigprint(); + + const uword N_dims = means.n_rows; + const uword N_gaus = means.n_cols; + + const eT var_ceiling = std::numeric_limits::max(); + + for(uword g=0; g < N_gaus; ++g) + { + Mat& fcov = access::rw(fcovs).slice(g); + + for(uword d=0; d < N_dims; ++d) + { + eT& var_val = fcov.at(d,d); + + if(var_val < var_floor ) { var_val = var_floor; } + else if(var_val > var_ceiling) { var_val = var_ceiling; } + else if(arma_isnan(var_val) ) { var_val = eT(1); } + } + } + + + eT* hefts_mem = access::rw(hefts).memptr(); + + for(uword g1=0; g1 < N_gaus; ++g1) + { + if(hefts_mem[g1] > eT(0)) + { + const eT* means_colptr_g1 = means.colptr(g1); + + for(uword g2=(g1+1); g2 < N_gaus; ++g2) + { + if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits::epsilon()) ) + { + const eT dist = distance::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1); + + if(dist == eT(0)) { hefts_mem[g2] = eT(0); } + } + } + } + } + + const eT heft_floor = std::numeric_limits::min(); + const eT heft_initial = eT(1) / eT(N_gaus); + + for(uword i=0; i < N_gaus; ++i) + { + eT& heft_val = hefts_mem[i]; + + if(heft_val < heft_floor) { heft_val = heft_floor; } + else if(heft_val > eT(1) ) { heft_val = eT(1); } + else if(arma_isnan(heft_val) ) { heft_val = heft_initial; } + } + + const eT heft_sum = accu(hefts); + + if((heft_sum < (eT(1) - Datum::eps)) || (heft_sum > (eT(1) + Datum::eps))) { access::rw(hefts) /= heft_sum; } + } + + + +} // namespace gmm_priv + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/gmm_misc_bones.hpp b/src/armadillo/include/armadillo_bits/gmm_misc_bones.hpp new file mode 100644 index 0000000..44507d4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/gmm_misc_bones.hpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gmm_misc +//! @{ + + +struct gmm_dist_mode { const uword id; inline constexpr explicit gmm_dist_mode(const uword in_id) : id(in_id) {} }; + +inline bool operator==(const gmm_dist_mode& a, const gmm_dist_mode& b) { return (a.id == b.id); } +inline bool operator!=(const gmm_dist_mode& a, const gmm_dist_mode& b) { return (a.id != b.id); } + +struct gmm_dist_eucl : public gmm_dist_mode { inline constexpr gmm_dist_eucl() : gmm_dist_mode(1) {} }; +struct gmm_dist_maha : public gmm_dist_mode { inline constexpr gmm_dist_maha() : gmm_dist_mode(2) {} }; +struct gmm_dist_prob : public gmm_dist_mode { inline constexpr gmm_dist_prob() : gmm_dist_mode(3) {} }; + +static constexpr gmm_dist_eucl eucl_dist; +static constexpr gmm_dist_maha maha_dist; +static constexpr gmm_dist_prob prob_dist; + + + +struct gmm_seed_mode { const uword id; inline constexpr explicit gmm_seed_mode(const uword in_id) : id(in_id) {} }; + +inline bool operator==(const gmm_seed_mode& a, const gmm_seed_mode& b) { return (a.id == b.id); } +inline bool operator!=(const gmm_seed_mode& a, const gmm_seed_mode& b) { return (a.id != b.id); } + +struct gmm_seed_keep_existing : public gmm_seed_mode { inline constexpr gmm_seed_keep_existing() : gmm_seed_mode(1) {} }; +struct gmm_seed_static_subset : public gmm_seed_mode { inline constexpr gmm_seed_static_subset() : gmm_seed_mode(2) {} }; +struct gmm_seed_static_spread : public gmm_seed_mode { inline constexpr gmm_seed_static_spread() : gmm_seed_mode(3) {} }; +struct gmm_seed_random_subset : public gmm_seed_mode { inline constexpr gmm_seed_random_subset() : gmm_seed_mode(4) {} }; +struct gmm_seed_random_spread : public gmm_seed_mode { inline constexpr gmm_seed_random_spread() : gmm_seed_mode(5) {} }; + +static constexpr gmm_seed_keep_existing keep_existing; +static constexpr gmm_seed_static_subset static_subset; +static constexpr gmm_seed_static_spread static_spread; +static constexpr gmm_seed_random_subset random_subset; +static constexpr gmm_seed_random_spread random_spread; + + +namespace gmm_priv +{ + + +template class gmm_diag; +template class gmm_full; + + +struct gmm_empty_arg {}; + + +// running_mean_scalar + +template +class running_mean_scalar + { + public: + + inline running_mean_scalar(); + inline running_mean_scalar(const running_mean_scalar& in_rms); + + inline const running_mean_scalar& operator=(const running_mean_scalar& in_rms); + + arma_hot inline void operator() (const eT X); + + inline void reset(); + + inline uword count() const; + inline eT mean() const; + + + private: + + arma_aligned uword counter; + arma_aligned eT r_mean; + }; + + + +// distance + +template +struct distance {}; + + +template +struct distance + { + arma_inline static eT eval(const uword N, const eT* A, const eT* B, const eT*); + }; + + + +template +struct distance + { + arma_inline static eT eval(const uword N, const eT* A, const eT* B, const eT* C); + }; + + +} + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/gmm_misc_meat.hpp b/src/armadillo/include/armadillo_bits/gmm_misc_meat.hpp new file mode 100644 index 0000000..3276b46 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/gmm_misc_meat.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gmm_misc +//! @{ + + +namespace gmm_priv +{ + + +template +inline +running_mean_scalar::running_mean_scalar() + : counter(uword(0)) + , r_mean ( eT(0)) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +running_mean_scalar::running_mean_scalar(const running_mean_scalar& in) + : counter(in.counter) + , r_mean (in.r_mean ) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +const running_mean_scalar& +running_mean_scalar::operator=(const running_mean_scalar& in) + { + arma_extra_debug_sigprint(); + + counter = in.counter; + r_mean = in.r_mean; + + return *this; + } + + + +template +inline +void +running_mean_scalar::operator() (const eT X) + { + arma_extra_debug_sigprint(); + + counter++; + + if(counter > 1) + { + const eT old_r_mean = r_mean; + + r_mean = old_r_mean + (X - old_r_mean)/counter; + } + else + { + r_mean = X; + } + } + + + +template +inline +void +running_mean_scalar::reset() + { + arma_extra_debug_sigprint(); + + counter = 0; + r_mean = eT(0); + } + + + +template +inline +uword +running_mean_scalar::count() const + { + return counter; + } + + + +template +inline +eT +running_mean_scalar::mean() const + { + return r_mean; + } + + + +// +// +// + + + +template +arma_inline +eT +distance::eval(const uword N, const eT* A, const eT* B, const eT*) + { + eT acc1 = eT(0); + eT acc2 = eT(0); + + uword i,j; + for(i=0, j=1; j +arma_inline +eT +distance::eval(const uword N, const eT* A, const eT* B, const eT* C) + { + eT acc1 = eT(0); + eT acc2 = eT(0); + + uword i,j; + for(i=0, j=1; j +inline +hid_t +get_hdf5_type() + { + return -1; // Return invalid. + } + + + +//! Specializations for each valid element type +//! (taken from all the possible typedefs of {u8, s8, ..., u64, s64} and the other native types. +//! We can't use the actual u8/s8 typedefs because their relations to the H5T_... types are unclear. +template<> +inline +hid_t +get_hdf5_type< unsigned char >() + { + return H5Tcopy(H5T_NATIVE_UCHAR); + } + +template<> +inline +hid_t +get_hdf5_type< char >() + { + return H5Tcopy(H5T_NATIVE_CHAR); + } + +template<> +inline +hid_t +get_hdf5_type< short >() + { + return H5Tcopy(H5T_NATIVE_SHORT); + } + +template<> +inline +hid_t +get_hdf5_type< unsigned short >() + { + return H5Tcopy(H5T_NATIVE_USHORT); + } + +template<> +inline +hid_t +get_hdf5_type< int >() + { + return H5Tcopy(H5T_NATIVE_INT); + } + +template<> +inline +hid_t +get_hdf5_type< unsigned int >() + { + return H5Tcopy(H5T_NATIVE_UINT); + } + +template<> +inline +hid_t +get_hdf5_type< long >() + { + return H5Tcopy(H5T_NATIVE_LONG); + } + +template<> +inline +hid_t +get_hdf5_type< unsigned long >() + { + return H5Tcopy(H5T_NATIVE_ULONG); + } + +template<> +inline +hid_t +get_hdf5_type< long long >() + { + return H5Tcopy(H5T_NATIVE_LLONG); + } + +template<> +inline +hid_t +get_hdf5_type< unsigned long long >() + { + return H5Tcopy(H5T_NATIVE_ULLONG); + } + +template<> +inline +hid_t +get_hdf5_type< float >() + { + return H5Tcopy(H5T_NATIVE_FLOAT); + } + +template<> +inline +hid_t +get_hdf5_type< double >() + { + return H5Tcopy(H5T_NATIVE_DOUBLE); + } + + + +//! Utility hid_t since HOFFSET() won't work with std::complex. +template +struct hdf5_complex_t + { + eT real; + eT imag; + }; + + + +template<> +inline +hid_t +get_hdf5_type< std::complex >() + { + hid_t type = H5Tcreate(H5T_COMPOUND, sizeof(hdf5_complex_t)); + + H5Tinsert(type, "real", HOFFSET(hdf5_complex_t, real), H5T_NATIVE_FLOAT); + H5Tinsert(type, "imag", HOFFSET(hdf5_complex_t, imag), H5T_NATIVE_FLOAT); + + return type; + } + + + +template<> +inline +hid_t +get_hdf5_type< std::complex >() + { + hid_t type = H5Tcreate(H5T_COMPOUND, sizeof(hdf5_complex_t)); + + H5Tinsert(type, "real", HOFFSET(hdf5_complex_t, real), H5T_NATIVE_DOUBLE); + H5Tinsert(type, "imag", HOFFSET(hdf5_complex_t, imag), H5T_NATIVE_DOUBLE); + + return type; + } + + + +// Compare datatype against all supported types. +inline +bool +is_supported_arma_hdf5_type(hid_t datatype) + { + hid_t search_type; + + bool is_equal; + + + // start with most likely used types: double, complex, float, complex + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type< std::complex >(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type< std::complex >(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + + // remaining supported types: u8, s8, u16, s16, u32, s32, u64, s64, ulng_t, slng_t + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + return false; + } + + + +//! Auxiliary functions and structs for search_hdf5_file. +struct hdf5_search_info + { + const std::vector& names; + int num_dims; + bool exact; + hid_t best_match; + size_t best_match_position; // Position of best match in names vector. + }; + + + +inline +herr_t +hdf5_search_callback + ( + hid_t loc_id, + const char* name, + const H5O_info_t* info, + void* operator_data // hdf5_search_info + ) + { + hdf5_search_info* search_info = (hdf5_search_info*) operator_data; + + // We are looking for datasets. + if(info->type == H5O_TYPE_DATASET) + { + // Check type of dataset to see if we could even load it. + hid_t dataset = H5Dopen(loc_id, name, H5P_DEFAULT); + hid_t datatype = H5Dget_type(dataset); + + const bool is_supported = is_supported_arma_hdf5_type(datatype); + + H5Tclose(datatype); + H5Dclose(dataset); + + if(is_supported == false) + { + // Forget about it and move on. + return 0; + } + + // Now we have to check against our set of names. + // Only check names which could be better. + for(size_t string_pos = 0; string_pos < search_info->best_match_position; ++string_pos) + { + // name is the full path (/path/to/dataset); names[string_pos] may be + // "dataset", "/to/dataset", or "/path/to/dataset". + // So if we count the number of forward slashes in names[string_pos], + // and then simply take the last substring of name containing that number of slashes, + // we can do the comparison. + + // Count the number of forward slashes in names[string_pos]. + uword name_count = 0; + for(uword i = 0; i < search_info->names[string_pos].length(); ++i) + { + if((search_info->names[string_pos])[i] == '/') { ++name_count; } + } + + // Count the number of forward slashes in the full name. + uword count = 0; + const std::string str = std::string(name); + for(uword i = 0; i < str.length(); ++i) + { + if(str[i] == '/') { ++count; } + } + + // Is the full string the same? + if(str == search_info->names[string_pos]) + { + // We found it exactly. + hid_t match_candidate = H5Dopen(loc_id, name, H5P_DEFAULT); + + if(match_candidate < 0) + { + return -1; + } + + // Ensure that the dataset is valid and of the correct dimensionality. + hid_t filespace = H5Dget_space(match_candidate); + int num_dims = H5Sget_simple_extent_ndims(filespace); + + if(num_dims <= search_info->num_dims) + { + // Valid dataset -- we'll keep it. + // If we already have an existing match we have to close it. + if(search_info->best_match != -1) + { + H5Dclose(search_info->best_match); + } + + search_info->best_match_position = string_pos; + search_info->best_match = match_candidate; + } + + H5Sclose(filespace); + // There is no possibility of anything better, so terminate the search. + return 1; + } + + // If we are asking for more slashes than we have, this can't be a match. + // Skip to below, where we decide whether or not to keep it anyway based + // on the exactness condition of the search. + if(count <= name_count) + { + size_t start_pos = (count == 0) ? 0 : std::string::npos; + while(count > 0) + { + // Move pointer to previous slash. + start_pos = str.rfind('/', start_pos); + + // Break if we've run out of slashes. + if(start_pos == std::string::npos) { break; } + + --count; + } + + // Now take the substring (this may end up being the full string). + const std::string substring = str.substr(start_pos); + + // Are they the same? + if(substring == search_info->names[string_pos]) + { + // We have found the object; it must be better than our existing match. + hid_t match_candidate = H5Dopen(loc_id, name, H5P_DEFAULT); + + + // arma_check(match_candidate < 0, "Mat::load(): cannot open an HDF5 dataset"); + if(match_candidate < 0) + { + return -1; + } + + + // Ensure that the dataset is valid and of the correct dimensionality. + hid_t filespace = H5Dget_space(match_candidate); + int num_dims = H5Sget_simple_extent_ndims(filespace); + + if(num_dims <= search_info->num_dims) + { + // Valid dataset -- we'll keep it. + // If we already have an existing match we have to close it. + if(search_info->best_match != -1) + { + H5Dclose(search_info->best_match); + } + + search_info->best_match_position = string_pos; + search_info->best_match = match_candidate; + } + + H5Sclose(filespace); + } + } + + + // If they are not the same, but we have not found anything and we don't need an exact match, take this. + if((search_info->exact == false) && (search_info->best_match == -1)) + { + hid_t match_candidate = H5Dopen(loc_id, name, H5P_DEFAULT); + + // arma_check(match_candidate < 0, "Mat::load(): cannot open an HDF5 dataset"); + if(match_candidate < 0) + { + return -1; + } + + hid_t filespace = H5Dget_space(match_candidate); + int num_dims = H5Sget_simple_extent_ndims(filespace); + + if(num_dims <= search_info->num_dims) + { + // Valid dataset -- we'll keep it. + search_info->best_match = H5Dopen(loc_id, name, H5P_DEFAULT); + } + + H5Sclose(filespace); + } + } + } + + return 0; + } + + + +//! Search an HDF5 file for the given dataset names. +//! If 'exact' is true, failure to find a dataset in the list of names means that -1 is returned. +//! If 'exact' is false and no datasets are found, -1 is returned. +//! The number of dimensions is used to help prune down invalid datasets; +//! 2 dimensions is a matrix, 1 dimension is a vector, and 3 dimensions is a cube. +//! If the number of dimensions in a dataset is less than or equal to num_dims, +//! it will be considered -- for instance, a one-dimensional HDF5 vector can be loaded as a single-column matrix. +inline +hid_t +search_hdf5_file + ( + const std::vector& names, + hid_t hdf5_file, + int num_dims = 2, + bool exact = false + ) + { + hdf5_search_info search_info = { names, num_dims, exact, -1, names.size() }; + + // We'll use the H5Ovisit to track potential entries. + herr_t status = H5Ovisit(hdf5_file, H5_INDEX_NAME, H5_ITER_NATIVE, hdf5_search_callback, void_ptr(&search_info)); + + // Return the best match; it will be -1 if there was a problem. + return (status < 0) ? -1 : search_info.best_match; + } + + + +//! Load an HDF5 matrix into an array of type specified by datatype, +//! then convert that into the desired array 'dest'. +//! This should only be called when eT is not the datatype. +template +inline +hid_t +load_and_convert_hdf5 + ( + eT *dest, + hid_t dataset, + hid_t datatype, + uword n_elem + ) + { + + // We can't use nice template specializations here + // as the determination of the type of 'datatype' must be done at runtime. + // So we end up with this ugliness... + hid_t search_type; + + bool is_equal; + + + // u8 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // s8 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // u16 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // s16 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // u32 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // s32 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // u64 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // s64 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // ulng_t + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // slng_t + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // float + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // double + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); + + return status; + } + + + // complex float + search_type = get_hdf5_type< std::complex >(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + if(is_cx::no) + { + return -1; // can't read complex data into non-complex matrix/cube + } + + Col< std::complex > v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert_cx(dest, v.memptr(), n_elem); + + return status; + } + + + // complex double + search_type = get_hdf5_type< std::complex >(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + if(is_cx::no) + { + return -1; // can't read complex data into non-complex matrix/cube + } + + Col< std::complex > v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert_cx(dest, v.memptr(), n_elem); + + return status; + } + + + return -1; // Failure. + } + + + +struct hdf5_suspend_printing_errors + { + #if (ARMA_WARN_LEVEL >= 3) + + inline + hdf5_suspend_printing_errors() {} + + #else + + herr_t (*old_client_func)(hid_t, void*); + void* old_client_data; + + inline + hdf5_suspend_printing_errors() + { + // Save old error handler. + H5Eget_auto(H5E_DEFAULT, &old_client_func, &old_client_data); + + // Disable annoying HDF5 error messages. + H5Eset_auto(H5E_DEFAULT, NULL, NULL); + } + + inline + ~hdf5_suspend_printing_errors() + { + H5Eset_auto(H5E_DEFAULT, old_client_func, old_client_data); + } + + #endif + }; + + + +} // namespace hdf5_misc +#endif // #if defined(ARMA_USE_HDF5) + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/hdf5_name.hpp b/src/armadillo/include/armadillo_bits/hdf5_name.hpp new file mode 100644 index 0000000..8dd38cc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/hdf5_name.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diskio +//! @{ + + +namespace hdf5_opts + { + typedef unsigned int flag_type; + + struct opts + { + const flag_type flags; + + inline constexpr explicit opts(const flag_type in_flags); + + inline const opts operator+(const opts& rhs) const; + }; + + inline + constexpr + opts::opts(const flag_type in_flags) + : flags(in_flags) + {} + + inline + const opts + opts::operator+(const opts& rhs) const + { + const opts result( flags | rhs.flags ); + + return result; + } + + // The values below (eg. 1u << 0) are for internal Armadillo use only. + // The values can change without notice. + + static constexpr flag_type flag_none = flag_type(0 ); + static constexpr flag_type flag_trans = flag_type(1u << 0); + static constexpr flag_type flag_append = flag_type(1u << 1); + static constexpr flag_type flag_replace = flag_type(1u << 2); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_trans : public opts { inline constexpr opts_trans() : opts(flag_trans ) {} }; + struct opts_append : public opts { inline constexpr opts_append() : opts(flag_append ) {} }; + struct opts_replace : public opts { inline constexpr opts_replace() : opts(flag_replace) {} }; + + static constexpr opts_none none; + static constexpr opts_trans trans; + static constexpr opts_append append; + static constexpr opts_replace replace; + } + + +struct hdf5_name + { + const std::string filename; + const std::string dsname; + const hdf5_opts::opts opts; + + inline + hdf5_name(const std::string& in_filename) + : filename(in_filename ) + , dsname (std::string() ) + , opts (hdf5_opts::none) + {} + + inline + hdf5_name(const std::string& in_filename, const std::string& in_dsname, const hdf5_opts::opts& in_opts = hdf5_opts::none) + : filename(in_filename) + , dsname (in_dsname ) + , opts (in_opts ) + {} + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/include_hdf5.hpp b/src/armadillo/include/armadillo_bits/include_hdf5.hpp new file mode 100644 index 0000000..a639f78 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/include_hdf5.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_HDF5) + + #undef H5_USE_110_API + #define H5_USE_110_API + + #if defined(__has_include) + #if __has_include() + #include + #else + #undef ARMA_USE_HDF5 + #pragma message ("WARNING: use of HDF5 disabled; hdf5.h header not found") + #endif + #else + #include + #endif + + #if defined(H5_USE_16_API) || defined(H5_USE_16_API_DEFAULT) + #pragma message ("WARNING: use of HDF5 disabled; incompatible configuration: H5_USE_16_API or H5_USE_16_API_DEFAULT") + #undef ARMA_USE_HDF5 + #endif + + // // TODO + // #if defined(H5_USE_18_API) || defined(H5_USE_18_API_DEFAULT) + // #pragma message ("WARNING: detected possibly incompatible configuration of HDF5: H5_USE_18_API or H5_USE_18_API_DEFAULT") + // #endif + +#endif diff --git a/src/armadillo/include/armadillo_bits/include_superlu.hpp b/src/armadillo/include/armadillo_bits/include_superlu.hpp new file mode 100644 index 0000000..43fa0a7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/include_superlu.hpp @@ -0,0 +1,393 @@ +// SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ------------------------------------------------------------------------ +// +// This file includes portions of SuperLU 5.2 software, +// licensed under the following conditions. +// +// Copyright (c) 2003, The Regents of the University of California, through +// Lawrence Berkeley National Laboratory (subject to receipt of any required +// approvals from U.S. Dept. of Energy) +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// (1) Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// (2) Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// (3) Neither the name of Lawrence Berkeley National Laboratory, U.S. Dept. of +// Energy nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +// OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_SUPERLU) + +#undef ARMA_SLU_HEADERS_FOUND + +// Since we need to suport float, double, cx_float and cx_double, +// as well as preserve the sanity of the user, +// we cannot simply include all the SuperLU headers due to their messy state +// (duplicate definitions, pollution of global namespace, bizarro defines). +// As such we are forced to include only a subset of the headers +// and manually specify a few SuperLU structures and function prototypes. +// +// CAVEAT: +// This code requires SuperLU version 5.2, +// and assumes that newer 5.x versions will have no API changes. + +namespace arma +{ +namespace superlu + { + // slu_*defs.h has int typedefed to int_t. + // I'll just write it as int for simplicity, where I can, but supermatrix.h needs int_t. + typedef int int_t; + } +} + +#if defined(ARMA_USE_SUPERLU_HEADERS) || defined(ARMA_SUPERLU_INCLUDE_DIR) + +namespace arma +{ +namespace superlu + { + // Include supermatrix.h. This gives us SuperMatrix. + // Put it in the slu namespace. + // For versions of SuperLU I am familiar with, supermatrix.h does not include any other files. + // Therefore, putting it in the superlu namespace is reasonably safe. + // This same reasoning is true for superlu_enum_consts.h. + + #undef ARMA_SLU_HEADER_A + #undef ARMA_SLU_HEADER_B + + #if defined(ARMA_SUPERLU_INCLUDE_DIR) + #undef ARMA_SLU_STR1 + #undef ARMA_SLU_STR2 + + #define ARMA_SLU_STR1(x) x + #define ARMA_SLU_STR2(x) ARMA_SLU_STR1(x) + + #define ARMA_SLU_HEADER_A ARMA_SLU_STR2(ARMA_SUPERLU_INCLUDE_DIR)ARMA_SLU_STR2(supermatrix.h) + #define ARMA_SLU_HEADER_B ARMA_SLU_STR2(ARMA_SUPERLU_INCLUDE_DIR)ARMA_SLU_STR2(superlu_enum_consts.h) + #else + #define ARMA_SLU_HEADER_A supermatrix.h + #define ARMA_SLU_HEADER_B superlu_enum_consts.h + #endif + + #if defined(__has_include) + #if __has_include(ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_A)) && __has_include(ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_B)) + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_A) + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_B) + #define ARMA_SLU_HEADERS_FOUND + #endif + #else + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_A) + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_B) + #define ARMA_SLU_HEADERS_FOUND + #endif + + #undef ARMA_SLU_STR1 + #undef ARMA_SLU_STR2 + + #undef ARMA_SLU_HEADER_A + #undef ARMA_SLU_HEADER_B + + #if defined(ARMA_SLU_HEADERS_FOUND) + + typedef struct + { + int* panel_histo; + double* utime; + float* ops; + int TinyPivots; + int RefineSteps; + int expansions; + } SuperLUStat_t; + + typedef struct + { + fact_t Fact; + yes_no_t Equil; + colperm_t ColPerm; + trans_t Trans; + IterRefine_t IterRefine; + double DiagPivotThresh; + yes_no_t SymmetricMode; + yes_no_t PivotGrowth; + yes_no_t ConditionNumber; + rowperm_t RowPerm; + int ILU_DropRule; + double ILU_DropTol; + double ILU_FillFactor; + norm_t ILU_Norm; + double ILU_FillTol; + milu_t ILU_MILU; + double ILU_MILU_Dim; + yes_no_t ParSymbFact; + yes_no_t ReplaceTinyPivot; + yes_no_t SolveInitialized; + yes_no_t RefineInitialized; + yes_no_t PrintStat; + int nnzL, nnzU; + int num_lookaheads; + yes_no_t lookahead_etree; + yes_no_t SymPattern; + } superlu_options_t; + + typedef struct + { + float for_lu; + float total_needed; + } mem_usage_t; + + typedef struct e_node + { + int size; + void* mem; + } ExpHeader; + + typedef struct + { + int size; + int used; + int top1; + int top2; + void* array; + } LU_stack_t; + + typedef struct + { + int* xsup; + int* supno; + int* lsub; + int* xlsub; + void* lusup; + int* xlusup; + void* ucol; + int* usub; + int* xusub; + int nzlmax; + int nzumax; + int nzlumax; + int n; + LU_space_t MemModel; + int num_expansions; + ExpHeader* expanders; + LU_stack_t stack; + } GlobalLU_t; + + #endif + } +} + + +#endif + +#if defined(ARMA_USE_SUPERLU_HEADERS) && !defined(ARMA_SLU_HEADERS_FOUND) + #undef ARMA_USE_SUPERLU + #pragma message ("WARNING: use of SuperLU disabled; required headers not found") +#endif + +#endif + + + +#if defined(ARMA_USE_SUPERLU) && !defined(ARMA_SLU_HEADERS_FOUND) + +// Not using any SuperLU headers, so define all required enums and structs. + +#if defined(ARMA_SUPERLU_INCLUDE_DIR) + #pragma message ("WARNING: SuperLU headers not found; using built-in definitions") +#endif + +namespace arma +{ +namespace superlu + { + typedef enum + { + SLU_NC, + SLU_NCP, + SLU_NR, + SLU_SC, + SLU_SCP, + SLU_SR, + SLU_DN, + SLU_NR_loc + } Stype_t; + + typedef enum + { + SLU_S, + SLU_D, + SLU_C, + SLU_Z + } Dtype_t; + + typedef enum + { + SLU_GE, + SLU_TRLU, + SLU_TRUU, + SLU_TRL, + SLU_TRU, + SLU_SYL, + SLU_SYU, + SLU_HEL, + SLU_HEU + } Mtype_t; + + typedef struct + { + Stype_t Stype; + Dtype_t Dtype; + Mtype_t Mtype; + int_t nrow; + int_t ncol; + void* Store; + } SuperMatrix; + + typedef struct + { + int* panel_histo; + double* utime; + float* ops; + int TinyPivots; + int RefineSteps; + int expansions; + } SuperLUStat_t; + + typedef enum {NO, YES} yes_no_t; + typedef enum {DOFACT, SamePattern, SamePattern_SameRowPerm, FACTORED} fact_t; + typedef enum {NOROWPERM, LargeDiag, MY_PERMR} rowperm_t; + typedef enum {NATURAL, MMD_ATA, MMD_AT_PLUS_A, COLAMD, + METIS_AT_PLUS_A, PARMETIS, ZOLTAN, MY_PERMC} colperm_t; + typedef enum {NOTRANS, TRANS, CONJ} trans_t; + typedef enum {NOREFINE, SLU_SINGLE=1, SLU_DOUBLE, SLU_EXTRA} IterRefine_t; + typedef enum {SYSTEM, USER} LU_space_t; + typedef enum {ONE_NORM, TWO_NORM, INF_NORM} norm_t; + typedef enum {SILU, SMILU_1, SMILU_2, SMILU_3} milu_t; + + typedef struct + { + fact_t Fact; + yes_no_t Equil; + colperm_t ColPerm; + trans_t Trans; + IterRefine_t IterRefine; + double DiagPivotThresh; + yes_no_t SymmetricMode; + yes_no_t PivotGrowth; + yes_no_t ConditionNumber; + rowperm_t RowPerm; + int ILU_DropRule; + double ILU_DropTol; + double ILU_FillFactor; + norm_t ILU_Norm; + double ILU_FillTol; + milu_t ILU_MILU; + double ILU_MILU_Dim; + yes_no_t ParSymbFact; + yes_no_t ReplaceTinyPivot; + yes_no_t SolveInitialized; + yes_no_t RefineInitialized; + yes_no_t PrintStat; + int nnzL, nnzU; + int num_lookaheads; + yes_no_t lookahead_etree; + yes_no_t SymPattern; + } superlu_options_t; + + typedef struct + { + float for_lu; + float total_needed; + } mem_usage_t; + + typedef struct + { + int_t nnz; + void* nzval; + int_t* rowind; + int_t* colptr; + } NCformat; + + typedef struct + { + int_t lda; + void* nzval; + } DNformat; + + typedef struct e_node + { + int size; + void* mem; + } ExpHeader; + + typedef struct + { + int size; + int used; + int top1; + int top2; + void* array; + } LU_stack_t; + + typedef struct + { + int* xsup; + int* supno; + int* lsub; + int* xlsub; + void* lusup; + int* xlusup; + void* ucol; + int* usub; + int* xusub; + int nzlmax; + int nzumax; + int nzlumax; + int n; + LU_space_t MemModel; + int num_expansions; + ExpHeader* expanders; + LU_stack_t stack; + } GlobalLU_t; + } +} + +#undef ARMA_SLU_HEADERS_FOUND + +#endif diff --git a/src/armadillo/include/armadillo_bits/injector_bones.hpp b/src/armadillo/include/armadillo_bits/injector_bones.hpp new file mode 100644 index 0000000..80e2a17 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/injector_bones.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup injector +//! @{ + + + +template +class mat_injector + { + public: + + typedef typename T1::elem_type elem_type; + + arma_cold inline void insert(const elem_type val) const; + arma_cold inline void end_of_row() const; + arma_cold inline ~mat_injector(); + + + private: + + inline mat_injector(T1& in_X, const elem_type val); + inline mat_injector(T1& in_X, const injector_end_of_row<>&); + + T1& parent; + + mutable std::vector values; + mutable std::vector rowend; + + friend class Mat; + friend class Row; + friend class Col; + }; + + + +// + + + +template +class field_injector + { + public: + + typedef typename T1::object_type object_type; + + arma_cold inline void insert(const object_type& val) const; + arma_cold inline void end_of_row() const; + arma_cold inline ~field_injector(); + + + private: + + inline field_injector(T1& in_X, const object_type& val); + inline field_injector(T1& in_X, const injector_end_of_row<>&); + + T1& parent; + + mutable std::vector values; + mutable std::vector rowend; + + friend class field; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/injector_meat.hpp b/src/armadillo/include/armadillo_bits/injector_meat.hpp new file mode 100644 index 0000000..81962c0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/injector_meat.hpp @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup injector +//! @{ + + + +template +inline +mat_injector::mat_injector(T1& in_parent, const typename mat_injector::elem_type val) + : parent(in_parent) + { + arma_extra_debug_sigprint(); + + values.reserve(16); + rowend.reserve(16); + + insert(val); + } + + + +template +inline +mat_injector::mat_injector(T1& in_parent, const injector_end_of_row<>&) + : parent(in_parent) + { + arma_extra_debug_sigprint(); + + values.reserve(16); + rowend.reserve(16); + + end_of_row(); + } + + + +template +inline +mat_injector::~mat_injector() + { + arma_extra_debug_sigprint(); + + const uword N = values.size(); + + if(N == 0) { return; } + + uword n_rows = 1; + uword n_cols = 0; + + for(uword i=0; i::value) + { + arma_debug_check( (n_rows > 1), "matrix initialisation: incompatible dimensions" ); + + parent.zeros(1,n_cols); + + uword col = 0; + + for(uword i=0; i::value) + { + const bool is_vec = ((n_cols == 1) || (n_rows == 1)); + + arma_debug_check( (is_vec == false), "matrix initialisation: incompatible dimensions" ); + + if(n_cols == 1) + { + parent.zeros(n_rows,1); + + uword row = 0; + + for(uword i=0; i 0) && rowend[i-1]) { ++row; } + } + else + { + parent.at(row) = values[i]; + ++row; + } + } + } + else + if(n_rows == 1) + { + parent.zeros(n_cols,1); + + uword row = 0; + + for(uword i=0; i +inline +void +mat_injector::insert(const typename mat_injector::elem_type val) const + { + arma_extra_debug_sigprint(); + + values.push_back(val ); + rowend.push_back(char(0)); + } + + + + +template +inline +void +mat_injector::end_of_row() const + { + arma_extra_debug_sigprint(); + + typedef typename mat_injector::elem_type eT; + + values.push_back( eT(0)); + rowend.push_back(char(1)); + } + + + +template +inline +const mat_injector& +operator<<(const mat_injector& ref, const typename mat_injector::elem_type val) + { + arma_extra_debug_sigprint(); + + ref.insert(val); + + return ref; + } + + + +template +inline +const mat_injector& +operator<<(const mat_injector& ref, const injector_end_of_row<>&) + { + arma_extra_debug_sigprint(); + + ref.end_of_row(); + + return ref; + } + + + +// +// +// + + + +template +inline +field_injector::field_injector(T1& in_parent, const typename field_injector::object_type& val) + : parent(in_parent) + { + arma_extra_debug_sigprint(); + + insert(val); + } + + + +template +inline +field_injector::field_injector(T1& in_parent, const injector_end_of_row<>&) + : parent(in_parent) + { + arma_extra_debug_sigprint(); + + end_of_row(); + } + + + +template +inline +field_injector::~field_injector() + { + arma_extra_debug_sigprint(); + + const uword N = values.size(); + + if(N == 0) { return; } + + uword n_rows = 1; + uword n_cols = 0; + + for(uword i=0; i +inline +void +field_injector::insert(const typename field_injector::object_type& val) const + { + arma_extra_debug_sigprint(); + + values.push_back(val ); + rowend.push_back(char(0)); + } + + + + +template +inline +void +field_injector::end_of_row() const + { + arma_extra_debug_sigprint(); + + typedef typename field_injector::object_type oT; + + values.push_back(oT() ); + rowend.push_back(char(1)); + } + + + +template +inline +const field_injector& +operator<<(const field_injector& ref, const typename field_injector::object_type& val) + { + arma_extra_debug_sigprint(); + + ref.insert(val); + + return ref; + } + + + +template +inline +const field_injector& +operator<<(const field_injector& ref, const injector_end_of_row<>&) + { + arma_extra_debug_sigprint(); + + ref.end_of_row(); + + return ref; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/memory.hpp b/src/armadillo/include/armadillo_bits/memory.hpp new file mode 100644 index 0000000..ffa4d2c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/memory.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup memory +//! @{ + + +class memory + { + public: + + template arma_malloc inline static eT* acquire(const uword n_elem); + + template arma_inline static void release(eT* mem); + + template arma_inline static bool is_aligned(const eT* mem); + template arma_inline static void mark_as_aligned( eT*& mem); + template arma_inline static void mark_as_aligned(const eT*& mem); + }; + + + +template +arma_malloc +inline +eT* +memory::acquire(const uword n_elem) + { + if(n_elem == 0) { return nullptr; } + + arma_debug_check + ( + ( size_t(n_elem) > (std::numeric_limits::max() / sizeof(eT)) ), + "arma::memory::acquire(): requested size is too large" + ); + + eT* out_memptr; + + #if defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) + { + out_memptr = (eT *) ARMA_ALIEN_MEM_ALLOC_FUNCTION(sizeof(eT)*n_elem); + } + #elif defined(ARMA_USE_TBB_ALLOC) + { + out_memptr = (eT *) scalable_malloc(sizeof(eT)*n_elem); + } + #elif defined(ARMA_USE_MKL_ALLOC) + { + out_memptr = (eT *) mkl_malloc( sizeof(eT)*n_elem, 32 ); + } + #elif defined(ARMA_HAVE_POSIX_MEMALIGN) + { + eT* memptr = nullptr; + + const size_t n_bytes = sizeof(eT)*size_t(n_elem); + const size_t alignment = (n_bytes >= size_t(1024)) ? size_t(32) : size_t(16); + + // TODO: investigate apparent memory leak when using alignment >= 64 (as shown on Fedora 28, glibc 2.27) + int status = posix_memalign((void **)&memptr, ( (alignment >= sizeof(void*)) ? alignment : sizeof(void*) ), n_bytes); + + out_memptr = (status == 0) ? memptr : nullptr; + } + #elif defined(_MSC_VER) + { + // Windoze is too primitive to handle C++17 std::aligned_alloc() + + //out_memptr = (eT *) malloc(sizeof(eT)*n_elem); + //out_memptr = (eT *) _aligned_malloc( sizeof(eT)*n_elem, 16 ); // lives in malloc.h + + const size_t n_bytes = sizeof(eT)*size_t(n_elem); + const size_t alignment = (n_bytes >= size_t(1024)) ? size_t(32) : size_t(16); + + out_memptr = (eT *) _aligned_malloc( n_bytes, alignment ); + } + #else + { + //return ( new(std::nothrow) eT[n_elem] ); + out_memptr = (eT *) malloc(sizeof(eT)*n_elem); + } + #endif + + // TODO: for mingw, use __mingw_aligned_malloc + + arma_check_bad_alloc( (out_memptr == nullptr), "arma::memory::acquire(): out of memory" ); + + return out_memptr; + } + + + +template +arma_inline +void +memory::release(eT* mem) + { + if(mem == nullptr) { return; } + + #if defined(ARMA_ALIEN_MEM_FREE_FUNCTION) + { + ARMA_ALIEN_MEM_FREE_FUNCTION( (void *)(mem) ); + } + #elif defined(ARMA_USE_TBB_ALLOC) + { + scalable_free( (void *)(mem) ); + } + #elif defined(ARMA_USE_MKL_ALLOC) + { + mkl_free( (void *)(mem) ); + } + #elif defined(ARMA_HAVE_POSIX_MEMALIGN) + { + free( (void *)(mem) ); + } + #elif defined(_MSC_VER) + { + //free( (void *)(mem) ); + _aligned_free( (void *)(mem) ); + } + #else + { + //delete [] mem; + free( (void *)(mem) ); + } + #endif + + // TODO: for mingw, use __mingw_aligned_free + } + + + +template +arma_inline +bool +memory::is_aligned(const eT* mem) + { + #if (defined(ARMA_HAVE_ICC_ASSUME_ALIGNED) || defined(ARMA_HAVE_GCC_ASSUME_ALIGNED)) && !defined(ARMA_DONT_CHECK_ALIGNMENT) + { + return (sizeof(std::size_t) >= sizeof(eT*)) ? ((std::size_t(mem) & 0x0F) == 0) : false; + } + #else + { + arma_ignore(mem); + + return false; + } + #endif + } + + + +template +arma_inline +void +memory::mark_as_aligned(eT*& mem) + { + #if defined(ARMA_HAVE_ICC_ASSUME_ALIGNED) + { + __assume_aligned(mem, 16); + } + #elif defined(ARMA_HAVE_GCC_ASSUME_ALIGNED) + { + mem = (eT*)__builtin_assume_aligned(mem, 16); + } + #else + { + arma_ignore(mem); + } + #endif + + // TODO: look into C++20 std::assume_aligned() + // TODO: https://en.cppreference.com/w/cpp/memory/assume_aligned + + // TODO: MSVC? __assume( (mem & 0x0F) == 0 ); + // + // http://comments.gmane.org/gmane.comp.gcc.patches/239430 + // GCC __builtin_assume_aligned is similar to ICC's __assume_aligned, + // so for lvalue first argument ICC's __assume_aligned can be emulated using + // #define __assume_aligned(lvalueptr, align) lvalueptr = __builtin_assume_aligned (lvalueptr, align) + // + // http://www.inf.ethz.ch/personal/markusp/teaching/263-2300-ETH-spring11/slides/class19.pdf + // http://software.intel.com/sites/products/documentation/hpc/composerxe/en-us/cpp/lin/index.htm + // http://d3f8ykwhia686p.cloudfront.net/1live/intel/CompilerAutovectorizationGuide.pdf + } + + + +template +arma_inline +void +memory::mark_as_aligned(const eT*& mem) + { + #if defined(ARMA_HAVE_ICC_ASSUME_ALIGNED) + { + __assume_aligned(mem, 16); + } + #elif defined(ARMA_HAVE_GCC_ASSUME_ALIGNED) + { + mem = (const eT*)__builtin_assume_aligned(mem, 16); + } + #else + { + arma_ignore(mem); + } + #endif + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mp_misc.hpp b/src/armadillo/include/armadillo_bits/mp_misc.hpp new file mode 100644 index 0000000..b323ffe --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mp_misc.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mp_misc +//! @{ + + + + +template +struct mp_gate + { + arma_inline + static + bool + eval(const uword n_elem) + { + #if defined(ARMA_USE_OPENMP) + { + const bool length_ok = (is_cx::yes || use_smaller_thresh) ? (n_elem >= (arma_config::mp_threshold/uword(2))) : (n_elem >= arma_config::mp_threshold); + + if(length_ok) + { + if(omp_in_parallel()) { return false; } + } + + return length_ok; + } + #else + { + arma_ignore(n_elem); + + return false; + } + #endif + } + }; + + + +struct mp_thread_limit + { + arma_inline + static + int + get() + { + #if defined(ARMA_USE_OPENMP) + int n_threads = (std::min)(int(arma_config::mp_threads), int((std::max)(int(1), int(omp_get_max_threads())))); + #else + int n_threads = int(1); + #endif + + return n_threads; + } + + arma_inline + static + bool + in_parallel() + { + #if defined(ARMA_USE_OPENMP) + { + return bool(omp_in_parallel()); + } + #else + { + return false; + } + #endif + } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtGlueCube_bones.hpp b/src/armadillo/include/armadillo_bits/mtGlueCube_bones.hpp new file mode 100644 index 0000000..846d805 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtGlueCube_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtGlueCube +//! @{ + + + +template +class mtGlueCube : public BaseCube< out_eT, mtGlueCube > + { + public: + + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + + arma_inline mtGlueCube(const T1& in_A, const T2& in_B); + arma_inline mtGlueCube(const T1& in_A, const T2& in_B, const uword in_aux_uword); + arma_inline ~mtGlueCube(); + + arma_aligned const T1& A; //!< first operand; must be derived from BaseCube + arma_aligned const T2& B; //!< second operand; must be derived from BaseCube + arma_aligned uword aux_uword; //!< storage of auxiliary data, uword format + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtGlueCube_meat.hpp b/src/armadillo/include/armadillo_bits/mtGlueCube_meat.hpp new file mode 100644 index 0000000..dd27ecd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtGlueCube_meat.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtGlueCube +//! @{ + + + +template +inline +mtGlueCube::mtGlueCube(const T1& in_A, const T2& in_B) + : A(in_A) + , B(in_B) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtGlueCube::mtGlueCube(const T1& in_A, const T2& in_B, const uword in_aux_uword) + : A(in_A) + , B(in_B) + , aux_uword(in_aux_uword) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtGlueCube::~mtGlueCube() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtGlue_bones.hpp b/src/armadillo/include/armadillo_bits/mtGlue_bones.hpp new file mode 100644 index 0000000..5937d89 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtGlue_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtGlue +//! @{ + + + +template +class mtGlue : public Base< out_eT, mtGlue > + { + public: + + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = glue_type::template traits::is_row; + static constexpr bool is_col = glue_type::template traits::is_col; + static constexpr bool is_xvec = glue_type::template traits::is_xvec; + + arma_inline mtGlue(const T1& in_A, const T2& in_B); + arma_inline mtGlue(const T1& in_A, const T2& in_B, const uword in_aux_uword); + arma_inline ~mtGlue(); + + arma_aligned const T1& A; //!< first operand; must be derived from Base + arma_aligned const T2& B; //!< second operand; must be derived from Base + arma_aligned uword aux_uword; //!< storage of auxiliary data, uword format + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtGlue_meat.hpp b/src/armadillo/include/armadillo_bits/mtGlue_meat.hpp new file mode 100644 index 0000000..cf3afcc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtGlue_meat.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtGlue +//! @{ + + + +template +inline +mtGlue::mtGlue(const T1& in_A, const T2& in_B) + : A(in_A) + , B(in_B) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtGlue::mtGlue(const T1& in_A, const T2& in_B, const uword in_aux_uword) + : A(in_A) + , B(in_B) + , aux_uword(in_aux_uword) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtGlue::~mtGlue() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtOpCube_bones.hpp b/src/armadillo/include/armadillo_bits/mtOpCube_bones.hpp new file mode 100644 index 0000000..ea9addf --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtOpCube_bones.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtOpCube +//! @{ + + + +struct mtOpCube_dual_aux_indicator {}; + + +template +class mtOpCube : public BaseCube< out_eT, mtOpCube > + { + public: + + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + + typedef typename T1::elem_type in_eT; + + inline explicit mtOpCube(const T1& in_m); + inline mtOpCube(const T1& in_m, const in_eT in_aux); + inline mtOpCube(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); + inline mtOpCube(const T1& in_m, const in_eT in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); + + inline mtOpCube(const char junk, const T1& in_m, const out_eT in_aux); + + inline mtOpCube(const mtOpCube_dual_aux_indicator&, const T1& in_m, const in_eT in_aux_a, const out_eT in_aux_b); + + inline ~mtOpCube(); + + + arma_aligned const T1& m; //!< the operand; must be derived from BaseCube + arma_aligned in_eT aux; //!< auxiliary data, using the element type as used by T1 + arma_aligned out_eT aux_out_eT; //!< auxiliary data, using the element type as specified by the out_eT template parameter + arma_aligned uword aux_uword_a; //!< auxiliary data, uword format + arma_aligned uword aux_uword_b; //!< auxiliary data, uword format + arma_aligned uword aux_uword_c; //!< auxiliary data, uword format + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtOpCube_meat.hpp b/src/armadillo/include/armadillo_bits/mtOpCube_meat.hpp new file mode 100644 index 0000000..9ce8b17 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtOpCube_meat.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtOpCube +//! @{ + + + +template +inline +mtOpCube::mtOpCube(const T1& in_m) + : m(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOpCube::mtOpCube(const T1& in_m, const typename T1::elem_type in_aux) + : m(in_m) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOpCube::mtOpCube(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c) + : m(in_m) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + , aux_uword_c(in_aux_uword_c) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOpCube::mtOpCube(const T1& in_m, const typename T1::elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c) + : m(in_m) + , aux(in_aux) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + , aux_uword_c(in_aux_uword_c) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOpCube::mtOpCube(const char junk, const T1& in_m, const out_eT in_aux) + : m(in_m) + , aux_out_eT(in_aux) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + } + + + +template +inline +mtOpCube::mtOpCube(const mtOpCube_dual_aux_indicator&, const T1& in_m, const typename T1::elem_type in_aux_a, const out_eT in_aux_b) + : m (in_m ) + , aux (in_aux_a) + , aux_out_eT(in_aux_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOpCube::~mtOpCube() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtOp_bones.hpp b/src/armadillo/include/armadillo_bits/mtOp_bones.hpp new file mode 100644 index 0000000..ff0e4c3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtOp_bones.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtOp +//! @{ + + +struct mtOp_dual_aux_indicator {}; + + +template +class mtOp : public Base< out_eT, mtOp > + { + public: + + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + + typedef typename T1::elem_type in_eT; + + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; + + inline explicit mtOp(const T1& in_m); + inline mtOp(const T1& in_m, const in_eT in_aux); + inline mtOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); + inline mtOp(const T1& in_m, const in_eT in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b); + + inline mtOp(const char junk, const T1& in_m, const out_eT in_aux); + + inline mtOp(const mtOp_dual_aux_indicator&, const T1& in_m, const in_eT in_aux_a, const out_eT in_aux_b); + + inline ~mtOp(); + + + arma_aligned const T1& m; //!< the operand; must be derived from Base + arma_aligned in_eT aux; //!< auxiliary data, using the element type as used by T1 + arma_aligned out_eT aux_out_eT; //!< auxiliary data, using the element type as specified by the out_eT template parameter + arma_aligned uword aux_uword_a; //!< auxiliary data, uword format + arma_aligned uword aux_uword_b; //!< auxiliary data, uword format + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtOp_meat.hpp b/src/armadillo/include/armadillo_bits/mtOp_meat.hpp new file mode 100644 index 0000000..c5b53b8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtOp_meat.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtOp +//! @{ + + + +template +inline +mtOp::mtOp(const T1& in_m) + : m(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOp::mtOp(const T1& in_m, const typename T1::elem_type in_aux) + : m(in_m) + , aux(in_aux) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOp::mtOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b) + : m(in_m) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOp::mtOp(const T1& in_m, const typename T1::elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b) + : m(in_m) + , aux(in_aux) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOp::mtOp(const char junk, const T1& in_m, const out_eT in_aux) + : m(in_m) + , aux_out_eT(in_aux) + { + arma_ignore(junk); + + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOp::mtOp(const mtOp_dual_aux_indicator&, const T1& in_m, const typename T1::elem_type in_aux_a, const out_eT in_aux_b) + : m (in_m ) + , aux (in_aux_a) + , aux_out_eT(in_aux_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtOp::~mtOp() + { + arma_extra_debug_sigprint(); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtSpGlue_bones.hpp b/src/armadillo/include/armadillo_bits/mtSpGlue_bones.hpp new file mode 100644 index 0000000..3690914 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtSpGlue_bones.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtSpGlue +//! @{ + + + +template +class mtSpGlue : public SpBase< out_eT, mtSpGlue > + { + public: + + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = spglue_type::template traits::is_row; + static constexpr bool is_col = spglue_type::template traits::is_col; + static constexpr bool is_xvec = spglue_type::template traits::is_xvec; + + inline mtSpGlue(const T1& in_A, const T2& in_B); + inline ~mtSpGlue(); + + template + arma_inline bool is_alias(const SpMat& X) const; + + const T1& A; //!< first operand; must be derived from SpBase + const T2& B; //!< second operand; must be derived from SpBase + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtSpGlue_meat.hpp b/src/armadillo/include/armadillo_bits/mtSpGlue_meat.hpp new file mode 100644 index 0000000..41ede45 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtSpGlue_meat.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtSpGlue +//! @{ + + + +template +inline +mtSpGlue::mtSpGlue(const T1& in_A, const T2& in_B) + : A(in_A) + , B(in_B) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtSpGlue::~mtSpGlue() + { + arma_extra_debug_sigprint(); + } + + + +template +template +arma_inline +bool +mtSpGlue::is_alias(const SpMat& X) const + { + return (A.is_alias(X) || B.is_alias(X)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtSpOp_bones.hpp b/src/armadillo/include/armadillo_bits/mtSpOp_bones.hpp new file mode 100644 index 0000000..9c73727 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtSpOp_bones.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtSpOp +//! @{ + +// Class for delayed multi-type sparse operations. These are operations where +// the resulting type is different than the stored type. + + + +template +class mtSpOp : public SpBase< out_eT, mtSpOp > + { + public: + + typedef out_eT elem_type; + typedef typename get_pod_type::result pod_type; + + typedef typename T1::elem_type in_eT; + + static constexpr bool is_row = spop_type::template traits::is_row; + static constexpr bool is_col = spop_type::template traits::is_col; + static constexpr bool is_xvec = spop_type::template traits::is_xvec; + + inline explicit mtSpOp(const T1& in_m); + inline mtSpOp(const T1& in_m, const uword aux_uword_a, const uword aux_uword_b); + inline mtSpOp(const char junk, const T1& in_m, const out_eT in_aux); + inline ~mtSpOp(); + + template + arma_inline bool is_alias(const SpMat& X) const; + + arma_aligned const T1& m; //!< the operand; must be derived from SpBase + arma_aligned out_eT aux_out_eT; //!< auxiliary data, using the element type as specified by the out_eT template parameter + arma_aligned uword aux_uword_a; + arma_aligned uword aux_uword_b; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mtSpOp_meat.hpp b/src/armadillo/include/armadillo_bits/mtSpOp_meat.hpp new file mode 100644 index 0000000..2273f08 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mtSpOp_meat.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup mtSpOp +//! @{ + + + +template +inline +mtSpOp::mtSpOp(const T1& in_m) + : m(in_m) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtSpOp::mtSpOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b) + : m(in_m) + , aux_uword_a(in_aux_uword_a) + , aux_uword_b(in_aux_uword_b) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +mtSpOp::mtSpOp(const char junk, const T1& in_m, const out_eT in_aux) + : m(in_m) + , aux_out_eT(in_aux) + { + arma_ignore(junk); + + arma_extra_debug_sigprint(); + } + + + +template +inline +mtSpOp::~mtSpOp() + { + arma_extra_debug_sigprint(); + } + + + +template +template +arma_inline +bool +mtSpOp::is_alias(const SpMat& X) const + { + return (void_ptr(&X) == void_ptr(&m)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mul_gemm.hpp b/src/armadillo/include/armadillo_bits/mul_gemm.hpp new file mode 100644 index 0000000..27e3183 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mul_gemm.hpp @@ -0,0 +1,435 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gemm +//! @{ + + + +//! for tiny square matrices, size <= 4x4 +template +class gemm_emul_tinysq + { + public: + + + template + arma_cold + inline + static + void + apply + ( + Mat& C, + const TA& A, + const TB& B, + const eT alpha = eT(1), + const eT beta = eT(0) + ) + { + arma_extra_debug_sigprint(); + + switch(A.n_rows) + { + case 4: gemv_emul_tinysq::apply( C.colptr(3), A, B.colptr(3), alpha, beta ); + // fallthrough + case 3: gemv_emul_tinysq::apply( C.colptr(2), A, B.colptr(2), alpha, beta ); + // fallthrough + case 2: gemv_emul_tinysq::apply( C.colptr(1), A, B.colptr(1), alpha, beta ); + // fallthrough + case 1: gemv_emul_tinysq::apply( C.colptr(0), A, B.colptr(0), alpha, beta ); + // fallthrough + default: ; + } + } + + }; + + + +//! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes) +template +class gemm_emul_large + { + public: + + template + arma_hot + inline + static + void + apply + ( + Mat& C, + const TA& A, + const TB& B, + const eT alpha = eT(1), + const eT beta = eT(0) + ) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (do_trans_A == false) && (do_trans_B == false) ) + { + arma_aligned podarray tmp(A_n_cols); + + eT* A_rowdata = tmp.memptr(); + + for(uword row_A=0; row_A < A_n_rows; ++row_A) + { + tmp.copy_row(A, row_A); + + for(uword col_B=0; col_B < B_n_cols; ++col_B) + { + const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B)); + + if( (use_alpha == false) && (use_beta == false) ) { C.at(row_A,col_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(row_A,col_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); } + } + } + } + else + if( (do_trans_A == true) && (do_trans_B == false) ) + { + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + // col_A is interpreted as row_A when storing the results in matrix C + + const eT* A_coldata = A.colptr(col_A); + + for(uword col_B=0; col_B < B_n_cols; ++col_B) + { + const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B)); + + if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,col_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,col_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); } + } + } + } + else + if( (do_trans_A == false) && (do_trans_B == true) ) + { + Mat BB; + op_strans::apply_mat_noalias(BB, B); + + gemm_emul_large::apply(C, A, BB, alpha, beta); + } + else + if( (do_trans_A == true) && (do_trans_B == true) ) + { + // mat B_tmp = trans(B); + // dgemm_arma::apply(C, A, B_tmp, alpha, beta); + + + // By using the trans(A)*trans(B) = trans(B*A) equivalency, + // transpose operations are not needed + + arma_aligned podarray tmp(B.n_cols); + eT* B_rowdata = tmp.memptr(); + + for(uword row_B=0; row_B < B_n_rows; ++row_B) + { + tmp.copy_row(B, row_B); + + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A)); + + if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,row_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,row_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); } + } + } + } + } + + }; + + + +template +class gemm_emul + { + public: + + + template + arma_hot + inline + static + void + apply + ( + Mat& C, + const TA& A, + const TB& B, + const eT alpha = eT(1), + const eT beta = eT(0), + const typename arma_not_cx::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + gemm_emul_large::apply(C, A, B, alpha, beta); + } + + + + template + arma_hot + inline + static + void + apply + ( + Mat& C, + const Mat& A, + const Mat& B, + const eT alpha = eT(1), + const eT beta = eT(0), + const typename arma_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + // "better than nothing" handling of hermitian transposes for complex number matrices + + Mat tmp_A; + Mat tmp_B; + + if(do_trans_A) { op_htrans::apply_mat_noalias(tmp_A, A); } + if(do_trans_B) { op_htrans::apply_mat_noalias(tmp_B, B); } + + const Mat& AA = (do_trans_A == false) ? A : tmp_A; + const Mat& BB = (do_trans_B == false) ? B : tmp_B; + + gemm_emul_large::apply(C, AA, BB, alpha, beta); + } + + }; + + + +//! \brief +//! Wrapper for BLAS dgemm function, using template arguments to control the arguments passed to dgemm. +//! Matrix 'C' is assumed to have been set to the correct size (ie. taking into account transposes) + +template +class gemm + { + public: + + template + inline + static + void + apply_blas_type( Mat& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) ) + { + arma_extra_debug_sigprint(); + + if( (A.n_rows <= 4) && (A.n_rows == A.n_cols) && (A.n_rows == B.n_rows) && (B.n_rows == B.n_cols) && (is_cx::no) ) + { + if(do_trans_B == false) + { + gemm_emul_tinysq::apply(C, A, B, alpha, beta); + } + else + { + Mat BB(B.n_rows, B.n_rows, arma_nozeros_indicator()); + + op_strans::apply_mat_noalias_tinysq(BB, B); + + gemm_emul_tinysq::apply(C, A, BB, alpha, beta); + } + } + else + { + #if defined(ARMA_USE_ATLAS) + { + arma_extra_debug_print("atlas::cblas_gemm()"); + + arma_debug_assert_atlas_size(A,B); + + atlas::cblas_gemm + ( + atlas_CblasColMajor, + (do_trans_A) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, + (do_trans_B) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, + C.n_rows, + C.n_cols, + (do_trans_A) ? A.n_rows : A.n_cols, + (use_alpha) ? alpha : eT(1), + A.mem, + (do_trans_A) ? A.n_rows : C.n_rows, + B.mem, + (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ), + (use_beta) ? beta : eT(0), + C.memptr(), + C.n_rows + ); + } + #elif defined(ARMA_USE_BLAS) + { + arma_extra_debug_print("blas::gemm()"); + + arma_debug_assert_blas_size(A,B); + + const char trans_A = (do_trans_A) ? ( is_cx::yes ? 'C' : 'T' ) : 'N'; + const char trans_B = (do_trans_B) ? ( is_cx::yes ? 'C' : 'T' ) : 'N'; + + const blas_int m = blas_int(C.n_rows); + const blas_int n = blas_int(C.n_cols); + const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols); + + const eT local_alpha = (use_alpha) ? alpha : eT(1); + + const blas_int lda = (do_trans_A) ? k : m; + const blas_int ldb = (do_trans_B) ? n : k; + + const eT local_beta = (use_beta) ? beta : eT(0); + + arma_extra_debug_print( arma_str::format("blas::gemm(): trans_A = %c") % trans_A ); + arma_extra_debug_print( arma_str::format("blas::gemm(): trans_B = %c") % trans_B ); + + blas::gemm + ( + &trans_A, + &trans_B, + &m, + &n, + &k, + &local_alpha, + A.mem, + &lda, + B.mem, + &ldb, + &local_beta, + C.memptr(), + &m + ); + } + #else + { + gemm_emul::apply(C,A,B,alpha,beta); + } + #endif + } + } + + + + //! immediate multiplication of matrices A and B, storing the result in C + template + inline + static + void + apply( Mat& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) ) + { + gemm_emul::apply(C,A,B,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + Mat& C, + const TA& A, + const TB& B, + const float alpha = float(1), + const float beta = float(0) + ) + { + gemm::apply_blas_type(C,A,B,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + Mat& C, + const TA& A, + const TB& B, + const double alpha = double(1), + const double beta = double(0) + ) + { + gemm::apply_blas_type(C,A,B,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const TB& B, + const std::complex alpha = std::complex(1), + const std::complex beta = std::complex(0) + ) + { + gemm::apply_blas_type(C,A,B,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const TB& B, + const std::complex alpha = std::complex(1), + const std::complex beta = std::complex(0) + ) + { + gemm::apply_blas_type(C,A,B,alpha,beta); + } + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mul_gemm_mixed.hpp b/src/armadillo/include/armadillo_bits/mul_gemm_mixed.hpp new file mode 100644 index 0000000..749cdb1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mul_gemm_mixed.hpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gemm_mixed +//! @{ + + + +//! \brief +//! Matrix multplication where the matrices have differing element types. +//! Uses caching for speedup. +//! Matrix 'C' is assumed to have been set to the correct size (ie. taking into account transposes) + +template +class gemm_mixed_large + { + public: + + template + arma_hot + inline + static + void + apply + ( + Mat& C, + const Mat& A, + const Mat& B, + const out_eT alpha = out_eT(1), + const out_eT beta = out_eT(0) + ) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (do_trans_A == false) && (do_trans_B == false) ) + { + podarray tmp(A_n_cols); + in_eT1* A_rowdata = tmp.memptr(); + + #if defined(ARMA_USE_OPENMP) + const bool use_mp = (B_n_cols >= 2) && (B.n_elem >= 8192) && (mp_thread_limit::in_parallel() == false); + #else + const bool use_mp = false; + #endif + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)( uword(mp_thread_limit::get()), uword(B_n_cols) ) ); + + for(uword row_A=0; row_A < A_n_rows; ++row_A) + { + tmp.copy_row(A, row_A); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col_B=0; col_B < B_n_cols; ++col_B) + { + const in_eT2* B_coldata = B.colptr(col_B); + + out_eT acc = out_eT(0); + for(uword i=0; i < B_n_rows; ++i) + { + acc += upgrade_val::apply(A_rowdata[i]) * upgrade_val::apply(B_coldata[i]); + } + + if( (use_alpha == false) && (use_beta == false) ) { C.at(row_A,col_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(row_A,col_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); } + } + } + } + #endif + } + else + { + for(uword row_A=0; row_A < A_n_rows; ++row_A) + { + tmp.copy_row(A, row_A); + + for(uword col_B=0; col_B < B_n_cols; ++col_B) + { + const in_eT2* B_coldata = B.colptr(col_B); + + out_eT acc = out_eT(0); + for(uword i=0; i < B_n_rows; ++i) + { + acc += upgrade_val::apply(A_rowdata[i]) * upgrade_val::apply(B_coldata[i]); + } + + if( (use_alpha == false) && (use_beta == false) ) { C.at(row_A,col_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(row_A,col_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); } + } + } + } + } + else + if( (do_trans_A == true) && (do_trans_B == false) ) + { + #if defined(ARMA_USE_OPENMP) + const bool use_mp = (B_n_cols >= 2) && (B.n_elem >= 8192) && (mp_thread_limit::in_parallel() == false); + #else + const bool use_mp = false; + #endif + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)( uword(mp_thread_limit::get()), uword(B_n_cols) ) ); + + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + // col_A is interpreted as row_A when storing the results in matrix C + + const in_eT1* A_coldata = A.colptr(col_A); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col_B=0; col_B < B_n_cols; ++col_B) + { + const in_eT2* B_coldata = B.colptr(col_B); + + out_eT acc = out_eT(0); + for(uword i=0; i < B_n_rows; ++i) + { + acc += upgrade_val::apply(A_coldata[i]) * upgrade_val::apply(B_coldata[i]); + } + + if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,col_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,col_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); } + } + } + } + #endif + } + else + { + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + // col_A is interpreted as row_A when storing the results in matrix C + + const in_eT1* A_coldata = A.colptr(col_A); + + for(uword col_B=0; col_B < B_n_cols; ++col_B) + { + const in_eT2* B_coldata = B.colptr(col_B); + + out_eT acc = out_eT(0); + for(uword i=0; i < B_n_rows; ++i) + { + acc += upgrade_val::apply(A_coldata[i]) * upgrade_val::apply(B_coldata[i]); + } + + if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,col_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,col_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); } + } + } + } + } + else + if( (do_trans_A == false) && (do_trans_B == true) ) + { + Mat B_tmp; + + op_strans::apply_mat_noalias(B_tmp, B); + + gemm_mixed_large::apply(C, A, B_tmp, alpha, beta); + } + else + if( (do_trans_A == true) && (do_trans_B == true) ) + { + // mat B_tmp = trans(B); + // dgemm_arma::apply(C, A, B_tmp, alpha, beta); + + + // By using the trans(A)*trans(B) = trans(B*A) equivalency, + // transpose operations are not needed + + podarray tmp(B_n_cols); + in_eT2* B_rowdata = tmp.memptr(); + + for(uword row_B=0; row_B < B_n_rows; ++row_B) + { + tmp.copy_row(B, row_B); + + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + const in_eT1* A_coldata = A.colptr(col_A); + + out_eT acc = out_eT(0); + for(uword i=0; i < A_n_rows; ++i) + { + acc += upgrade_val::apply(B_rowdata[i]) * upgrade_val::apply(A_coldata[i]); + } + + if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,row_B) = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,row_B) = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); } + else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); } + } + } + + } + } + + }; + + + +//! \brief +//! Matrix multplication where the matrices have differing element types. + +template +class gemm_mixed + { + public: + + //! immediate multiplication of matrices A and B, storing the result in C + template + inline + static + void + apply + ( + Mat& C, + const Mat& A, + const Mat& B, + const out_eT alpha = out_eT(1), + const out_eT beta = out_eT(0) + ) + { + arma_extra_debug_sigprint(); + + if((is_cx::yes && do_trans_A) || (is_cx::yes && do_trans_B)) + { + // better-than-nothing handling of hermitian transpose + + Mat tmp_A; + Mat tmp_B; + + const bool predo_trans_A = ( (do_trans_A == true) && (is_cx::yes) ); + const bool predo_trans_B = ( (do_trans_B == true) && (is_cx::yes) ); + + if(predo_trans_A) { op_htrans::apply_mat_noalias(tmp_A, A); } + if(predo_trans_B) { op_htrans::apply_mat_noalias(tmp_B, B); } + + const Mat& AA = (predo_trans_A == false) ? A : tmp_A; + const Mat& BB = (predo_trans_B == false) ? B : tmp_B; + + gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); + } + else + { + gemm_mixed_large::apply(C, A, B, alpha, beta); + } + } + + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mul_gemv.hpp b/src/armadillo/include/armadillo_bits/mul_gemv.hpp new file mode 100644 index 0000000..2580e4a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mul_gemv.hpp @@ -0,0 +1,495 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup gemv +//! @{ + + + +//! for tiny square matrices, size <= 4x4 +template +class gemv_emul_tinysq + { + public: + + + template + struct pos + { + static constexpr uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2); + static constexpr uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3); + static constexpr uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4); + }; + + + + template + arma_inline + static + void + assign(eT* y, const eT acc, const eT alpha, const eT beta) + { + if(use_beta == false) + { + y[i] = (use_alpha == false) ? acc : alpha*acc; + } + else + { + const eT tmp = y[i]; + + y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc ); + } + } + + + + template + arma_cold + inline + static + void + apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) + { + arma_extra_debug_sigprint(); + + const eT* Am = A.memptr(); + + switch(A.n_rows) + { + case 1: + { + const eT acc = Am[0] * x[0]; + + assign(y, acc, alpha, beta); + } + break; + + + case 2: + { + const eT x0 = x[0]; + const eT x1 = x[1]; + + const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1; + const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1; + + assign(y, acc0, alpha, beta); + assign(y, acc1, alpha, beta); + } + break; + + + case 3: + { + const eT x0 = x[0]; + const eT x1 = x[1]; + const eT x2 = x[2]; + + const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2; + const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2; + const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2; + + assign(y, acc0, alpha, beta); + assign(y, acc1, alpha, beta); + assign(y, acc2, alpha, beta); + } + break; + + + case 4: + { + const eT x0 = x[0]; + const eT x1 = x[1]; + const eT x2 = x[2]; + const eT x3 = x[3]; + + const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3; + const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3; + const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3; + const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3; + + assign(y, acc0, alpha, beta); + assign(y, acc1, alpha, beta); + assign(y, acc2, alpha, beta); + assign(y, acc3, alpha, beta); + } + break; + + + default: + ; + } + } + + }; + + + +class gemv_emul_helper + { + public: + + template + arma_hot + inline + static + typename arma_not_cx::result + dot_row_col( const TA& A, const eT* x, const uword row, const uword N ) + { + eT acc1 = eT(0); + eT acc2 = eT(0); + + uword i,j; + for(i=0, j=1; j < N; i+=2, j+=2) + { + const eT xi = x[i]; + const eT xj = x[j]; + + acc1 += A.at(row,i) * xi; + acc2 += A.at(row,j) * xj; + } + + if(i < N) + { + acc1 += A.at(row,i) * x[i]; + } + + return (acc1 + acc2); + } + + + + template + arma_hot + inline + static + typename arma_cx_only::result + dot_row_col( const TA& A, const eT* x, const uword row, const uword N ) + { + typedef typename get_pod_type::result T; + + T val_real = T(0); + T val_imag = T(0); + + for(uword i=0; i& Ai = A.at(row,i); + const std::complex& xi = x[i]; + + const T a = Ai.real(); + const T b = Ai.imag(); + + const T c = xi.real(); + const T d = xi.imag(); + + val_real += (a*c) - (b*d); + val_imag += (a*d) + (b*c); + } + + return std::complex(val_real, val_imag); + } + + }; + + + +//! \brief +//! Partial emulation of BLAS gemv(). +//! 'y' is assumed to have been set to the correct size (ie. taking into account the transpose) + +template +class gemv_emul + { + public: + + template + arma_hot + inline + static + void + apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + if(do_trans_A == false) + { + if(A_n_rows == 1) + { + const eT acc = op_dot::direct_dot_arma(A_n_cols, A.memptr(), x); + + if( (use_alpha == false) && (use_beta == false) ) { y[0] = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { y[0] = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { y[0] = acc + beta*y[0]; } + else if( (use_alpha == true ) && (use_beta == true ) ) { y[0] = alpha*acc + beta*y[0]; } + } + else + for(uword row=0; row < A_n_rows; ++row) + { + const eT acc = gemv_emul_helper::dot_row_col(A, x, row, A_n_cols); + + if( (use_alpha == false) && (use_beta == false) ) { y[row] = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { y[row] = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { y[row] = acc + beta*y[row]; } + else if( (use_alpha == true ) && (use_beta == true ) ) { y[row] = alpha*acc + beta*y[row]; } + } + } + else + if(do_trans_A == true) + { + if(is_cx::no) + { + for(uword col=0; col < A_n_cols; ++col) + { + // col is interpreted as row when storing the results in 'y' + + + // const eT* A_coldata = A.colptr(col); + // + // eT acc = eT(0); + // for(uword row=0; row < A_n_rows; ++row) + // { + // acc += A_coldata[row] * x[row]; + // } + + const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x); + + if( (use_alpha == false) && (use_beta == false) ) { y[col] = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { y[col] = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { y[col] = acc + beta*y[col]; } + else if( (use_alpha == true ) && (use_beta == true ) ) { y[col] = alpha*acc + beta*y[col]; } + } + } + else + { + Mat AA; + + op_htrans::apply_mat_noalias(AA, A); + + gemv_emul::apply(y, AA, x, alpha, beta); + } + } + } + + }; + + + +//! \brief +//! Wrapper for BLAS gemv function, using template arguments to control the arguments passed to gemv. +//! 'y' is assumed to have been set to the correct size (ie. taking into account the transpose) + +template +class gemv + { + public: + + template + inline + static + void + apply_blas_type( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) + { + arma_extra_debug_sigprint(); + + if( (A.n_rows <= 4) && (A.n_rows == A.n_cols) && (is_cx::no) ) + { + gemv_emul_tinysq::apply(y, A, x, alpha, beta); + } + else + { + #if defined(ARMA_USE_ATLAS) + { + arma_debug_assert_atlas_size(A); + + if(is_cx::no) + { + // use gemm() instead of gemv() to work around a speed issue in Atlas 3.8.4 + + arma_extra_debug_print("atlas::cblas_gemm()"); + + atlas::cblas_gemm + ( + atlas_CblasColMajor, + (do_trans_A) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, + atlas_CblasNoTrans, + (do_trans_A) ? A.n_cols : A.n_rows, + 1, + (do_trans_A) ? A.n_rows : A.n_cols, + (use_alpha) ? alpha : eT(1), + A.mem, + A.n_rows, + x, + (do_trans_A) ? A.n_rows : A.n_cols, + (use_beta) ? beta : eT(0), + y, + (do_trans_A) ? A.n_cols : A.n_rows + ); + } + else + { + arma_extra_debug_print("atlas::cblas_gemv()"); + + atlas::cblas_gemv + ( + atlas_CblasColMajor, + (do_trans_A) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, + A.n_rows, + A.n_cols, + (use_alpha) ? alpha : eT(1), + A.mem, + A.n_rows, + x, + 1, + (use_beta) ? beta : eT(0), + y, + 1 + ); + } + } + #elif defined(ARMA_USE_BLAS) + { + arma_extra_debug_print("blas::gemv()"); + + arma_debug_assert_blas_size(A); + + const char trans_A = (do_trans_A) ? ( is_cx::yes ? 'C' : 'T' ) : 'N'; + const blas_int m = blas_int(A.n_rows); + const blas_int n = blas_int(A.n_cols); + const eT local_alpha = (use_alpha) ? alpha : eT(1); + //const blas_int lda = A.n_rows; + const blas_int inc = blas_int(1); + const eT local_beta = (use_beta) ? beta : eT(0); + + arma_extra_debug_print( arma_str::format("blas::gemv(): trans_A = %c") % trans_A ); + + blas::gemv + ( + &trans_A, + &m, + &n, + &local_alpha, + A.mem, + &m, // lda + x, + &inc, + &local_beta, + y, + &inc + ); + } + #else + { + gemv_emul::apply(y,A,x,alpha,beta); + } + #endif + } + + } + + + + template + arma_inline + static + void + apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) + { + gemv_emul::apply(y,A,x,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + float* y, + const TA& A, + const float* x, + const float alpha = float(1), + const float beta = float(0) + ) + { + gemv::apply_blas_type(y,A,x,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + double* y, + const TA& A, + const double* x, + const double alpha = double(1), + const double beta = double(0) + ) + { + gemv::apply_blas_type(y,A,x,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + std::complex* y, + const TA& A, + const std::complex* x, + const std::complex alpha = std::complex(1), + const std::complex beta = std::complex(0) + ) + { + gemv::apply_blas_type(y,A,x,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + std::complex* y, + const TA& A, + const std::complex* x, + const std::complex alpha = std::complex(1), + const std::complex beta = std::complex(0) + ) + { + gemv::apply_blas_type(y,A,x,alpha,beta); + } + + + + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mul_herk.hpp b/src/armadillo/include/armadillo_bits/mul_herk.hpp new file mode 100644 index 0000000..e6b13b2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mul_herk.hpp @@ -0,0 +1,492 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup herk +//! @{ + + + +class herk_helper + { + public: + + template + inline + static + void + inplace_conj_copy_upper_tri_to_lower_tri(Mat& C) + { + // under the assumption that C is a square matrix + + const uword N = C.n_rows; + + for(uword k=0; k < N; ++k) + { + eT* colmem = C.colptr(k); + + for(uword i=(k+1); i < N; ++i) + { + colmem[i] = std::conj( C.at(k,i) ); + } + } + } + + + template + arma_hot + inline + static + eT + dot_conj_row(const uword n_elem, const eT* const A, const Mat& B, const uword row) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + T val_real = T(0); + T val_imag = T(0); + + for(uword i=0; i& X = A[i]; + const std::complex& Y = B.at(row,i); + + const T a = X.real(); + const T b = X.imag(); + + const T c = Y.real(); + const T d = Y.imag(); + + val_real += (a*c) + (b*d); + val_imag += (b*c) - (a*d); + } + + return std::complex(val_real, val_imag); + } + + }; + + + +template +class herk_vec + { + public: + + template + arma_hot + inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const T alpha = T(1), + const T beta = T(0) + ) + { + arma_extra_debug_sigprint(); + + typedef std::complex eT; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + // for beta != 0, C is assumed to be hermitian + + // do_trans_A == false -> C = alpha * A * A^H + beta*C + // do_trans_A == true -> C = alpha * A^H * A + beta*C + + const eT* A_mem = A.memptr(); + + if(do_trans_A == false) + { + if(A_n_rows == 1) + { + const eT acc = op_cdot::direct_cdot(A_n_cols, A_mem, A_mem); + + if( (use_alpha == false) && (use_beta == false) ) { C[0] = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C[0] = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C[0] = acc + beta*C[0]; } + else if( (use_alpha == true ) && (use_beta == true ) ) { C[0] = alpha*acc + beta*C[0]; } + } + else + for(uword row_A=0; row_A < A_n_rows; ++row_A) + { + const eT& A_rowdata = A_mem[row_A]; + + for(uword k=row_A; k < A_n_rows; ++k) + { + const eT acc = A_rowdata * std::conj( A_mem[k] ); + + if( (use_alpha == false) && (use_beta == false) ) + { + C.at(row_A, k) = acc; + if(row_A != k) { C.at(k, row_A) = std::conj(acc); } + } + else + if( (use_alpha == true) && (use_beta == false) ) + { + const eT val = alpha*acc; + + C.at(row_A, k) = val; + if(row_A != k) { C.at(k, row_A) = std::conj(val); } + } + else + if( (use_alpha == false) && (use_beta == true) ) + { + C.at(row_A, k) = acc + beta*C.at(row_A, k); + if(row_A != k) { C.at(k, row_A) = std::conj(acc) + beta*C.at(k, row_A); } + } + else + if( (use_alpha == true) && (use_beta == true) ) + { + const eT val = alpha*acc; + + C.at(row_A, k) = val + beta*C.at(row_A, k); + if(row_A != k) { C.at(k, row_A) = std::conj(val) + beta*C.at(k, row_A); } + } + } + } + } + else + if(do_trans_A == true) + { + if(A_n_cols == 1) + { + const eT acc = op_cdot::direct_cdot(A_n_rows, A_mem, A_mem); + + if( (use_alpha == false) && (use_beta == false) ) { C[0] = acc; } + else if( (use_alpha == true ) && (use_beta == false) ) { C[0] = alpha*acc; } + else if( (use_alpha == false) && (use_beta == true ) ) { C[0] = acc + beta*C[0]; } + else if( (use_alpha == true ) && (use_beta == true ) ) { C[0] = alpha*acc + beta*C[0]; } + } + else + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + // col_A is interpreted as row_A when storing the results in matrix C + + const eT A_coldata = std::conj( A_mem[col_A] ); + + for(uword k=col_A; k < A_n_cols ; ++k) + { + const eT acc = A_coldata * A_mem[k]; + + if( (use_alpha == false) && (use_beta == false) ) + { + C.at(col_A, k) = acc; + if(col_A != k) { C.at(k, col_A) = std::conj(acc); } + } + else + if( (use_alpha == true ) && (use_beta == false) ) + { + const eT val = alpha*acc; + + C.at(col_A, k) = val; + if(col_A != k) { C.at(k, col_A) = std::conj(val); } + } + else + if( (use_alpha == false) && (use_beta == true ) ) + { + C.at(col_A, k) = acc + beta*C.at(col_A, k); + if(col_A != k) { C.at(k, col_A) = std::conj(acc) + beta*C.at(k, col_A); } + } + else + if( (use_alpha == true ) && (use_beta == true ) ) + { + const eT val = alpha*acc; + + C.at(col_A, k) = val + beta*C.at(col_A, k); + if(col_A != k) { C.at(k, col_A) = std::conj(val) + beta*C.at(k, col_A); } + } + } + } + } + } + + }; + + + +template +class herk_emul + { + public: + + template + arma_hot + inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const T alpha = T(1), + const T beta = T(0) + ) + { + arma_extra_debug_sigprint(); + + typedef std::complex eT; + + // do_trans_A == false -> C = alpha * A * A^H + beta*C + // do_trans_A == true -> C = alpha * A^H * A + beta*C + + if(do_trans_A == false) + { + Mat AA; + + op_htrans::apply_mat_noalias(AA, A); + + herk_emul::apply(C, AA, alpha, beta); + } + else + if(do_trans_A == true) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + // col_A is interpreted as row_A when storing the results in matrix C + + const eT* A_coldata = A.colptr(col_A); + + for(uword k=col_A; k < A_n_cols ; ++k) + { + const eT acc = op_cdot::direct_cdot(A_n_rows, A_coldata, A.colptr(k)); + + if( (use_alpha == false) && (use_beta == false) ) + { + C.at(col_A, k) = acc; + if(col_A != k) { C.at(k, col_A) = std::conj(acc); } + } + else + if( (use_alpha == true) && (use_beta == false) ) + { + const eT val = alpha*acc; + + C.at(col_A, k) = val; + if(col_A != k) { C.at(k, col_A) = std::conj(val); } + } + else + if( (use_alpha == false) && (use_beta == true) ) + { + C.at(col_A, k) = acc + beta*C.at(col_A, k); + if(col_A != k) { C.at(k, col_A) = std::conj(acc) + beta*C.at(k, col_A); } + } + else + if( (use_alpha == true) && (use_beta == true) ) + { + const eT val = alpha*acc; + + C.at(col_A, k) = val + beta*C.at(col_A, k); + if(col_A != k) { C.at(k, col_A) = std::conj(val) + beta*C.at(k, col_A); } + } + } + } + } + } + + }; + + + +template +class herk + { + public: + + template + inline + static + void + apply_blas_type( Mat>& C, const TA& A, const T alpha = T(1), const T beta = T(0) ) + { + arma_extra_debug_sigprint(); + + const uword threshold = 16; + + if(A.is_vec()) + { + // work around poor handling of vectors by herk() in standard BLAS + + herk_vec::apply(C,A,alpha,beta); + + return; + } + + + if( (A.n_elem <= threshold) ) + { + herk_emul::apply(C,A,alpha,beta); + } + else + { + #if defined(ARMA_USE_ATLAS) + { + if(use_beta == true) + { + typedef typename std::complex eT; + + // use a temporary matrix, as we can't assume that matrix C is already symmetric + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); + + herk::apply_blas_type(D,A,alpha); + + // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1 + arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem); + + return; + } + + atlas::cblas_herk + ( + atlas_CblasColMajor, + atlas_CblasUpper, + (do_trans_A) ? atlas_CblasConjTrans : atlas_CblasNoTrans, + C.n_cols, + (do_trans_A) ? A.n_rows : A.n_cols, + (use_alpha) ? alpha : T(1), + A.mem, + (do_trans_A) ? A.n_rows : C.n_cols, + (use_beta) ? beta : T(0), + C.memptr(), + C.n_cols + ); + + herk_helper::inplace_conj_copy_upper_tri_to_lower_tri(C); + } + #elif defined(ARMA_USE_BLAS) + { + if(use_beta == true) + { + typedef typename std::complex eT; + + // use a temporary matrix, as we can't assume that matrix C is already symmetric + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); + + herk::apply_blas_type(D,A,alpha); + + // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1 + arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem); + + return; + } + + arma_extra_debug_print("blas::herk()"); + + const char uplo = 'U'; + + const char trans_A = (do_trans_A) ? 'C' : 'N'; + + const blas_int n = blas_int(C.n_cols); + const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols); + + const T local_alpha = (use_alpha) ? alpha : T(1); + const T local_beta = (use_beta) ? beta : T(0); + + const blas_int lda = (do_trans_A) ? k : n; + + arma_extra_debug_print( arma_str::format("blas::herk(): trans_A = %c") % trans_A ); + + blas::herk + ( + &uplo, + &trans_A, + &n, + &k, + &local_alpha, + A.mem, + &lda, + &local_beta, + C.memptr(), + &n // &ldc + ); + + herk_helper::inplace_conj_copy_upper_tri_to_lower_tri(C); + } + #else + { + herk_emul::apply(C,A,alpha,beta); + } + #endif + } + + } + + + + template + inline + static + void + apply( Mat& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx::result* junk = nullptr ) + { + arma_ignore(C); + arma_ignore(A); + arma_ignore(alpha); + arma_ignore(beta); + arma_ignore(junk); + + // herk() cannot be used by non-complex matrices + + return; + } + + + + template + arma_inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const float alpha = float(1), + const float beta = float(0) + ) + { + herk::apply_blas_type(C,A,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const double alpha = double(1), + const double beta = double(0) + ) + { + herk::apply_blas_type(C,A,alpha,beta); + } + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/mul_syrk.hpp b/src/armadillo/include/armadillo_bits/mul_syrk.hpp new file mode 100644 index 0000000..c2da3a2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/mul_syrk.hpp @@ -0,0 +1,501 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup syrk +//! @{ + + + +class syrk_helper + { + public: + + template + inline + static + void + inplace_copy_upper_tri_to_lower_tri(Mat& C) + { + // under the assumption that C is a square matrix + + const uword N = C.n_rows; + + for(uword k=0; k < N; ++k) + { + eT* colmem = C.colptr(k); + + uword i, j; + for(i=(k+1), j=(k+2); j < N; i+=2, j+=2) + { + const eT tmp_i = C.at(k,i); + const eT tmp_j = C.at(k,j); + + colmem[i] = tmp_i; + colmem[j] = tmp_j; + } + + if(i < N) + { + colmem[i] = C.at(k,i); + } + } + } + }; + + + +//! partial emulation of BLAS function syrk(), specialised for A being a vector +template +class syrk_vec + { + public: + + template + arma_hot + inline + static + void + apply + ( + Mat& C, + const TA& A, + const eT alpha = eT(1), + const eT beta = eT(0) + ) + { + arma_extra_debug_sigprint(); + + const uword A_n1 = (do_trans_A == false) ? A.n_rows : A.n_cols; + const uword A_n2 = (do_trans_A == false) ? A.n_cols : A.n_rows; + + const eT* A_mem = A.memptr(); + + if(A_n1 == 1) + { + const eT acc1 = op_dot::direct_dot(A_n2, A_mem, A_mem); + + if( (use_alpha == false) && (use_beta == false) ) { C[0] = acc1; } + else if( (use_alpha == true ) && (use_beta == false) ) { C[0] = alpha*acc1; } + else if( (use_alpha == false) && (use_beta == true ) ) { C[0] = acc1 + beta*C[0]; } + else if( (use_alpha == true ) && (use_beta == true ) ) { C[0] = alpha*acc1 + beta*C[0]; } + } + else + for(uword k=0; k < A_n1; ++k) + { + const eT A_k = A_mem[k]; + + uword i,j; + for(i=(k), j=(k+1); j < A_n1; i+=2, j+=2) + { + const eT acc1 = A_k * A_mem[i]; + const eT acc2 = A_k * A_mem[j]; + + if( (use_alpha == false) && (use_beta == false) ) + { + C.at(k, i) = acc1; + C.at(k, j) = acc2; + + C.at(i, k) = acc1; + C.at(j, k) = acc2; + } + else + if( (use_alpha == true ) && (use_beta == false) ) + { + const eT val1 = alpha*acc1; + const eT val2 = alpha*acc2; + + C.at(k, i) = val1; + C.at(k, j) = val2; + + C.at(i, k) = val1; + C.at(j, k) = val2; + } + else + if( (use_alpha == false) && (use_beta == true) ) + { + C.at(k, i) = acc1 + beta*C.at(k, i); + C.at(k, j) = acc2 + beta*C.at(k, j); + + if(i != k) { C.at(i, k) = acc1 + beta*C.at(i, k); } + C.at(j, k) = acc2 + beta*C.at(j, k); + } + else + if( (use_alpha == true ) && (use_beta == true) ) + { + const eT val1 = alpha*acc1; + const eT val2 = alpha*acc2; + + C.at(k, i) = val1 + beta*C.at(k, i); + C.at(k, j) = val2 + beta*C.at(k, j); + + if(i != k) { C.at(i, k) = val1 + beta*C.at(i, k); } + C.at(j, k) = val2 + beta*C.at(j, k); + } + } + + if(i < A_n1) + { + const eT acc1 = A_k * A_mem[i]; + + if( (use_alpha == false) && (use_beta == false) ) + { + C.at(k, i) = acc1; + C.at(i, k) = acc1; + } + else + if( (use_alpha == true) && (use_beta == false) ) + { + const eT val1 = alpha*acc1; + + C.at(k, i) = val1; + C.at(i, k) = val1; + } + else + if( (use_alpha == false) && (use_beta == true) ) + { + C.at(k, i) = acc1 + beta*C.at(k, i); + if(i != k) { C.at(i, k) = acc1 + beta*C.at(i, k); } + } + else + if( (use_alpha == true) && (use_beta == true) ) + { + const eT val1 = alpha*acc1; + + C.at(k, i) = val1 + beta*C.at(k, i); + if(i != k) { C.at(i, k) = val1 + beta*C.at(i, k); } + } + } + } + } + + }; + + + +//! partial emulation of BLAS function syrk() +template +class syrk_emul + { + public: + + template + arma_hot + inline + static + void + apply + ( + Mat& C, + const TA& A, + const eT alpha = eT(1), + const eT beta = eT(0) + ) + { + arma_extra_debug_sigprint(); + + // do_trans_A == false -> C = alpha * A * A^T + beta*C + // do_trans_A == true -> C = alpha * A^T * A + beta*C + + if(do_trans_A == false) + { + Mat AA; + + op_strans::apply_mat_noalias(AA, A); + + syrk_emul::apply(C, AA, alpha, beta); + } + else + if(do_trans_A == true) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + for(uword col_A=0; col_A < A_n_cols; ++col_A) + { + // col_A is interpreted as row_A when storing the results in matrix C + + const eT* A_coldata = A.colptr(col_A); + + for(uword k=col_A; k < A_n_cols; ++k) + { + const eT acc = op_dot::direct_dot_arma(A_n_rows, A_coldata, A.colptr(k)); + + if( (use_alpha == false) && (use_beta == false) ) + { + C.at(col_A, k) = acc; + C.at(k, col_A) = acc; + } + else + if( (use_alpha == true ) && (use_beta == false) ) + { + const eT val = alpha*acc; + + C.at(col_A, k) = val; + C.at(k, col_A) = val; + } + else + if( (use_alpha == false) && (use_beta == true ) ) + { + C.at(col_A, k) = acc + beta*C.at(col_A, k); + if(col_A != k) { C.at(k, col_A) = acc + beta*C.at(k, col_A); } + } + else + if( (use_alpha == true ) && (use_beta == true ) ) + { + const eT val = alpha*acc; + + C.at(col_A, k) = val + beta*C.at(col_A, k); + if(col_A != k) { C.at(k, col_A) = val + beta*C.at(k, col_A); } + } + } + } + } + } + + }; + + + +template +class syrk + { + public: + + template + inline + static + void + apply_blas_type( Mat& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0) ) + { + arma_extra_debug_sigprint(); + + if(A.is_vec()) + { + // work around poor handling of vectors by syrk() in standard BLAS + + syrk_vec::apply(C,A,alpha,beta); + + return; + } + + const uword threshold = (is_cx::yes ? 16u : 48u); + + if( A.n_elem <= threshold ) + { + syrk_emul::apply(C,A,alpha,beta); + } + else + { + #if defined(ARMA_USE_ATLAS) + { + if(use_beta == true) + { + // use a temporary matrix, as we can't assume that matrix C is already symmetric + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); + + syrk::apply_blas_type(D,A,alpha); + + // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1 + arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem); + + return; + } + + atlas::cblas_syrk + ( + atlas_CblasColMajor, + atlas_CblasUpper, + (do_trans_A) ? atlas_CblasTrans : atlas_CblasNoTrans, + C.n_cols, + (do_trans_A) ? A.n_rows : A.n_cols, + (use_alpha) ? alpha : eT(1), + A.mem, + (do_trans_A) ? A.n_rows : C.n_cols, + (use_beta) ? beta : eT(0), + C.memptr(), + C.n_cols + ); + + syrk_helper::inplace_copy_upper_tri_to_lower_tri(C); + } + #elif defined(ARMA_USE_BLAS) + { + if(use_beta == true) + { + // use a temporary matrix, as we can't assume that matrix C is already symmetric + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); + + syrk::apply_blas_type(D,A,alpha); + + // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1 + arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem); + + return; + } + + arma_extra_debug_print("blas::syrk()"); + + const char uplo = 'U'; + + const char trans_A = (do_trans_A) ? 'T' : 'N'; + + const blas_int n = blas_int(C.n_cols); + const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols); + + const eT local_alpha = (use_alpha) ? alpha : eT(1); + const eT local_beta = (use_beta) ? beta : eT(0); + + const blas_int lda = (do_trans_A) ? k : n; + + arma_extra_debug_print( arma_str::format("blas::syrk(): trans_A = %c") % trans_A ); + + blas::syrk + ( + &uplo, + &trans_A, + &n, + &k, + &local_alpha, + A.mem, + &lda, + &local_beta, + C.memptr(), + &n // &ldc + ); + + syrk_helper::inplace_copy_upper_tri_to_lower_tri(C); + } + #else + { + syrk_emul::apply(C,A,alpha,beta); + } + #endif + } + } + + + + template + inline + static + void + apply( Mat& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0) ) + { + if(is_cx::no) + { + if(A.is_vec()) + { + syrk_vec::apply(C,A,alpha,beta); + } + else + { + syrk_emul::apply(C,A,alpha,beta); + } + } + else + { + // handling of complex matrix by syrk_emul() is not yet implemented + return; + } + } + + + + template + arma_inline + static + void + apply + ( + Mat& C, + const TA& A, + const float alpha = float(1), + const float beta = float(0) + ) + { + syrk::apply_blas_type(C,A,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + Mat& C, + const TA& A, + const double alpha = double(1), + const double beta = double(0) + ) + { + syrk::apply_blas_type(C,A,alpha,beta); + } + + + + template + arma_inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const std::complex alpha = std::complex(1), + const std::complex beta = std::complex(0) + ) + { + arma_ignore(C); + arma_ignore(A); + arma_ignore(alpha); + arma_ignore(beta); + + // handling of complex matrix by syrk() is not yet implemented + return; + } + + + + template + arma_inline + static + void + apply + ( + Mat< std::complex >& C, + const TA& A, + const std::complex alpha = std::complex(1), + const std::complex beta = std::complex(0) + ) + { + arma_ignore(C); + arma_ignore(A); + arma_ignore(alpha); + arma_ignore(beta); + + // handling of complex matrix by syrk() is not yet implemented + return; + } + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_bones.hpp new file mode 100644 index 0000000..90c3b5a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Define matrix operations on existing matrix objects +template +class DenseGenMatProd + { + private: + + const Mat& op_mat; + + + public: + + const uword n_rows; // number of rows of the underlying matrix + const uword n_cols; // number of columns of the underlying matrix + + inline DenseGenMatProd(const Mat& mat_obj); + + inline void perform_op(eT* x_in, eT* y_out) const; + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_meat.hpp new file mode 100644 index 0000000..8909245 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_DenseGenMatProd_meat.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +DenseGenMatProd::DenseGenMatProd(const Mat& mat_obj) + : op_mat(mat_obj) + , n_rows(mat_obj.n_rows) + , n_cols(mat_obj.n_cols) + { + arma_extra_debug_sigprint(); + } + + + +// Perform the matrix-vector multiplication operation \f$y=Ax\f$. +// y_out = A * x_in +template +inline +void +DenseGenMatProd::perform_op(eT* x_in, eT* y_out) const + { + arma_extra_debug_sigprint(); + + const Col x(x_in , n_cols, false, true); + Col y(y_out, n_rows, false, true); + + y = op_mat * x; + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_bones.hpp new file mode 100644 index 0000000..1599568 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_bones.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +class DoubleShiftQR + { + private: + + uword n; // Dimension of the matrix + Mat mat_H; // A copy of the matrix to be factorised + eT shift_s; // Shift constant + eT shift_t; // Shift constant + Mat ref_u; // Householder reflectors + Col ref_nr; // How many rows does each reflector affects + // 3 - A general reflector + // 2 - A Givens rotation + // 1 - An identity transformation + const eT prec; // Approximately zero + const eT eps_rel; + const eT eps_abs; + bool computed; // Whether matrix has been factorised + + inline void compute_reflector(const eT& x1, const eT& x2, const eT& x3, uword ind); + arma_inline void compute_reflector(const eT* x, uword ind); + + // Update the block X = H(il:iu, il:iu) + inline void update_block(uword il, uword iu); + + // P = I - 2 * u * u' = P' + // PX = X - 2 * u * (u'X) + inline void apply_PX(Mat& X, uword oi, uword oj, uword nrow, uword ncol, uword u_ind); + + // x is a pointer to a vector + // Px = x - 2 * dot(x, u) * u + inline void apply_PX(eT* x, uword u_ind); + + // XP = X - 2 * (X * u) * u' + inline void apply_XP(Mat& X, uword oi, uword oj, uword nrow, uword ncol, uword u_ind); + + + public: + + inline DoubleShiftQR(uword size); + + inline DoubleShiftQR(const Mat& mat_obj, eT s, eT t); + + inline void compute(const Mat& mat_obj, eT s, eT t); + + inline Mat matrix_QtHQ(); + + inline void apply_QtY(Col& y); + + inline void apply_YQ(Mat& Y); + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_meat.hpp new file mode 100644 index 0000000..1c7497d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_DoubleShiftQR_meat.hpp @@ -0,0 +1,399 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +void +DoubleShiftQR::compute_reflector(const eT& x1, const eT& x2, const eT& x3, uword ind) + { + arma_extra_debug_sigprint(); + + // In general case the reflector affects 3 rows + ref_nr(ind) = 3; + eT x2x3 = eT(0); + // If x3 is zero, decrease nr by 1 + if(std::abs(x3) < prec) + { + // If x2 is also zero, nr will be 1, and we can exit this function + if(std::abs(x2) < prec) + { + ref_nr(ind) = 1; + return; + } + else + { + ref_nr(ind) = 2; + } + x2x3 = std::abs(x2); + } + else + { + x2x3 = arma_hypot(x2, x3); + } + + // x1' = x1 - rho * ||x|| + // rho = -sign(x1), if x1 == 0, we choose rho = 1 + eT x1_new = x1 - ((x1 <= 0) - (x1 > 0)) * arma_hypot(x1, x2x3); + eT x_norm = arma_hypot(x1_new, x2x3); + // Double check the norm of new x + if(x_norm < prec) + { + ref_nr(ind) = 1; + return; + } + ref_u(0, ind) = x1_new / x_norm; + ref_u(1, ind) = x2 / x_norm; + ref_u(2, ind) = x3 / x_norm; + } + + +template +arma_inline +void +DoubleShiftQR::compute_reflector(const eT* x, uword ind) + { + arma_extra_debug_sigprint(); + + compute_reflector(x[0], x[1], x[2], ind); + } + + + +template +inline +void +DoubleShiftQR::update_block(uword il, uword iu) + { + arma_extra_debug_sigprint(); + + // Block size + uword bsize = iu - il + 1; + + // If block size == 1, there is no need to apply reflectors + if(bsize == 1) + { + ref_nr(il) = 1; + return; + } + + // For block size == 2, do a Givens rotation on M = X * X - s * X + t * I + if(bsize == 2) + { + // m00 = x00 * (x00 - s) + x01 * x10 + t + eT m00 = mat_H(il, il) * (mat_H(il, il) - shift_s) + + mat_H(il, il + 1) * mat_H(il + 1, il) + + shift_t; + // m10 = x10 * (x00 + x11 - s) + eT m10 = mat_H(il + 1, il) * (mat_H(il, il) + mat_H(il + 1, il + 1) - shift_s); + // This causes nr=2 + compute_reflector(m00, m10, 0, il); + // Apply the reflector to X + apply_PX(mat_H, il, il, 2, n - il, il); + apply_XP(mat_H, 0, il, il + 2, 2, il); + + ref_nr(il + 1) = 1; + return; + } + + // For block size >=3, use the regular strategy + eT m00 = mat_H(il, il) * (mat_H(il, il) - shift_s) + + mat_H(il, il + 1) * mat_H(il + 1, il) + + shift_t; + eT m10 = mat_H(il + 1, il) * (mat_H(il, il) + mat_H(il + 1, il + 1) - shift_s); + // m20 = x21 * x10 + eT m20 = mat_H(il + 2, il + 1) * mat_H(il + 1, il); + compute_reflector(m00, m10, m20, il); + + // Apply the first reflector + apply_PX(mat_H, il, il, 3, n - il, il); + apply_XP(mat_H, 0, il, il + (std::min)(bsize, uword(4)), 3, il); + + // Calculate the following reflectors + // If entering this loop, block size is at least 4. + for(uword i = 1; i < bsize - 2; i++) + { + compute_reflector(mat_H.colptr(il + i - 1) + il + i, il + i); + // Apply the reflector to X + apply_PX(mat_H, il + i, il + i - 1, 3, n + 1 - il - i, il + i); + apply_XP(mat_H, 0, il + i, il + (std::min)(bsize, uword(i + 4)), 3, il + i); + } + + // The last reflector + // This causes nr=2 + compute_reflector(mat_H(iu - 1, iu - 2), mat_H(iu, iu - 2), 0, iu - 1); + // Apply the reflector to X + apply_PX(mat_H, iu - 1, iu - 2, 2, n + 2 - iu, iu - 1); + apply_XP(mat_H, 0, iu - 1, il + bsize, 2, iu - 1); + + ref_nr(iu) = 1; + } + + + +template +inline +void +DoubleShiftQR::apply_PX(Mat& X, uword oi, uword oj, uword nrow, uword ncol, uword u_ind) + { + arma_extra_debug_sigprint(); + + if(ref_nr(u_ind) == 1) { return; } + + // Householder reflectors at index u_ind + Col u(ref_u.colptr(u_ind), 3, false); + + const uword stride = X.n_rows; + const eT u0_2 = 2 * u(0); + const eT u1_2 = 2 * u(1); + + eT* xptr = &X(oi, oj); + if(ref_nr(u_ind) == 2 || nrow == 2) + { + for(uword i = 0; i < ncol; i++, xptr += stride) + { + eT tmp = u0_2 * xptr[0] + u1_2 * xptr[1]; + xptr[0] -= tmp * u(0); + xptr[1] -= tmp * u(1); + } + } + else + { + const eT u2_2 = 2 * u(2); + for(uword i = 0; i < ncol; i++, xptr += stride) + { + eT tmp = u0_2 * xptr[0] + u1_2 * xptr[1] + u2_2 * xptr[2]; + xptr[0] -= tmp * u(0); + xptr[1] -= tmp * u(1); + xptr[2] -= tmp * u(2); + } + } + } + + + +template +inline +void +DoubleShiftQR::apply_PX(eT* x, uword u_ind) + { + arma_extra_debug_sigprint(); + + if(ref_nr(u_ind) == 1) { return; } + + eT u0 = ref_u(0, u_ind), + u1 = ref_u(1, u_ind), + u2 = ref_u(2, u_ind); + + // When the reflector only contains two elements, u2 has been set to zero + bool nr_is_2 = (ref_nr(u_ind) == 2); + eT dot2 = x[0] * u0 + x[1] * u1 + (nr_is_2 ? 0 : (x[2] * u2)); + dot2 *= 2; + x[0] -= dot2 * u0; + x[1] -= dot2 * u1; + if(!nr_is_2) { x[2] -= dot2 * u2; } + } + + + +template +inline +void +DoubleShiftQR::apply_XP(Mat& X, uword oi, uword oj, uword nrow, uword ncol, uword u_ind) + { + arma_extra_debug_sigprint(); + + if(ref_nr(u_ind) == 1) { return; } + + // Householder reflectors at index u_ind + Col u(ref_u.colptr(u_ind), 3, false); + uword stride = X.n_rows; + const eT u0_2 = 2 * u(0); + const eT u1_2 = 2 * u(1); + eT* X0 = &X(oi, oj); + eT* X1 = X0 + stride; // X0 => X(oi, oj), X1 => X(oi, oj + 1) + + if(ref_nr(u_ind) == 2 || ncol == 2) + { + // tmp = 2 * u0 * X0 + 2 * u1 * X1 + // X0 => X0 - u0 * tmp + // X1 => X1 - u1 * tmp + for(uword i = 0; i < nrow; i++) + { + eT tmp = u0_2 * X0[i] + u1_2 * X1[i]; + X0[i] -= tmp * u(0); + X1[i] -= tmp * u(1); + } + } + else + { + eT* X2 = X1 + stride; // X2 => X(oi, oj + 2) + const eT u2_2 = 2 * u(2); + for(uword i = 0; i < nrow; i++) + { + eT tmp = u0_2 * X0[i] + u1_2 * X1[i] + u2_2 * X2[i]; + X0[i] -= tmp * u(0); + X1[i] -= tmp * u(1); + X2[i] -= tmp * u(2); + } + } + } + + + +template +inline +DoubleShiftQR::DoubleShiftQR(uword size) + : n(size) + , prec(std::numeric_limits::epsilon()) + , eps_rel(prec) + , eps_abs(prec) + , computed(false) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +DoubleShiftQR::DoubleShiftQR(const Mat& mat_obj, eT s, eT t) + : n(mat_obj.n_rows) + , mat_H(n, n) + , shift_s(s) + , shift_t(t) + , ref_u(3, n) + , ref_nr(n) + , prec(std::numeric_limits::epsilon()) + , eps_rel(prec) + , eps_abs(prec) + , computed(false) + { + arma_extra_debug_sigprint(); + + compute(mat_obj, s, t); + } + + + +template +void +DoubleShiftQR::compute(const Mat& mat_obj, eT s, eT t) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (mat_obj.is_square() == false), "newarp::DoubleShiftQR::compute(): matrix must be square" ); + + n = mat_obj.n_rows; + mat_H.set_size(n, n); + shift_s = s; + shift_t = t; + ref_u.set_size(3, n); + ref_nr.set_size(n); + + // Make a copy of mat_obj + mat_H = mat_obj; + + // Obtain the indices of zero elements in the subdiagonal, + // so that H can be divided into several blocks + std::vector zero_ind; + zero_ind.reserve(n - 1); + zero_ind.push_back(0); + eT* Hii = mat_H.memptr(); + for(uword i = 0; i < n - 2; i++, Hii += (n + 1)) + { + // Hii[1] => mat_H(i + 1, i) + const eT h = std::abs(Hii[1]); + if(h <= eps_abs || h <= eps_rel * (std::abs(Hii[0]) + std::abs(Hii[n + 1]))) + { + Hii[1] = 0; + zero_ind.push_back(i + 1); + } + // Make sure mat_H is upper Hessenberg + // Zero the elements below mat_H(i + 1, i) + std::fill(Hii + 2, Hii + n - i, eT(0)); + } + zero_ind.push_back(n); + + for(std::vector::size_type i = 0; i < zero_ind.size() - 1; i++) + { + uword start = zero_ind[i]; + uword end = zero_ind[i + 1] - 1; + // Compute refelctors from each block X + update_block(start, end); + } + + computed = true; + } + + + +template +Mat +DoubleShiftQR::matrix_QtHQ() + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::DoubleShiftQR::matrix_QtHQ(): need to call compute() first" ); + + return mat_H; + } + + + +template +inline +void +DoubleShiftQR::apply_QtY(Col& y) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::DoubleShiftQR::apply_QtY(): need to call compute() first" ); + + eT* y_ptr = y.memptr(); + for(uword i = 0; i < n - 1; i++, y_ptr++) + { + apply_PX(y_ptr, i); + } + } + + + +template +inline +void +DoubleShiftQR::apply_YQ(Mat& Y) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::DoubleShiftQR::apply_YQ(): need to call compute() first" ); + + uword nrow = Y.n_rows; + for(uword i = 0; i < n - 2; i++) + { + apply_XP(Y, 0, i, nrow, 3, i); + } + + apply_XP(Y, 0, n - 2, nrow, 2, n - 2); + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_EigsSelect.hpp b/src/armadillo/include/armadillo_bits/newarp_EigsSelect.hpp new file mode 100644 index 0000000..d518c64 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_EigsSelect.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! The enumeration of selection rules of desired eigenvalues. +struct EigsSelect + { + enum SELECT_EIGENVALUE + { + LARGEST_MAGN = 0, //!< Select eigenvalues with largest magnitude. + //!< Magnitude means the absolute value for real numbers and norm for complex numbers. + //!< Applies to both symmetric and general eigen solvers. + + LARGEST_REAL, //!< Select eigenvalues with largest real part. Only for general eigen solvers. + + LARGEST_IMAG, //!< Select eigenvalues with largest imaginary part (in magnitude). Only for general eigen solvers. + + LARGEST_ALGE, //!< Select eigenvalues with largest algebraic value, considering any negative sign. Only for symmetric eigen solvers. + + SMALLEST_MAGN, //!< Select eigenvalues with smallest magnitude. Applies to both symmetric and general eigen solvers. + + SMALLEST_REAL, //!< Select eigenvalues with smallest real part. Only for general eigen solvers. + + SMALLEST_IMAG, //!< Select eigenvalues with smallest imaginary part (in magnitude). Only for general eigen solvers. + + SMALLEST_ALGE, //!< Select eigenvalues with smallest algebraic value. Only for symmetric eigen solvers. + + BOTH_ENDS //!< Select eigenvalues half from each end of the spectrum. + //!< When `nev` is odd, compute more from the high end. Only for symmetric eigen solvers. + }; + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_bones.hpp new file mode 100644 index 0000000..eabaf06 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_bones.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! This class implements the eigen solver for general real matrices. +template +class GenEigsSolver + { + protected: + + const OpType& op; // object to conduct matrix operation, eg. matrix-vector product + const uword nev; // number of eigenvalues requested + Col< std::complex > ritz_val; // ritz values + + // Sort the first nev Ritz pairs in decreasing magnitude order + // This is used to return the final results + virtual void sort_ritzpair(); + + + private: + + const uword dim_n; // dimension of matrix A + const uword ncv; // number of ritz values + uword nmatop; // number of matrix operations called + uword niter; // number of restarting iterations + Mat fac_V; // V matrix in the Arnoldi factorisation + Mat fac_H; // H matrix in the Arnoldi factorisation + Col fac_f; // residual in the Arnoldi factorisation + Mat< std::complex > ritz_vec; // ritz vectors + Col< std::complex > ritz_est; // last row of ritz_vec + std::vector ritz_conv; // indicator of the convergence of ritz values + const eT eps; // the machine precision + // eg. ~= 1e-16 for double type + const eT approx0; // a number that is approximately zero + // approx0 = eps^(2/3) + // used to test the orthogonality of vectors, + // and in convergence test, tol*approx0 is + // the absolute tolerance + + std::mt19937_64 local_rng; // local random number generator + + inline void fill_rand(eT* dest, const uword N, const uword seed_val); + + // Arnoldi factorisation starting from step-k + inline void factorise_from(uword from_k, uword to_m, const Col& fk); + + // Implicitly restarted Arnoldi factorisation + inline void restart(uword k); + + // Calculate the number of converged Ritz values + inline uword num_converged(eT tol); + + // Return the adjusted nev for restarting + inline uword nev_adjusted(uword nconv); + + // Retrieve and sort ritz values and ritz vectors + inline void retrieve_ritzpair(); + + + public: + + //! Constructor to create a solver object. + inline GenEigsSolver(const OpType& op_, uword nev_, uword ncv_); + + //! Providing the initial residual vector for the algorithm. + inline void init(eT* init_resid); + + //! Providing a random initial residual vector. + inline void init(); + + //! Conducting the major computation procedure. + inline uword compute(uword maxit = 1000, eT tol = 1e-10); + + //! Returning the number of iterations used in the computation. + inline int num_iterations() { return niter; } + + //! Returning the number of matrix operations used in the computation. + inline int num_operations() { return nmatop; } + + //! Returning the converged eigenvalues. + inline Col< std::complex > eigenvalues(); + + //! Returning the eigenvectors associated with the converged eigenvalues. + inline Mat< std::complex > eigenvectors(uword nvec); + + //! Returning all converged eigenvectors. + inline Mat< std::complex > eigenvectors() { return eigenvectors(nev); } + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_meat.hpp new file mode 100644 index 0000000..290fa4f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_GenEigsSolver_meat.hpp @@ -0,0 +1,492 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +void +GenEigsSolver::fill_rand(eT* dest, const uword N, const uword seed_val) + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type seed_type; + + local_rng.seed( seed_type(seed_val) ); + + std::uniform_real_distribution dist(-1.0, +1.0); + + for(uword i=0; i < N; ++i) { dest[i] = eT(dist(local_rng)); } + } + + + +template +inline +void +GenEigsSolver::factorise_from(uword from_k, uword to_m, const Col& fk) + { + arma_extra_debug_sigprint(); + + if(to_m <= from_k) { return; } + + fac_f = fk; + + Col w(dim_n, arma_zeros_indicator()); + eT beta = norm(fac_f); + // Keep the upperleft k x k submatrix of H and set other elements to 0 + fac_H.tail_cols(ncv - from_k).zeros(); + fac_H.submat(span(from_k, ncv - 1), span(0, from_k - 1)).zeros(); + for(uword i = from_k; i <= to_m - 1; i++) + { + bool restart = false; + // If beta = 0, then the next V is not full rank + // We need to generate a new residual vector that is orthogonal + // to the current V, which we call a restart + if(beta < eps) + { + // // Generate new random vector for fac_f + // blas_int idist = 2; + // blas_int iseed[4] = {1, 3, 5, 7}; + // iseed[0] = (i + 100) % 4095; + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, fac_f.memptr()); + + // Generate new random vector for fac_f + fill_rand(fac_f.memptr(), dim_n, i+1); + + // f <- f - V * V' * f, so that f is orthogonal to V + Mat Vs(fac_V.memptr(), dim_n, i, false); // First i columns + Col Vf = Vs.t() * fac_f; + fac_f -= Vs * Vf; + // beta <- ||f|| + beta = norm(fac_f); + + restart = true; + } + + // v <- f / ||f|| + fac_V.col(i) = fac_f / beta; // The (i+1)-th column + + // Note that H[i+1, i] equals to the unrestarted beta + if(restart) { fac_H(i, i - 1) = 0.0; } else { fac_H(i, i - 1) = beta; } + + // w <- A * v, v = fac_V.col(i) + op.perform_op(fac_V.colptr(i), w.memptr()); + nmatop++; + + // First i+1 columns of V + Mat Vs(fac_V.memptr(), dim_n, i + 1, false); + // h = fac_H(0:i, i) + Col h(fac_H.colptr(i), i + 1, false); + // h <- V' * w + h = Vs.t() * w; + + // f <- w - V * h + fac_f = w - Vs * h; + beta = norm(fac_f); + + if(beta > 0.717 * norm(h)) { continue; } + + // f/||f|| is going to be the next column of V, so we need to test + // whether V' * (f/||f||) ~= 0 + Col Vf = Vs.t() * fac_f; + // If not, iteratively correct the residual + uword count = 0; + while(count < 5 && abs(Vf).max() > approx0 * beta) + { + // f <- f - V * Vf + fac_f -= Vs * Vf; + // h <- h + Vf + h += Vf; + // beta <- ||f|| + beta = norm(fac_f); + + Vf = Vs.t() * fac_f; + count++; + } + } + } + + + +template +inline +void +GenEigsSolver::restart(uword k) + { + arma_extra_debug_sigprint(); + + if(k >= ncv) { return; } + + DoubleShiftQR decomp_ds(ncv); + UpperHessenbergQR decomp; + + Mat Q(ncv, ncv, fill::eye); + + for(uword i = k; i < ncv; i++) + { + if(cx_attrib::is_complex(ritz_val(i), eT(0)) && (i < (ncv - 1)) && cx_attrib::is_conj(ritz_val(i), ritz_val(i + 1), eT(0))) + { + // H - mu * I = Q1 * R1 + // H <- R1 * Q1 + mu * I = Q1' * H * Q1 + // H - conj(mu) * I = Q2 * R2 + // H <- R2 * Q2 + conj(mu) * I = Q2' * H * Q2 + // + // (H - mu * I) * (H - conj(mu) * I) = Q1 * Q2 * R2 * R1 = Q * R + eT s = 2 * ritz_val(i).real(); + eT t = std::norm(ritz_val(i)); + decomp_ds.compute(fac_H, s, t); + + // Q -> Q * Qi + decomp_ds.apply_YQ(Q); + // H -> Q'HQ + fac_H = decomp_ds.matrix_QtHQ(); + + i++; + } + else + { + // QR decomposition of H - mu * I, mu is real + fac_H.diag() -= ritz_val(i).real(); + decomp.compute(fac_H); + + // Q -> Q * Qi + decomp.apply_YQ(Q); + // H -> Q'HQ = RQ + mu * I + fac_H = decomp.matrix_RQ(); + fac_H.diag() += ritz_val(i).real(); + } + } + + // V -> VQ + // Q has some elements being zero + // The first (ncv - k + i) elements of the i-th column of Q are non-zero + Mat Vs(dim_n, k + 1, arma_nozeros_indicator()); + uword nnz; + for(uword i = 0; i < k; i++) + { + nnz = ncv - k + i + 1; + Mat V(fac_V.memptr(), dim_n, nnz, false); + Col q(Q.colptr(i), nnz, false); + Col v(Vs.colptr(i), dim_n, false); + v = V * q; + } + + Vs.col(k) = fac_V * Q.col(k); + fac_V.head_cols(k + 1) = Vs; + + Col fk = fac_f * Q(ncv - 1, k - 1) + fac_V.col(k) * fac_H(k, k - 1); + factorise_from(k, ncv, fk); + retrieve_ritzpair(); + } + + + +template +inline +uword +GenEigsSolver::num_converged(eT tol) + { + arma_extra_debug_sigprint(); + + // thresh = tol * max(prec, abs(theta)), theta for ritz value + const eT f_norm = arma::norm(fac_f); + for(uword i = 0; i < nev; i++) + { + eT thresh = tol * (std::max)(approx0, std::abs(ritz_val(i))); + eT resid = std::abs(ritz_est(i)) * f_norm; + ritz_conv[i] = (resid < thresh); + } + + return std::count(ritz_conv.begin(), ritz_conv.end(), true); + } + + + +template +inline +uword +GenEigsSolver::nev_adjusted(uword nconv) + { + arma_extra_debug_sigprint(); + + uword nev_new = nev; + + for(uword i = nev; i < ncv; i++) + { + if(std::abs(ritz_est(i)) < eps) { nev_new++; } + } + // Adjust nev_new again, according to dnaup2.f line 660~674 in ARPACK + nev_new += (std::min)(nconv, (ncv - nev_new) / 2); + if(nev_new == 1 && ncv >= 6) + { + nev_new = ncv / 2; + } + else + if(nev_new == 1 && ncv > 3) + { + nev_new = 2; + } + + if(nev_new > ncv - 2) { nev_new = ncv - 2; } + + // Increase nev by one if ritz_val[nev - 1] and + // ritz_val[nev] are conjugate pairs + if(cx_attrib::is_complex(ritz_val(nev_new - 1), eps) && cx_attrib::is_conj(ritz_val(nev_new - 1), ritz_val(nev_new), eps)) + { + nev_new++; + } + + return nev_new; + } + + + +template +inline +void +GenEigsSolver::retrieve_ritzpair() + { + arma_extra_debug_sigprint(); + + UpperHessenbergEigen decomp(fac_H); + + Col< std::complex > evals = decomp.eigenvalues(); + Mat< std::complex > evecs = decomp.eigenvectors(); + + SortEigenvalue< std::complex, SelectionRule > sorting(evals.memptr(), evals.n_elem); + std::vector ind = sorting.index(); + + // Copy the ritz values and vectors to ritz_val and ritz_vec, respectively + for(uword i = 0; i < ncv; i++) + { + ritz_val(i) = evals(ind[i]); + ritz_est(i) = evecs(ncv - 1, ind[i]); + } + for(uword i = 0; i < nev; i++) + { + ritz_vec.col(i) = evecs.col(ind[i]); + } + } + + + +template +inline +void +GenEigsSolver::sort_ritzpair() + { + arma_extra_debug_sigprint(); + + // SortEigenvalue< std::complex, EigsSelect::LARGEST_MAGN > sorting(ritz_val.memptr(), nev); + + // sort Ritz values according to SelectionRule, to be consistent with ARPACK + SortEigenvalue< std::complex, SelectionRule > sorting(ritz_val.memptr(), nev); + + std::vector ind = sorting.index(); + + Col< std::complex > new_ritz_val(ncv, arma_zeros_indicator() ); + Mat< std::complex > new_ritz_vec(ncv, nev, arma_nozeros_indicator()); + std::vector new_ritz_conv(nev); + + for(uword i = 0; i < nev; i++) + { + new_ritz_val(i) = ritz_val(ind[i]); + new_ritz_vec.col(i) = ritz_vec.col(ind[i]); + new_ritz_conv[i] = ritz_conv[ind[i]]; + } + + ritz_val.swap(new_ritz_val); + ritz_vec.swap(new_ritz_vec); + ritz_conv.swap(new_ritz_conv); + } + + + +template +inline +GenEigsSolver::GenEigsSolver(const OpType& op_, uword nev_, uword ncv_) + : op(op_) + , nev(nev_) + , dim_n(op.n_rows) + , ncv(ncv_ > dim_n ? dim_n : ncv_) + , nmatop(0) + , niter(0) + , eps(std::numeric_limits::epsilon()) + , approx0(std::pow(eps, eT(2.0) / 3)) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (nev_ < 1 || nev_ > dim_n - 2), "newarp::GenEigsSolver: nev must satisfy 1 <= nev <= n - 2, n is the size of matrix" ); + arma_debug_check( (ncv_ < nev_ + 2 || ncv_ > dim_n), "newarp::GenEigsSolver: ncv must satisfy nev + 2 <= ncv <= n, n is the size of matrix" ); + } + + + +template +inline +void +GenEigsSolver::init(eT* init_resid) + { + arma_extra_debug_sigprint(); + + // Reset all matrices/vectors to zero + fac_V.zeros(dim_n, ncv); + fac_H.zeros(ncv, ncv); + fac_f.zeros(dim_n); + ritz_val.zeros(ncv); + ritz_vec.zeros(ncv, nev); + ritz_est.zeros(ncv); + ritz_conv.assign(nev, false); + + nmatop = 0; + niter = 0; + + Col r(init_resid, dim_n, false); + // The first column of fac_V + Col v(fac_V.colptr(0), dim_n, false); + eT rnorm = norm(r); + arma_check( (rnorm < eps), "newarp::GenEigsSolver::init(): initial residual vector cannot be zero" ); + v = r / rnorm; + + Col w(dim_n, arma_zeros_indicator()); + op.perform_op(v.memptr(), w.memptr()); + nmatop++; + + fac_H(0, 0) = dot(v, w); + fac_f = w - v * fac_H(0, 0); + } + + + +template +inline +void +GenEigsSolver::init() + { + arma_extra_debug_sigprint(); + + // podarray init_resid(dim_n); + // blas_int idist = 2; // Uniform(-1, 1) + // blas_int iseed[4] = {1, 3, 5, 7}; // Fixed random seed + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, init_resid.memptr()); + // init(init_resid.memptr()); + + podarray init_resid(dim_n); + + fill_rand(init_resid.memptr(), dim_n, 0); + + init(init_resid.memptr()); + } + + + +template +inline +uword +GenEigsSolver::compute(uword maxit, eT tol) + { + arma_extra_debug_sigprint(); + + // The m-step Arnoldi factorisation + factorise_from(1, ncv, fac_f); + retrieve_ritzpair(); + // Restarting + uword i, nconv = 0, nev_adj; + for(i = 0; i < maxit; i++) + { + nconv = num_converged(tol); + if(nconv >= nev) { break; } + + nev_adj = nev_adjusted(nconv); + restart(nev_adj); + } + // Sorting results + sort_ritzpair(); + + niter = i + 1; + + return (std::min)(nev, nconv); + } + + + +template +inline +Col< std::complex > +GenEigsSolver::eigenvalues() + { + arma_extra_debug_sigprint(); + + uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); + Col< std::complex > res(nconv, arma_zeros_indicator()); + + if(nconv > 0) + { + uword j = 0; + for(uword i = 0; i < nev; i++) + { + if(ritz_conv[i]) + { + res(j) = ritz_val(i); + j++; + } + } + } + + return res; + } + + + +template +inline +Mat< std::complex > +GenEigsSolver::eigenvectors(uword nvec) + { + arma_extra_debug_sigprint(); + + uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); + nvec = (std::min)(nvec, nconv); + Mat< std::complex > res(dim_n, nvec); + + if(nvec > 0) + { + Mat< std::complex > ritz_vec_conv(ncv, nvec, arma_zeros_indicator()); + uword j = 0; + for(uword i = 0; (i < nev) && (j < nvec); i++) + { + if(ritz_conv[i]) + { + ritz_vec_conv.col(j) = ritz_vec.col(i); + j++; + } + } + + res = fac_V * ritz_vec_conv; + } + + return res; + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SortEigenvalue.hpp b/src/armadillo/include/armadillo_bits/newarp_SortEigenvalue.hpp new file mode 100644 index 0000000..5f2c357 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SortEigenvalue.hpp @@ -0,0 +1,203 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +// When comparing eigenvalues, we first calculate the "target" to sort. +// For example, if we want to choose the eigenvalues with largest magnitude, the target will be -std::abs(x). +// The minus sign is due to the fact that std::sort() sorts in ascending order. + + +// default target: throw an exception +template +struct SortingTarget + { + arma_inline static typename get_pod_type::result get(const eT& val) + { + arma_ignore(val); + arma_stop_logic_error("newarp::SortingTarget: incompatible selection rule"); + + typedef typename get_pod_type::result out_T; + return out_T(0); + } + }; + + +// specialisation for LARGEST_MAGN: this covers [float, double, complex] x [LARGEST_MAGN] +template +struct SortingTarget + { + arma_inline static typename get_pod_type::result get(const eT& val) + { + return -std::abs(val); + } + }; + + +// specialisation for LARGEST_REAL: this covers [complex] x [LARGEST_REAL] +template +struct SortingTarget, EigsSelect::LARGEST_REAL> + { + arma_inline static T get(const std::complex& val) + { + return -val.real(); + } + }; + + +// specialisation for LARGEST_IMAG: this covers [complex] x [LARGEST_IMAG] +template +struct SortingTarget, EigsSelect::LARGEST_IMAG> + { + arma_inline static T get(const std::complex& val) + { + return -std::abs(val.imag()); + } + }; + + +// specialisation for LARGEST_ALGE: this covers [float, double] x [LARGEST_ALGE] +template +struct SortingTarget + { + arma_inline static eT get(const eT& val) + { + return -val; + } + }; + + +// Here BOTH_ENDS is the same as LARGEST_ALGE, but we need some additional steps, +// which are done in SymEigsSolver => retrieve_ritzpair(). +// There we move the smallest values to the proper locations. +template +struct SortingTarget + { + arma_inline static eT get(const eT& val) + { + return -val; + } + }; + + +// specialisation for SMALLEST_MAGN: this covers [float, double, complex] x [SMALLEST_MAGN] +template +struct SortingTarget + { + arma_inline static typename get_pod_type::result get(const eT& val) + { + return std::abs(val); + } + }; + + +// specialisation for SMALLEST_REAL: this covers [complex] x [SMALLEST_REAL] +template +struct SortingTarget, EigsSelect::SMALLEST_REAL> + { + arma_inline static T get(const std::complex& val) + { + return val.real(); + } + }; + + +// specialisation for SMALLEST_IMAG: this covers [complex] x [SMALLEST_IMAG] +template +struct SortingTarget, EigsSelect::SMALLEST_IMAG> + { + arma_inline static T get(const std::complex& val) + { + return std::abs(val.imag()); + } + }; + + +// specialisation for SMALLEST_ALGE: this covers [float, double] x [SMALLEST_ALGE] +template +struct SortingTarget + { + arma_inline static eT get(const eT& val) + { + return val; + } + }; + + +// Sort eigenvalues and return the order index +template +struct PairComparator + { + arma_inline bool operator() (const PairType& v1, const PairType& v2) + { + return v1.first < v2.first; + } + }; + + +template +class SortEigenvalue + { + private: + + typedef typename get_pod_type::result TargetType; // type of the sorting target, will be a floating number type, eg. double + typedef std::pair PairType; // type of the sorting pair, including the sorting target and the index + + std::vector pair_sort; + + + public: + + inline + SortEigenvalue(const eT* start, const uword size) + : pair_sort(size) + { + arma_extra_debug_sigprint(); + + for(uword i = 0; i < size; i++) + { + pair_sort[i].first = SortingTarget::get(start[i]); + pair_sort[i].second = i; + } + + PairComparator comp; + + std::sort(pair_sort.begin(), pair_sort.end(), comp); + } + + + inline + std::vector + index() + { + arma_extra_debug_sigprint(); + + const uword len = pair_sort.size(); + + std::vector ind(len); + + for(uword i = 0; i < len; i++) { ind[i] = pair_sort[i].second; } + + return ind; + } + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_bones.hpp new file mode 100644 index 0000000..2028aee --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_bones.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Define matrix operations on existing matrix objects +template +class SparseGenMatProd + { + private: + + const SpMat& op_mat; + SpMat op_mat_st; + + + public: + + const uword n_rows; // number of rows of the underlying matrix + const uword n_cols; // number of columns of the underlying matrix + + inline SparseGenMatProd(const SpMat& mat_obj); + + inline void perform_op(eT* x_in, eT* y_out) const; + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_meat.hpp new file mode 100644 index 0000000..bbe539a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SparseGenMatProd_meat.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +SparseGenMatProd::SparseGenMatProd(const SpMat& mat_obj) + : op_mat(mat_obj) + , n_rows(mat_obj.n_rows) + , n_cols(mat_obj.n_cols) + { + arma_extra_debug_sigprint(); + + op_mat_st = op_mat.st(); // pre-calculate transpose + } + + + +// Perform the matrix-vector multiplication operation \f$y=Ax\f$. +// y_out = A * x_in +template +inline +void +SparseGenMatProd::perform_op(eT* x_in, eT* y_out) const + { + arma_extra_debug_sigprint(); + + // // OLD METHOD + // + // const Col x(x_in , n_cols, false, true); + // Col y(y_out, n_rows, false, true); + // + // y = op_mat * x; + + + // NEW METHOD + + const Row x(x_in , n_cols, false, true); + Row y(y_out, n_rows, false, true); + + y = x * op_mat_st; + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp new file mode 100644 index 0000000..a47575d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Define matrix operations on existing matrix objects +template +class SparseGenRealShiftSolve + { + private: + + #if defined(ARMA_USE_SUPERLU) + // The following objects are read-only in perform_op() + mutable superlu_supermatrix_wrangler l; + mutable superlu_supermatrix_wrangler u; + mutable superlu_array_wrangler perm_c; + mutable superlu_array_wrangler perm_r; + #endif + + + public: + + bool valid = false; + + const uword n_rows; // number of rows of the underlying matrix + const uword n_cols; // number of columns of the underlying matrix + + inline SparseGenRealShiftSolve(const SpMat& mat_obj, const eT shift); + + inline void perform_op(eT* x_in, eT* y_out) const; + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp new file mode 100644 index 0000000..ea20618 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +SparseGenRealShiftSolve::SparseGenRealShiftSolve(const SpMat& mat_obj, const eT shift) + #if defined(ARMA_USE_SUPERLU) + : perm_c(mat_obj.n_cols + 1) + , perm_r(mat_obj.n_rows + 1) + , n_rows(mat_obj.n_rows) + , n_cols(mat_obj.n_cols) + #else + : n_rows(0) + , n_cols(0) + #endif + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + // Derived from sp_auxlib::run_aupd_shiftinvert() + superlu_opts superlu_opts_default; + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, superlu_opts_default); + + superlu::GlobalLU_t Glu; + arrayops::fill_zeros(reinterpret_cast(&Glu), sizeof(superlu::GlobalLU_t)); + + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler xC; + superlu_array_wrangler etree(mat_obj.n_cols+1); + + // Copy A-shift*I to x + const bool status_x = sp_auxlib::copy_to_supermatrix_with_shift(x.get_ref(), mat_obj, shift); + + if(status_x == false) { arma_stop_runtime_error("newarp::SparseGenRealShiftSolve::SparseGenRealShiftSolve(): could not construct SuperLU matrix"); return; } + + int panel_size = superlu::sp_ispec_environ(1); + int relax = superlu::sp_ispec_environ(2); + int slu_info = 0; // Return code + int lwork = 0; // lwork = 0: allocate space internally by system malloc + + superlu_stat_wrangler stat; + + arma_extra_debug_print("superlu::gstrf()"); + superlu::get_permutation_c(options.ColPerm, x.get_ptr(), perm_c.get_ptr()); + superlu::sp_preorder_mat(&options, x.get_ptr(), perm_c.get_ptr(), etree.get_ptr(), xC.get_ptr()); + superlu::gstrf(&options, xC.get_ptr(), relax, panel_size, etree.get_ptr(), NULL, lwork, perm_c.get_ptr(), perm_r.get_ptr(), l.get_ptr(), u.get_ptr(), &Glu, stat.get_ptr(), &slu_info); + + if(slu_info != 0) + { + arma_debug_warn_level(2, "matrix is singular to working precision"); + return; + } + + eT x_norm_val = sp_auxlib::norm1(x.get_ptr()); + eT x_rcond = sp_auxlib::lu_rcond(l.get_ptr(), u.get_ptr(), x_norm_val); + + if( (x_rcond < std::numeric_limits::epsilon()) || arma_isnan(x_rcond) ) + { + if(x_rcond == eT(0)) { arma_debug_warn_level(2, "matrix is singular to working precision"); } + else { arma_debug_warn_level(2, "matrix is singular to working precision (rcond: ", x_rcond, ")"); } + return; + } + + valid = true; + } + #else + { + arma_ignore(mat_obj); + arma_ignore(shift); + } + #endif + } + + + +// Perform the shift-solve operation \f$y=(A-\sigma I)^{-1}x\f$. +// y_out = inv(A - sigma * I) * x_in +template +inline +void +SparseGenRealShiftSolve::perform_op(eT* x_in, eT* y_out) const + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + const Col x(x_in , n_cols, false, true); + Col y(y_out, n_rows, false, true); + + // Derived from sp_auxlib::run_aupd_shiftinvert() + y = x; + superlu_supermatrix_wrangler out_slu; + + const bool status_out_slu = sp_auxlib::wrap_to_supermatrix(out_slu.get_ref(), y); + + if(status_out_slu == false) { arma_stop_runtime_error("newarp::SparseGenRealShiftSolve::perform_op(): could not construct SuperLU matrix"); return; } + + superlu_stat_wrangler stat; + int info = 0; + + arma_extra_debug_print("superlu::gstrs()"); + superlu::gstrs(superlu::NOTRANS, l.get_ptr(), u.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), out_slu.get_ptr(), stat.get_ptr(), &info); + + if(info != 0) { arma_stop_runtime_error("newarp::SparseGenRealShiftSolve::perform_op(): could not solve linear equation"); return; } + + // No need to modify memory further since it was all done in-place. + } + #else + { + arma_ignore(x_in); + arma_ignore(y_out); + } + #endif + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp new file mode 100644 index 0000000..bf3231f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! This class implements the eigen solver for real symmetric matrices in the shift-and-invert mode. +template +class SymEigsShiftSolver : public SymEigsSolver + { + private: + + const eT sigma; + + // Sort the first nev Ritz pairs in ascending algebraic order + // This is used to return the final results + void sort_ritzpair(); + + + public: + + //! Constructor to create a solver object. + inline SymEigsShiftSolver(const OpType& op_, uword nev_, uword ncv_, const eT sigma_); + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp new file mode 100644 index 0000000..bfb2913 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +void +SymEigsShiftSolver::sort_ritzpair() + { + arma_extra_debug_sigprint(); + + // First transform back the Ritz values, and then sort + for(uword i = 0; i < this->nev; i++) + { + this->ritz_val(i) = eT(1.0) / this->ritz_val(i) + sigma; + } + SymEigsSolver::sort_ritzpair(); + } + + + +template +inline +SymEigsShiftSolver::SymEigsShiftSolver(const OpType& op_, uword nev_, uword ncv_, const eT sigma_) + : SymEigsSolver::SymEigsSolver(op_, nev_, ncv_) + , sigma(sigma_) + { + arma_extra_debug_sigprint(); + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_bones.hpp new file mode 100644 index 0000000..612f92a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_bones.hpp @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! This class implements the eigen solver for real symmetric matrices. +template +class SymEigsSolver + { + protected: + + const OpType& op; // object to conduct matrix operation, eg. matrix-vector product + const uword nev; // number of eigenvalues requested + Col ritz_val; // ritz values + + // Sort the first nev Ritz pairs in ascending algebraic order + // This is used to return the final results + virtual void sort_ritzpair(); + + + private: + + const uword dim_n; // dimension of matrix A + const uword ncv; // number of ritz values + uword nmatop; // number of matrix operations called + uword niter; // number of restarting iterations + Mat fac_V; // V matrix in the Arnoldi factorisation + Mat fac_H; // H matrix in the Arnoldi factorisation + Col fac_f; // residual in the Arnoldi factorisation + Mat ritz_vec; // ritz vectors + Col ritz_est; // last row of ritz_vec + std::vector ritz_conv; // indicator of the convergence of ritz values + const eT eps; // the machine precision + // eg. ~= 1e-16 for double type + const eT eps23; // eps^(2/3), used in convergence test + // tol*eps23 is the absolute tolerance + const eT near0; // a very small value, but 1/near0 does not overflow + + std::mt19937_64 local_rng; // local random number generator + + inline void fill_rand(eT* dest, const uword N, const uword seed_val); + + // Arnoldi factorisation starting from step-k + inline void factorise_from(uword from_k, uword to_m, const Col& fk); + + // Implicitly restarted Arnoldi factorisation + inline void restart(uword k); + + // Calculate the number of converged Ritz values + inline uword num_converged(eT tol); + + // Return the adjusted nev for restarting + inline uword nev_adjusted(uword nconv); + + // Retrieve and sort ritz values and ritz vectors + inline void retrieve_ritzpair(); + + + public: + + //! Constructor to create a solver object. + inline SymEigsSolver(const OpType& op_, uword nev_, uword ncv_); + + //! Providing the initial residual vector for the algorithm. + inline void init(eT* init_resid); + + //! Providing a random initial residual vector. + inline void init(); + + //! Conducting the major computation procedure. + inline uword compute(uword maxit = 1000, eT tol = 1e-10); + + //! Returning the number of iterations used in the computation. + inline uword num_iterations() { return niter; } + + //! Returning the number of matrix operations used in the computation. + inline uword num_operations() { return nmatop; } + + //! Returning the converged eigenvalues. + inline Col eigenvalues(); + + //! Returning the eigenvectors associated with the converged eigenvalues. + inline Mat eigenvectors(uword nvec); + + //! Returning all converged eigenvectors. + inline Mat eigenvectors() { return eigenvectors(nev); } + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_meat.hpp new file mode 100644 index 0000000..2223328 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_SymEigsSolver_meat.hpp @@ -0,0 +1,508 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +void +SymEigsSolver::fill_rand(eT* dest, const uword N, const uword seed_val) + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type seed_type; + + local_rng.seed( seed_type(seed_val) ); + + std::uniform_real_distribution dist(-1.0, +1.0); + + for(uword i=0; i < N; ++i) { dest[i] = eT(dist(local_rng)); } + } + + + +template +inline +void +SymEigsSolver::factorise_from(uword from_k, uword to_m, const Col& fk) + { + arma_extra_debug_sigprint(); + + if(to_m <= from_k) { return; } + + fac_f = fk; + + Col w(dim_n, arma_zeros_indicator()); + // Norm of f + eT beta = norm(fac_f); + // Used to test beta~=0 + const eT beta_thresh = eps * eop_aux::sqrt(dim_n); + // Keep the upperleft k x k submatrix of H and set other elements to 0 + fac_H.tail_cols(ncv - from_k).zeros(); + fac_H.submat(span(from_k, ncv - 1), span(0, from_k - 1)).zeros(); + for(uword i = from_k; i <= to_m - 1; i++) + { + bool restart = false; + // If beta = 0, then the next V is not full rank + // We need to generate a new residual vector that is orthogonal + // to the current V, which we call a restart + if(beta < near0) + { + // // Generate new random vector for fac_f + // blas_int idist = 2; + // blas_int iseed[4] = {1, 3, 5, 7}; + // iseed[0] = (i + 100) % 4095; + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, fac_f.memptr()); + + // Generate new random vector for fac_f + fill_rand(fac_f.memptr(), dim_n, i+1); + + // f <- f - V * V' * f, so that f is orthogonal to V + Mat Vs(fac_V.memptr(), dim_n, i, false); // First i columns + Col Vf = Vs.t() * fac_f; + fac_f -= Vs * Vf; + // beta <- ||f|| + beta = norm(fac_f); + + restart = true; + } + + // v <- f / ||f|| + Col v(fac_V.colptr(i), dim_n, false); // The (i+1)-th column + v = fac_f / beta; + + // Note that H[i+1, i] equals to the unrestarted beta + fac_H(i, i - 1) = restart ? eT(0) : beta; + + // w <- A * v, v = fac_V.col(i) + op.perform_op(v.memptr(), w.memptr()); + nmatop++; + + fac_H(i - 1, i) = fac_H(i, i - 1); // Due to symmetry + eT Hii = dot(v, w); + fac_H(i, i) = Hii; + + // f <- w - V * V' * w = w - H[i+1, i] * V{i} - H[i+1, i+1] * V{i+1} + // If restarting, we know that H[i+1, i] = 0 + if(restart) + { + fac_f = w - Hii * v; + } + else + { + fac_f = w - fac_H(i, i - 1) * fac_V.col(i - 1) - Hii * v; + } + + beta = norm(fac_f); + + // f/||f|| is going to be the next column of V, so we need to test + // whether V' * (f/||f||) ~= 0 + Mat Vs(fac_V.memptr(), dim_n, i + 1, false); // First i+1 columns + Col Vf = Vs.t() * fac_f; + eT ortho_err = abs(Vf).max(); + // If not, iteratively correct the residual + uword count = 0; + while(count < 5 && ortho_err > eps * beta) + { + // There is an edge case: when beta=||f|| is close to zero, f mostly consists + // of rounding errors, so the test [ortho_err < eps * beta] is very + // likely to fail. In particular, if beta=0, then the test is ensured to fail. + // Hence when this happens, we force f to be zero, and then restart in the + // next iteration. + if(beta < beta_thresh) + { + fac_f.zeros(); + beta = eT(0); + break; + } + + // f <- f - V * Vf + fac_f -= Vs * Vf; + // h <- h + Vf + fac_H(i - 1, i) += Vf[i - 1]; + fac_H(i, i - 1) = fac_H(i - 1, i); + fac_H(i, i) += Vf[i]; + // beta <- ||f|| + beta = norm(fac_f); + + Vf = Vs.t() * fac_f; + ortho_err = abs(Vf).max(); + count++; + } + } + } + + + +template +inline +void +SymEigsSolver::restart(uword k) + { + arma_extra_debug_sigprint(); + + if(k >= ncv) { return; } + + TridiagQR decomp; + Mat Q(ncv, ncv, fill::eye); + + for(uword i = k; i < ncv; i++) + { + // QR decomposition of H-mu*I, mu is the shift + fac_H.diag() -= ritz_val(i); + decomp.compute(fac_H); + + // Q -> Q * Qi + decomp.apply_YQ(Q); + + // H -> Q'HQ + // Since QR = H - mu * I, we have H = QR + mu * I + // and therefore Q'HQ = RQ + mu * I + fac_H = decomp.matrix_RQ(); + fac_H.diag() += ritz_val(i); + } + + // V -> VQ, only need to update the first k+1 columns + // Q has some elements being zero + // The first (ncv - k + i) elements of the i-th column of Q are non-zero + Mat Vs(dim_n, k + 1, arma_nozeros_indicator()); + uword nnz; + for(uword i = 0; i < k; i++) + { + nnz = ncv - k + i + 1; + Mat V(fac_V.memptr(), dim_n, nnz, false); + Col q(Q.colptr(i), nnz, false); + // OLD CODE: + // Vs.col(i) = V * q; + // NEW CODE: + Col v(Vs.colptr(i), dim_n, false, true); + v = V * q; + } + + Vs.col(k) = fac_V * Q.col(k); + fac_V.head_cols(k + 1) = Vs; + + Col fk = fac_f * Q(ncv - 1, k - 1) + fac_V.col(k) * fac_H(k, k - 1); + factorise_from(k, ncv, fk); + retrieve_ritzpair(); + } + + + +template +inline +uword +SymEigsSolver::num_converged(eT tol) + { + arma_extra_debug_sigprint(); + + // thresh = tol * max(approx0, abs(theta)), theta for ritz value + const eT f_norm = norm(fac_f); + for(uword i = 0; i < nev; i++) + { + eT thresh = tol * (std::max)(eps23, std::abs(ritz_val(i))); + eT resid = std::abs(ritz_est(i)) * f_norm; + ritz_conv[i] = (resid < thresh); + } + + return std::count(ritz_conv.begin(), ritz_conv.end(), true); + } + + + +template +inline +uword +SymEigsSolver::nev_adjusted(uword nconv) + { + arma_extra_debug_sigprint(); + + uword nev_new = nev; + for(uword i = nev; i < ncv; i++) + { + if(std::abs(ritz_est(i)) < near0) { nev_new++; } + } + + // Adjust nev_new, according to dsaup2.f line 677~684 in ARPACK + nev_new += (std::min)(nconv, (ncv - nev_new) / 2); + + if(nev_new >= ncv) { nev_new = ncv - 1; } + + if(nev_new == 1) + { + if(ncv >= 6) { nev_new = ncv / 2; } + else if(ncv > 2) { nev_new = 2; } + } + + return nev_new; + } + + + +template +inline +void +SymEigsSolver::retrieve_ritzpair() + { + arma_extra_debug_sigprint(); + + TridiagEigen decomp(fac_H); + Col evals = decomp.eigenvalues(); + Mat evecs = decomp.eigenvectors(); + + SortEigenvalue sorting(evals.memptr(), evals.n_elem); + std::vector ind = sorting.index(); + + // For BOTH_ENDS, the eigenvalues are sorted according + // to the LARGEST_ALGE rule, so we need to move those smallest + // values to the left + // The order would be + // Largest => Smallest => 2nd largest => 2nd smallest => ... + // We keep this order since the first k values will always be + // the wanted collection, no matter k is nev_updated (used in restart()) + // or is nev (used in sort_ritzpair()) + if(SelectionRule == EigsSelect::BOTH_ENDS) + { + std::vector ind_copy(ind); + for(uword i = 0; i < ncv; i++) + { + // If i is even, pick values from the left (large values) + // If i is odd, pick values from the right (small values) + + ind[i] = (i % 2 == 0) ? ind_copy[i / 2] : ind_copy[ncv - 1 - i / 2]; + } + } + + // Copy the ritz values and vectors to ritz_val and ritz_vec, respectively + for(uword i = 0; i < ncv; i++) + { + ritz_val(i) = evals(ind[i]); + ritz_est(i) = evecs(ncv - 1, ind[i]); + } + for(uword i = 0; i < nev; i++) + { + ritz_vec.col(i) = evecs.col(ind[i]); + } + } + + + +template +inline +void +SymEigsSolver::sort_ritzpair() + { + arma_extra_debug_sigprint(); + + // SortEigenvalue sorting(ritz_val.memptr(), nev); + + // Sort Ritz values in ascending algebraic, to be consistent with ARPACK + SortEigenvalue sorting(ritz_val.memptr(), nev); + + std::vector ind = sorting.index(); + + Col new_ritz_val(ncv, arma_zeros_indicator() ); + Mat new_ritz_vec(ncv, nev, arma_nozeros_indicator()); + std::vector new_ritz_conv(nev); + + for(uword i = 0; i < nev; i++) + { + new_ritz_val(i) = ritz_val(ind[i]); + new_ritz_vec.col(i) = ritz_vec.col(ind[i]); + new_ritz_conv[i] = ritz_conv[ind[i]]; + } + + ritz_val.swap(new_ritz_val); + ritz_vec.swap(new_ritz_vec); + ritz_conv.swap(new_ritz_conv); + } + + + +template +inline +SymEigsSolver::SymEigsSolver(const OpType& op_, uword nev_, uword ncv_) + : op(op_) + , nev(nev_) + , dim_n(op.n_rows) + , ncv(ncv_ > dim_n ? dim_n : ncv_) + , nmatop(0) + , niter(0) + , eps(std::numeric_limits::epsilon()) + , eps23(std::pow(eps, eT(2.0) / 3)) + , near0(std::numeric_limits::min() * eT(10)) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (nev_ < 1 || nev_ > dim_n - 1), "newarp::SymEigsSolver: nev must satisfy 1 <= nev <= n - 1, n is the size of matrix" ); + arma_debug_check( (ncv_ <= nev_ || ncv_ > dim_n), "newarp::SymEigsSolver: ncv must satisfy nev < ncv <= n, n is the size of matrix" ); + } + + + +template +inline +void +SymEigsSolver::init(eT* init_resid) + { + arma_extra_debug_sigprint(); + + // Reset all matrices/vectors to zero + fac_V.zeros(dim_n, ncv); + fac_H.zeros(ncv, ncv); + fac_f.zeros(dim_n); + ritz_val.zeros(ncv); + ritz_vec.zeros(ncv, nev); + ritz_est.zeros(ncv); + ritz_conv.assign(nev, false); + + nmatop = 0; + niter = 0; + + Col r(init_resid, dim_n, false); + // The first column of fac_V + Col v(fac_V.colptr(0), dim_n, false); + eT rnorm = norm(r); + arma_check( (rnorm < near0), "newarp::SymEigsSolver::init(): initial residual vector cannot be zero" ); + v = r / rnorm; + + Col w(dim_n, arma_zeros_indicator()); + op.perform_op(v.memptr(), w.memptr()); + nmatop++; + + fac_H(0, 0) = dot(v, w); + fac_f = w - v * fac_H(0, 0); + + // In some cases f is zero in exact arithmetics, but due to rounding errors + // it may contain tiny fluctuations. When this happens, we force f to be zero + if(abs(fac_f).max() < eps) { fac_f.zeros(); } + } + + + +template +inline +void +SymEigsSolver::init() + { + arma_extra_debug_sigprint(); + + // podarray init_resid(dim_n); + // blas_int idist = 2; // Uniform(-1, 1) + // blas_int iseed[4] = {1, 3, 5, 7}; // Fixed random seed + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, init_resid.memptr()); + // init(init_resid.memptr()); + + podarray init_resid(dim_n); + + fill_rand(init_resid.memptr(), dim_n, 0); + + init(init_resid.memptr()); + } + + + +template +inline +uword +SymEigsSolver::compute(uword maxit, eT tol) + { + arma_extra_debug_sigprint(); + + // The m-step Arnoldi factorisation + factorise_from(1, ncv, fac_f); + retrieve_ritzpair(); + // Restarting + uword i, nconv = 0, nev_adj; + for(i = 0; i < maxit; i++) + { + nconv = num_converged(tol); + if(nconv >= nev) { break; } + + nev_adj = nev_adjusted(nconv); + restart(nev_adj); + } + // Sorting results + sort_ritzpair(); + + niter = i + 1; + + return (std::min)(nev, nconv); + } + + + +template +inline +Col +SymEigsSolver::eigenvalues() + { + arma_extra_debug_sigprint(); + + uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); + Col res(nconv, arma_zeros_indicator()); + + if(nconv > 0) + { + uword j = 0; + + for(uword i=0; i < nev; i++) + { + if(ritz_conv[i]) { res(j) = ritz_val(i); j++; } + } + } + + return res; + } + + + +template +inline +Mat +SymEigsSolver::eigenvectors(uword nvec) + { + arma_extra_debug_sigprint(); + + uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); + nvec = (std::min)(nvec, nconv); + Mat res(dim_n, nvec); + + if(nvec > 0) + { + Mat ritz_vec_conv(ncv, nvec, arma_zeros_indicator()); + + uword j = 0; + + for(uword i=0; i < nev && j < nvec; i++) + { + if(ritz_conv[i]) { ritz_vec_conv.col(j) = ritz_vec.col(i); j++; } + } + + res = fac_V * ritz_vec_conv; + } + + return res; + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_TridiagEigen_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_TridiagEigen_bones.hpp new file mode 100644 index 0000000..9664a3c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_TridiagEigen_bones.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Calculate the eigenvalues and eigenvectors of a symmetric tridiagonal matrix. +//! This class is a wrapper of the Lapack functions `_steqr`. +template +class TridiagEigen + { + private: + + blas_int n; + Col main_diag; // Main diagonal elements of the matrix + Col sub_diag; // Sub-diagonal elements of the matrix + Mat evecs; // To store eigenvectors + bool computed; + + + public: + + //! Default constructor. Computation can + //! be performed later by calling the compute() method. + inline TridiagEigen(); + + //! Constructor to create an object that calculates the eigenvalues + //! and eigenvectors of a symmetric tridiagonal matrix `mat_obj`. + inline TridiagEigen(const Mat& mat_obj); + + //! Compute the eigenvalue decomposition of a symmetric tridiagonal matrix. + inline void compute(const Mat& mat_obj); + + //! Retrieve the eigenvalues. + inline Col eigenvalues(); + + //! Retrieve the eigenvectors. + inline Mat eigenvectors(); + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_TridiagEigen_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_TridiagEigen_meat.hpp new file mode 100644 index 0000000..b11cfec --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_TridiagEigen_meat.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +TridiagEigen::TridiagEigen() + : n(0) + , computed(false) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +TridiagEigen::TridiagEigen(const Mat& mat_obj) + : n(mat_obj.n_rows) + , computed(false) + { + arma_extra_debug_sigprint(); + + compute(mat_obj); + } + + + +template +inline +void +TridiagEigen::compute(const Mat& mat_obj) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (mat_obj.is_square() == false), "newarp::TridiagEigen::compute(): matrix must be square" ); + + n = blas_int(mat_obj.n_rows); + + main_diag = mat_obj.diag(); + sub_diag = mat_obj.diag(-1); + + evecs.set_size(n, n); + + char compz = 'I'; + blas_int lwork_min = 1 + 4*n + n*n; + blas_int liwork_min = 3 + 5*n; + blas_int info = blas_int(0); + + blas_int lwork_proposed = 0; + blas_int liwork_proposed = 0; + + if(n >= 32) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + blas_int iwork_query[2] = {}; + blas_int liwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::stedc()"); + lapack::stedc(&compz, &n, main_diag.memptr(), sub_diag.memptr(), evecs.memptr(), &n, &work_query[0], &lwork_query, &iwork_query[0], &liwork_query, &info); + + if(info != 0) { arma_stop_runtime_error("lapack::stedc(): couldn't get size of work arrays"); return; } + + lwork_proposed = static_cast( work_query[0] ); + liwork_proposed = iwork_query[0]; + } + + blas_int lwork = (std::max)( lwork_min, lwork_proposed); + blas_int liwork = (std::max)(liwork_min, liwork_proposed); + + podarray work( static_cast( lwork) ); + podarray iwork( static_cast(liwork) ); + + arma_extra_debug_print("lapack::stedc()"); + lapack::stedc(&compz, &n, main_diag.memptr(), sub_diag.memptr(), evecs.memptr(), &n, work.memptr(), &lwork, iwork.memptr(), &liwork, &info); + + if(info != 0) { arma_stop_runtime_error("lapack::stedc(): failed to compute all eigenvalues"); return; } + + computed = true; + } + + + +template +inline +Col +TridiagEigen::eigenvalues() + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::TridiagEigen::eigenvalues(): need to call compute() first" ); + + // After calling compute(), main_diag will contain the eigenvalues. + return main_diag; + } + + + +template +inline +Mat +TridiagEigen::eigenvectors() + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::TridiagEigen::eigenvectors(): need to call compute() first" ); + + return evecs; + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp new file mode 100644 index 0000000..668adbe --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Calculate the eigenvalues and eigenvectors of an upper Hessenberg matrix. +//! This class is uses lapack::lahqr() and lapack::trevc() +template +class UpperHessenbergEigen + { + private: + + uword n_rows; + Mat mat_Z; // In the first stage, H = ZTZ', Z is an orthogonal matrix + // In the second stage, Z will be overwritten by the eigenvectors of H + Mat mat_T; // H = ZTZ', T is a Schur form matrix + Col< std::complex > evals; // eigenvalues of H + bool computed; + + + public: + + //! Default constructor. Computation can + //! be performed later by calling the compute() method. + inline UpperHessenbergEigen(); + + //! Constructor to create an object that calculates the eigenvalues + //! and eigenvectors of an upper Hessenberg matrix `mat_obj`. + inline UpperHessenbergEigen(const Mat& mat_obj); + + //! Compute the eigenvalue decomposition of an upper Hessenberg matrix. + inline void compute(const Mat& mat_obj); + + //! Retrieve the eigenvalues. + inline Col< std::complex > eigenvalues(); + + //! Retrieve the eigenvectors. + inline Mat< std::complex > eigenvectors(); + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp new file mode 100644 index 0000000..a758205 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +UpperHessenbergEigen::UpperHessenbergEigen() + : n_rows(0) + , computed(false) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +UpperHessenbergEigen::UpperHessenbergEigen(const Mat& mat_obj) + : n_rows(mat_obj.n_rows) + , computed(false) + { + arma_extra_debug_sigprint(); + + compute(mat_obj); + } + + + +template +inline +void +UpperHessenbergEigen::compute(const Mat& mat_obj) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (mat_obj.is_square() == false), "newarp::UpperHessenbergEigen::compute(): matrix must be square" ); + + n_rows = mat_obj.n_rows; + + mat_Z.set_size(n_rows, n_rows); + mat_T.set_size(n_rows, n_rows); + evals.set_size(n_rows); + + mat_Z.eye(); + mat_T = mat_obj; + + blas_int want_T = blas_int(1); + blas_int want_Z = blas_int(1); + + blas_int n = blas_int(n_rows); + blas_int ilo = blas_int(1); + blas_int ihi = blas_int(n_rows); + blas_int iloz = blas_int(1); + blas_int ihiz = blas_int(n_rows); + + blas_int info = blas_int(0); + + podarray wr(n_rows); + podarray wi(n_rows); + + arma_extra_debug_print("lapack::lahqr()"); + lapack::lahqr(&want_T, &want_Z, &n, &ilo, &ihi, mat_T.memptr(), &n, wr.memptr(), wi.memptr(), &iloz, &ihiz, mat_Z.memptr(), &n, &info); + + if(info != 0) { arma_stop_runtime_error("lapack::lahqr(): failed to compute all eigenvalues"); return; } + + for(uword i=0; i < n_rows; i++) { evals(i) = std::complex(wr[i], wi[i]); } + + char side = 'R'; + char howmny = 'B'; + blas_int m = blas_int(0); + + podarray work(3*n); + + arma_extra_debug_print("lapack::trevc()"); + lapack::trevc(&side, &howmny, (blas_int*) NULL, &n, mat_T.memptr(), &n, (eT*) NULL, &n, mat_Z.memptr(), &n, &n, &m, work.memptr(), &info); + + if(info != 0) { arma_stop_runtime_error("lapack::trevc(): illegal value"); return; } + + computed = true; + } + + + +template +inline +Col< std::complex > +UpperHessenbergEigen::eigenvalues() + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::UpperHessenbergEigen::eigenvalues(): need to call compute() first" ); + + return evals; + } + + + +template +inline +Mat< std::complex > +UpperHessenbergEigen::eigenvectors() + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::UpperHessenbergEigen::eigenvectors(): need to call compute() first" ); + + // Lapack will set the imaginary parts of real eigenvalues to be exact zero + Mat< std::complex > evecs(n_rows, n_rows, arma_zeros_indicator()); + + std::complex* col_ptr = evecs.memptr(); + + for(uword i=0; i < n_rows; i++) + { + if(cx_attrib::is_real(evals(i), eT(0))) + { + // for real eigenvector, normalise and copy + const eT z_norm = norm(mat_Z.col(i)); + + for(uword j=0; j < n_rows; j++) + { + col_ptr[j] = std::complex(mat_Z(j, i) / z_norm, eT(0)); + } + + col_ptr += n_rows; + } + else + { + // complex eigenvectors are stored in consecutive columns + const eT r2 = dot(mat_Z.col(i ), mat_Z.col(i )); + const eT i2 = dot(mat_Z.col(i+1), mat_Z.col(i+1)); + + const eT z_norm = std::sqrt(r2 + i2); + const eT* z_ptr = mat_Z.colptr(i); + + for(uword j=0; j < n_rows; j++) + { + col_ptr[j ] = std::complex(z_ptr[j] / z_norm, z_ptr[j + n_rows] / z_norm); + col_ptr[j + n_rows] = std::conj(col_ptr[j]); + } + + i++; + col_ptr += 2 * n_rows; + } + } + + return evecs; + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp new file mode 100644 index 0000000..4d07f8c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Perform the QR decomposition of an upper Hessenberg matrix. +template +class UpperHessenbergQR + { + protected: + + uword n; + Mat mat_T; + // Gi = [ cos[i] sin[i]] + // [-sin[i] cos[i]] + // Q = G1 * G2 * ... * G_{n-1} + Col rot_cos; + Col rot_sin; + bool computed; + + + public: + + //! Default constructor. Computation can + //! be performed later by calling the compute() method. + inline UpperHessenbergQR(); + + //! Constructor to create an object that performs and stores the + //! QR decomposition of an upper Hessenberg matrix `mat_obj`. + inline UpperHessenbergQR(const Mat& mat_obj); + + //! Conduct the QR factorisation of an upper Hessenberg matrix. + virtual void compute(const Mat& mat_obj); + + //! Return the \f$RQ\f$ matrix, the multiplication of \f$R\f$ and \f$Q\f$, + //! which is an upper Hessenberg matrix. + virtual Mat matrix_RQ(); + + //! Apply the \f$Q\f$ matrix to another matrix \f$Y\f$. + inline void apply_YQ(Mat& Y); + }; + + + +//! Perform the QR decomposition of a tridiagonal matrix, a special +//! case of upper Hessenberg matrices. +template +class TridiagQR : public UpperHessenbergQR + { + public: + + //! Default constructor. Computation can + //! be performed later by calling the compute() method. + inline TridiagQR(); + + //! Constructor to create an object that performs and stores the + //! QR decomposition of a tridiagonal matrix `mat_obj`. + inline TridiagQR(const Mat& mat_obj); + + //! Conduct the QR factorisation of a tridiagonal matrix. + inline void compute(const Mat& mat_obj); + + //! Return the \f$RQ\f$ matrix, the multiplication of \f$R\f$ and \f$Q\f$, + //! which is a tridiagonal matrix. + inline Mat matrix_RQ(); + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp new file mode 100644 index 0000000..c3a6fa8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +UpperHessenbergQR::UpperHessenbergQR() + : n(0) + , computed(false) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +UpperHessenbergQR::UpperHessenbergQR(const Mat& mat_obj) + : n(mat_obj.n_rows) + , mat_T(n, n) + , rot_cos(n - 1) + , rot_sin(n - 1) + , computed(false) + { + arma_extra_debug_sigprint(); + + compute(mat_obj); + } + + + +template +void +UpperHessenbergQR::compute(const Mat& mat_obj) + { + arma_extra_debug_sigprint(); + + n = mat_obj.n_rows; + mat_T.set_size(n, n); + rot_cos.set_size(n - 1); + rot_sin.set_size(n - 1); + + // Make a copy of mat_obj + mat_T = mat_obj; + + eT xi, xj, r, c, s, eps = std::numeric_limits::epsilon(); + eT *ptr; + for(uword i = 0; i < n - 1; i++) + { + // Make sure mat_T is upper Hessenberg + // Zero the elements below mat_T(i + 1, i) + if(i < n - 2) { mat_T(span(i + 2, n - 1), i).zeros(); } + + xi = mat_T(i, i); // mat_T(i, i) + xj = mat_T(i + 1, i); // mat_T(i + 1, i) + r = arma_hypot(xi, xj); + if(r <= eps) + { + r = 0; + rot_cos(i) = c = 1; + rot_sin(i) = s = 0; + } + else + { + rot_cos(i) = c = xi / r; + rot_sin(i) = s = -xj / r; + } + + // For a complete QR decomposition, + // we first obtain the rotation matrix + // G = [ cos sin] + // [-sin cos] + // and then do T[i:(i + 1), i:(n - 1)] = G' * T[i:(i + 1), i:(n - 1)] + + // mat_T.submat(i, i, i + 1, n - 1) = Gt * mat_T.submat(i, i, i + 1, n - 1); + mat_T(i, i) = r; // mat_T(i, i) => r + mat_T(i + 1, i) = 0; // mat_T(i + 1, i) => 0 + ptr = &mat_T(i, i + 1); // mat_T(i, k), k = i+1, i+2, ..., n-1 + for(uword j = i + 1; j < n; j++, ptr += n) + { + eT tmp = ptr[0]; + ptr[0] = c * tmp - s * ptr[1]; + ptr[1] = s * tmp + c * ptr[1]; + } + } + + computed = true; + } + + + +template +Mat +UpperHessenbergQR::matrix_RQ() + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::UpperHessenbergQR::matrix_RQ(): need to call compute() first" ); + + // Make a copy of the R matrix + Mat RQ = trimatu(mat_T); + + for(uword i = 0; i < n - 1; i++) + { + // RQ[, i:(i + 1)] = RQ[, i:(i + 1)] * Gi + // Gi = [ cos[i] sin[i]] + // [-sin[i] cos[i]] + const eT c = rot_cos(i); + const eT s = rot_sin(i); + eT *Yi, *Yi1; + Yi = RQ.colptr(i); + Yi1 = RQ.colptr(i + 1); + for(uword j = 0; j < i + 2; j++) + { + eT tmp = Yi[j]; + Yi[j] = c * tmp - s * Yi1[j]; + Yi1[j] = s * tmp + c * Yi1[j]; + } + + /* Yi = RQ(span(0, i + 1), i); + RQ(span(0, i + 1), i) = (*c) * Yi - (*s) * RQ(span(0, i + 1), i + 1); + RQ(span(0, i + 1), i + 1) = (*s) * Yi + (*c) * RQ(span(0, i + 1), i + 1); */ + } + + return RQ; + } + + + +template +inline +void +UpperHessenbergQR::apply_YQ(Mat& Y) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (computed == false), "newarp::UpperHessenbergQR::apply_YQ(): need to call compute() first" ); + + eT *Y_col_i, *Y_col_i1; + uword nrow = Y.n_rows; + for(uword i = 0; i < n - 1; i++) + { + const eT c = rot_cos(i); + const eT s = rot_sin(i); + Y_col_i = Y.colptr(i); + Y_col_i1 = Y.colptr(i + 1); + for(uword j = 0; j < nrow; j++) + { + eT tmp = Y_col_i[j]; + Y_col_i[j] = c * tmp - s * Y_col_i1[j]; + Y_col_i1[j] = s * tmp + c * Y_col_i1[j]; + } + } + } + + + +template +inline +TridiagQR::TridiagQR() + : UpperHessenbergQR() + { + arma_extra_debug_sigprint(); + } + + + +template +inline +TridiagQR::TridiagQR(const Mat& mat_obj) + : UpperHessenbergQR() + { + arma_extra_debug_sigprint(); + + this->compute(mat_obj); + } + + + +template +inline +void +TridiagQR::compute(const Mat& mat_obj) + { + arma_extra_debug_sigprint(); + + this->n = mat_obj.n_rows; + this->mat_T.set_size(this->n, this->n); + this->rot_cos.set_size(this->n - 1); + this->rot_sin.set_size(this->n - 1); + + this->mat_T.zeros(); + this->mat_T.diag() = mat_obj.diag(); + this->mat_T.diag(1) = mat_obj.diag(-1); + this->mat_T.diag(-1) = mat_obj.diag(-1); + + eT xi, xj, r, c, s, tmp, eps = std::numeric_limits::epsilon(); + eT *ptr; // A number of pointers to avoid repeated address calculation + for(uword i = 0; i < this->n - 1; i++) + { + xi = this->mat_T(i, i); // mat_T(i, i) + xj = this->mat_T(i + 1, i); // mat_T(i + 1, i) + r = arma_hypot(xi, xj); + if(r <= eps) + { + r = 0; + this->rot_cos(i) = c = 1; + this->rot_sin(i) = s = 0; + } + else + { + this->rot_cos(i) = c = xi / r; + this->rot_sin(i) = s = -xj / r; + } + + // For a complete QR decomposition, + // we first obtain the rotation matrix + // G = [ cos sin] + // [-sin cos] + // and then do T[i:(i + 1), i:(i + 2)] = G' * T[i:(i + 1), i:(i + 2)] + + // Update T[i, i] and T[i + 1, i] + // The updated value of T[i, i] is known to be r + // The updated value of T[i + 1, i] is known to be 0 + this->mat_T(i, i) = r; + this->mat_T(i + 1, i) = 0; + // Update T[i, i + 1] and T[i + 1, i + 1] + // ptr[0] == T[i, i + 1] + // ptr[1] == T[i + 1, i + 1] + ptr = &(this->mat_T(i, i + 1)); + tmp = *ptr; + ptr[0] = c * tmp - s * ptr[1]; + ptr[1] = s * tmp + c * ptr[1]; + + if(i < this->n - 2) + { + // Update T[i, i + 2] and T[i + 1, i + 2] + // ptr[0] == T[i, i + 2] == 0 + // ptr[1] == T[i + 1, i + 2] + ptr = &(this->mat_T(i, i + 2)); + ptr[0] = -s * ptr[1]; + ptr[1] *= c; + } + } + + this->computed = true; + } + + + +template +Mat +TridiagQR::matrix_RQ() + { + arma_extra_debug_sigprint(); + + arma_debug_check( (this->computed == false), "newarp::TridiagQR::matrix_RQ(): need to call compute() first" ); + + // Make a copy of the R matrix + Mat RQ(this->n, this->n, arma_zeros_indicator()); + RQ.diag() = this->mat_T.diag(); + RQ.diag(1) = this->mat_T.diag(1); + + // [m11 m12] will point to RQ[i:(i+1), i:(i+1)] + // [m21 m22] + eT *m11 = RQ.memptr(), *m12, *m21, *m22, tmp; + for(uword i = 0; i < this->n - 1; i++) + { + const eT c = this->rot_cos(i); + const eT s = this->rot_sin(i); + m21 = m11 + 1; + m12 = m11 + this->n; + m22 = m12 + 1; + tmp = *m21; + + // Update diagonal and the below-subdiagonal + *m11 = c * (*m11) - s * (*m12); + *m21 = c * tmp - s * (*m22); + *m22 = s * tmp + c * (*m22); + + // Move m11 to RQ[i+1, i+1] + m11 = m22; + } + + // Copy the below-subdiagonal to above-subdiagonal + RQ.diag(1) = RQ.diag(-1); + + return RQ; + } + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/newarp_cx_attrib.hpp b/src/armadillo/include/armadillo_bits/newarp_cx_attrib.hpp new file mode 100644 index 0000000..e654dc4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/newarp_cx_attrib.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Tiny functions to check attributes of complex numbers +struct cx_attrib + { + template + arma_inline static bool is_real (const std::complex& v, const T eps) { return (std::abs(v.imag()) <= eps); } + + template + arma_inline static bool is_complex(const std::complex& v, const T eps) { return (std::abs(v.imag()) > eps); } + + template + arma_inline static bool is_conj(const std::complex& v1, const std::complex& v2, const T eps) { return (std::abs(v1 - std::conj(v2)) <= eps); } + }; + + +} // namespace newarp diff --git a/src/armadillo/include/armadillo_bits/op_all_bones.hpp b/src/armadillo/include/armadillo_bits/op_all_bones.hpp new file mode 100644 index 0000000..b8faf9a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_all_bones.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_all +//! @{ + + + +class op_all + : public traits_op_xvec + { + public: + + + template + static inline bool + all_vec_helper(const Base& X); + + + template + static inline bool + all_vec_helper(const subview& X); + + + template + static inline bool + all_vec_helper(const Op& X); + + + template + static inline bool + all_vec_helper + ( + const mtOp& X, + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr + ); + + + template + static inline bool + all_vec_helper + ( + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr, + const typename arma_not_cx::result* junk3 = nullptr + ); + + + template + static inline bool all_vec(T1& X); + + + template + static inline void apply_helper(Mat& out, const Proxy& P, const uword dim); + + + template + static inline void apply(Mat& out, const mtOp& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_all_meat.hpp b/src/armadillo/include/armadillo_bits/op_all_meat.hpp new file mode 100644 index 0000000..5dff3ec --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_all_meat.hpp @@ -0,0 +1,406 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_all +//! @{ + + + +template +inline +bool +op_all::all_vec_helper(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + uword count = 0; + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i +inline +bool +op_all::all_vec_helper(const subview& X) + { + arma_extra_debug_sigprint(); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + uword count = 0; + + if(X_n_rows == 1) + { + for(uword col=0; col < X_n_cols; ++col) + { + count += (X.at(0,col) != eT(0)) ? uword(1) : uword(0); + } + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + const eT* X_colmem = X.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + count += (X_colmem[row] != eT(0)) ? uword(1) : uword(0); + } + } + } + + return (X.n_elem == count); + } + + + +template +inline +bool +op_all::all_vec_helper(const Op& X) + { + arma_extra_debug_sigprint(); + + return op_all::all_vec_helper(X.m); + } + + + +template +inline +bool +op_all::all_vec_helper + ( + const mtOp& X, + const typename arma_op_rel_only::result* junk1, + const typename arma_not_cx::result* junk2 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::elem_type eT; + + const eT val = X.aux; + + const Proxy P(X.m); + + const uword n_elem = P.get_n_elem(); + + uword count = 0; + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + const eT tmp = Pea[i]; + + if(is_same_type::yes) { count += (val < tmp) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp < val) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (val > tmp) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp > val) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (val <= tmp) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp <= val) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (val >= tmp) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp >= val) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp == val) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp != val) ? uword(1) : uword(0); } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT tmp = P.at(row,col); + + if(is_same_type::yes) { if(val < tmp) { ++count; } } + else if(is_same_type::yes) { if(tmp < val) { ++count; } } + else if(is_same_type::yes) { if(val > tmp) { ++count; } } + else if(is_same_type::yes) { if(tmp > val) { ++count; } } + else if(is_same_type::yes) { if(val <= tmp) { ++count; } } + else if(is_same_type::yes) { if(tmp <= val) { ++count; } } + else if(is_same_type::yes) { if(val >= tmp) { ++count; } } + else if(is_same_type::yes) { if(tmp >= val) { ++count; } } + else if(is_same_type::yes) { if(tmp == val) { ++count; } } + else if(is_same_type::yes) { if(tmp != val) { ++count; } } + } + } + + return (n_elem == count); + } + + + +template +inline +bool +op_all::all_vec_helper + ( + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1, + const typename arma_not_cx::result* junk2, + const typename arma_not_cx::result* junk3 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + arma_ignore(junk3); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "relational operator"); + + const uword n_elem = A.get_n_elem(); + + uword count = 0; + + const bool use_at = (Proxy::use_at || Proxy::use_at); + + if(use_at == false) + { + ea_type1 PA = A.get_ea(); + ea_type2 PB = B.get_ea(); + + for(uword i=0; i::yes) { count += (tmp1 < tmp2) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp1 > tmp2) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp1 <= tmp2) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp1 >= tmp2) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp1 == tmp2) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp1 != tmp2) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp1 && tmp2) ? uword(1) : uword(0); } + else if(is_same_type::yes) { count += (tmp1 || tmp2) ? uword(1) : uword(0); } + } + } + else + { + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT1 tmp1 = A.at(row,col); + const eT2 tmp2 = B.at(row,col); + + if(is_same_type::yes) { if(tmp1 < tmp2) { ++count; } } + else if(is_same_type::yes) { if(tmp1 > tmp2) { ++count; } } + else if(is_same_type::yes) { if(tmp1 <= tmp2) { ++count; } } + else if(is_same_type::yes) { if(tmp1 >= tmp2) { ++count; } } + else if(is_same_type::yes) { if(tmp1 == tmp2) { ++count; } } + else if(is_same_type::yes) { if(tmp1 != tmp2) { ++count; } } + else if(is_same_type::yes) { if(tmp1 && tmp2) { ++count; } } + else if(is_same_type::yes) { if(tmp1 || tmp2) { ++count; } } + } + } + + return (n_elem == count); + } + + + +template +inline +bool +op_all::all_vec(T1& X) + { + arma_extra_debug_sigprint(); + + return op_all::all_vec_helper(X); + } + + + +template +inline +void +op_all::apply_helper(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + typedef typename Proxy::elem_type eT; + + if(dim == 0) // traverse rows (ie. process each column) + { + out.zeros(1, n_cols); + + if(out.n_elem == 0) { return; } + + uword* out_mem = out.memptr(); + + if(is_Mat::stored_type>::value) + { + const unwrap::stored_type> U(P.Q); + + for(uword col=0; col < n_cols; ++col) + { + const eT* colmem = U.M.colptr(col); + + uword count = 0; + + for(uword row=0; row < n_rows; ++row) + { + count += (colmem[row] != eT(0)) ? uword(1) : uword(0); + } + + out_mem[col] = (n_rows == count) ? uword(1) : uword(0); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + uword count = 0; + + for(uword row=0; row < n_rows; ++row) + { + if(P.at(row,col) != eT(0)) { ++count; } + } + + out_mem[col] = (n_rows == count) ? uword(1) : uword(0); + } + } + } + else + { + out.zeros(n_rows, 1); + + uword* out_mem = out.memptr(); + + // internal dual use of 'out': keep the counts for each row + + if(is_Mat::stored_type>::value) + { + const unwrap::stored_type> U(P.Q); + + for(uword col=0; col < n_cols; ++col) + { + const eT* colmem = U.M.colptr(col); + + for(uword row=0; row < n_rows; ++row) + { + out_mem[row] += (colmem[row] != eT(0)) ? uword(1) : uword(0); + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + for(uword row=0; row < n_rows; ++row) + { + if(P.at(row,col) != eT(0)) { ++out_mem[row]; } + } + } + } + + + // see what the counts tell us + + for(uword row=0; row < n_rows; ++row) + { + out_mem[row] = (n_cols == out_mem[row]) ? uword(1) : uword(0); + } + + } + } + + + +template +inline +void +op_all::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + const uword dim = X.aux_uword_a; + + const Proxy P(X.m); + + if(P.is_alias(out) == false) + { + op_all::apply_helper(out, P, dim); + } + else + { + Mat out2; + + op_all::apply_helper(out2, P, dim); + + out.steal_mem(out2); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_any_bones.hpp b/src/armadillo/include/armadillo_bits/op_any_bones.hpp new file mode 100644 index 0000000..ffb197b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_any_bones.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_any +//! @{ + + + +class op_any + : public traits_op_xvec + { + public: + + + template + static inline bool + any_vec_helper(const Base& X); + + + template + static inline bool + any_vec_helper(const subview& X); + + + template + static inline bool + any_vec_helper(const Op& X); + + + template + static inline bool + any_vec_helper + ( + const mtOp& X, + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr + ); + + + template + static inline bool + any_vec_helper + ( + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr, + const typename arma_not_cx::result* junk3 = nullptr + ); + + + template + static inline bool any_vec(T1& X); + + + template + static inline void apply_helper(Mat& out, const Proxy& P, const uword dim); + + + template + static inline void apply(Mat& out, const mtOp& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_any_meat.hpp b/src/armadillo/include/armadillo_bits/op_any_meat.hpp new file mode 100644 index 0000000..3356ec7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_any_meat.hpp @@ -0,0 +1,377 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_any +//! @{ + + + +template +inline +bool +op_any::any_vec_helper(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i +inline +bool +op_any::any_vec_helper(const subview& X) + { + arma_extra_debug_sigprint(); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(X_n_rows == 1) + { + for(uword col=0; col < X_n_cols; ++col) + { + if(X.at(0,col) != eT(0)) { return true; } + } + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + const eT* X_colmem = X.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + if(X_colmem[row] != eT(0)) { return true; } + } + } + } + + return false; + } + + + +template +inline +bool +op_any::any_vec_helper(const Op& X) + { + arma_extra_debug_sigprint(); + + return op_any::any_vec_helper(X.m); + } + + + +template +inline +bool +op_any::any_vec_helper + ( + const mtOp& X, + const typename arma_op_rel_only::result* junk1, + const typename arma_not_cx::result* junk2 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::elem_type eT; + + const eT val = X.aux; + + const Proxy P(X.m); + + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i < n_elem; ++i) + { + const eT tmp = Pea[i]; + + if(is_same_type::yes) { if(val < tmp) { return true; } } + else if(is_same_type::yes) { if(tmp < val) { return true; } } + else if(is_same_type::yes) { if(val > tmp) { return true; } } + else if(is_same_type::yes) { if(tmp > val) { return true; } } + else if(is_same_type::yes) { if(val <= tmp) { return true; } } + else if(is_same_type::yes) { if(tmp <= val) { return true; } } + else if(is_same_type::yes) { if(val >= tmp) { return true; } } + else if(is_same_type::yes) { if(tmp >= val) { return true; } } + else if(is_same_type::yes) { if(tmp == val) { return true; } } + else if(is_same_type::yes) { if(tmp != val) { return true; } } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT tmp = P.at(row,col); + + if(is_same_type::yes) { if(val < tmp) { return true; } } + else if(is_same_type::yes) { if(tmp < val) { return true; } } + else if(is_same_type::yes) { if(val > tmp) { return true; } } + else if(is_same_type::yes) { if(tmp > val) { return true; } } + else if(is_same_type::yes) { if(val <= tmp) { return true; } } + else if(is_same_type::yes) { if(tmp <= val) { return true; } } + else if(is_same_type::yes) { if(val >= tmp) { return true; } } + else if(is_same_type::yes) { if(tmp >= val) { return true; } } + else if(is_same_type::yes) { if(tmp == val) { return true; } } + else if(is_same_type::yes) { if(tmp != val) { return true; } } + } + } + + return false; + } + + + +template +inline +bool +op_any::any_vec_helper + ( + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1, + const typename arma_not_cx::result* junk2, + const typename arma_not_cx::result* junk3 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + arma_ignore(junk3); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "relational operator"); + + const bool use_at = (Proxy::use_at || Proxy::use_at); + + if(use_at == false) + { + ea_type1 PA = A.get_ea(); + ea_type2 PB = B.get_ea(); + + const uword n_elem = A.get_n_elem(); + + for(uword i=0; i::yes) { if(tmp1 < tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 > tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 <= tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 >= tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 == tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 != tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 && tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 || tmp2) { return true; } } + } + } + else + { + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT1 tmp1 = A.at(row,col); + const eT2 tmp2 = B.at(row,col); + + if(is_same_type::yes) { if(tmp1 < tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 > tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 <= tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 >= tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 == tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 != tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 && tmp2) { return true; } } + else if(is_same_type::yes) { if(tmp1 || tmp2) { return true; } } + } + } + + return false; + } + + + +template +inline +bool +op_any::any_vec(T1& X) + { + arma_extra_debug_sigprint(); + + return op_any::any_vec_helper(X); + } + + + +template +inline +void +op_any::apply_helper(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + typedef typename Proxy::elem_type eT; + + if(dim == 0) // traverse rows (ie. process each column) + { + out.zeros(1, n_cols); + + uword* out_mem = out.memptr(); + + if(is_Mat::stored_type>::value) + { + const unwrap::stored_type> U(P.Q); + + for(uword col=0; col < n_cols; ++col) + { + const eT* colmem = U.M.colptr(col); + + for(uword row=0; row < n_rows; ++row) + { + if(colmem[row] != eT(0)) { out_mem[col] = uword(1); break; } + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + for(uword row=0; row < n_rows; ++row) + { + if(P.at(row,col) != eT(0)) { out_mem[col] = uword(1); break; } + } + } + } + } + else + { + out.zeros(n_rows, 1); + + uword* out_mem = out.memptr(); + + if(is_Mat::stored_type>::value) + { + const unwrap::stored_type> U(P.Q); + + for(uword col=0; col < n_cols; ++col) + { + const eT* colmem = U.M.colptr(col); + + for(uword row=0; row < n_rows; ++row) + { + if(colmem[row] != eT(0)) { out_mem[row] = uword(1); } + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + for(uword row=0; row < n_rows; ++row) + { + if(P.at(row,col) != eT(0)) { out_mem[row] = uword(1); } + } + } + } + } + } + + + +template +inline +void +op_any::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + const uword dim = X.aux_uword_a; + + const Proxy P(X.m); + + if(P.is_alias(out) == false) + { + op_any::apply_helper(out, P, dim); + } + else + { + Mat out2; + + op_any::apply_helper(out2, P, dim); + + out.steal_mem(out2); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_chi2rnd_bones.hpp b/src/armadillo/include/armadillo_bits/op_chi2rnd_bones.hpp new file mode 100644 index 0000000..540bdfb --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_chi2rnd_bones.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_chi2rnd +//! @{ + + +class op_chi2rnd + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_noalias(Mat& out, const Proxy& P); + + template + inline static void fill_constant_df(Mat& out, const eT df); + }; + + + +template +class op_chi2rnd_varying_df + { + public: + + arma_aligned std::mt19937_64 motor; + + inline ~op_chi2rnd_varying_df(); + inline op_chi2rnd_varying_df(); + + inline eT operator()(const eT df); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_chi2rnd_meat.hpp b/src/armadillo/include/armadillo_bits/op_chi2rnd_meat.hpp new file mode 100644 index 0000000..1b681ae --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_chi2rnd_meat.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_chi2rnd +//! @{ + + + +template +inline +void +op_chi2rnd::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(in.m); + + if(P.is_alias(out) == false) + { + op_chi2rnd::apply_noalias(out, P); + } + else + { + Mat tmp; + + op_chi2rnd::apply_noalias(tmp, P); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_chi2rnd::apply_noalias(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + op_chi2rnd_varying_df generator; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword N = P.get_n_elem(); + + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i +inline +void +op_chi2rnd::fill_constant_df(Mat& out, const eT df) + { + arma_extra_debug_sigprint(); + + if(df > eT(0)) + { + typedef std::mt19937_64 motor_type; + typedef std::mt19937_64::result_type seed_type; + typedef std::chi_squared_distribution distr_type; + + motor_type motor; motor.seed( seed_type(arma_rng::randi()) ); + distr_type distr(df); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + + for(uword i=0; i::nan ); + } + } + + + +// + + + +template +inline +op_chi2rnd_varying_df::~op_chi2rnd_varying_df() + { + arma_extra_debug_sigprint(); + } + + + +template +inline +op_chi2rnd_varying_df::op_chi2rnd_varying_df() + { + arma_extra_debug_sigprint(); + + typedef std::mt19937_64::result_type seed_type; + + motor.seed( seed_type(arma_rng::randi()) ); + } + + + +template +inline +eT +op_chi2rnd_varying_df::operator()(const eT df) + { + arma_extra_debug_sigprint(); + + // as C++11 doesn't seem to provide a way to explicitly set the parameter + // of an existing chi_squared_distribution object, + // we need to create a new object each time + + if(df > eT(0)) + { + std::chi_squared_distribution distr(df); + + return eT( distr(motor) ); + } + else + { + return Datum::nan; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_chol_bones.hpp b/src/armadillo/include/armadillo_bits/op_chol_bones.hpp new file mode 100644 index 0000000..e3b3a9c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_chol_bones.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_chol +//! @{ + + + +class op_chol + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& X); + + template + inline static bool apply_direct(Mat& out, const Base& A_expr, const uword layout); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_chol_meat.hpp b/src/armadillo/include/armadillo_bits/op_chol_meat.hpp new file mode 100644 index 0000000..ebc6448 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_chol_meat.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_chol +//! @{ + + + +template +inline +void +op_chol::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_chol::apply_direct(out, X.m, X.aux_uword_a); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("chol(): decomposition failed"); + } + } + + + +template +inline +bool +op_chol::apply_direct(Mat& out, const Base& A_expr, const uword layout) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + out = A_expr.get_ref(); + + arma_debug_check( (out.is_square() == false), "chol(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if(out.is_empty()) { return true; } + + if((arma_config::debug) && (auxlib::rudimentary_sym_check(out) == false)) + { + if(is_cx::no ) { arma_debug_warn_level(1, "chol(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "chol(): given matrix is not hermitian"); } + } + + uword KD = 0; + + const bool is_band = arma_config::optimise_band && ((auxlib::crippled_lapack(out)) ? false : ((layout == 0) ? band_helper::is_band_upper(KD, out, uword(32)) : band_helper::is_band_lower(KD, out, uword(32)))); + + const bool status = (is_band) ? auxlib::chol_band(out, KD, layout) : auxlib::chol(out, layout); + + return status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_clamp_bones.hpp b/src/armadillo/include/armadillo_bits/op_clamp_bones.hpp new file mode 100644 index 0000000..89e5342 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_clamp_bones.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_clamp +//! @{ + + + +class op_clamp + : public traits_op_passthru + { + public: + + // matrices + + template inline static void apply(Mat& out, const mtOp& in); + + template inline static void apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val); + + template inline static void apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); + + // cubes + + template inline static void apply(Cube& out, const mtOpCube& in); + + template inline static void apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val); + + template inline static void apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); + }; + + + +class op_clamp_cx + : public traits_op_passthru + { + public: + + // matrices + + template inline static void apply(Mat& out, const mtOp& in); + + template inline static void apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val); + + template inline static void apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); + + + // cubes + + template inline static void apply(Cube& out, const mtOpCube& in); + + template inline static void apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val); + + template inline static void apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_clamp_meat.hpp b/src/armadillo/include/armadillo_bits/op_clamp_meat.hpp new file mode 100644 index 0000000..a14cd78 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_clamp_meat.hpp @@ -0,0 +1,577 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_clamp +//! @{ + + + +template +inline +void +op_clamp::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const eT min_val = in.aux; + const eT max_val = in.aux_out_eT; + + arma_debug_check( (min_val > max_val), "clamp(): min_val must be less than max_val" ); + + if(is_Mat::value) + { + const unwrap U(in.m); + + op_clamp::apply_direct(out, U.M, min_val, max_val); + } + else + { + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_clamp::apply_proxy_noalias(tmp, P, min_val, max_val); + + out.steal_mem(tmp); + } + else + { + op_clamp::apply_proxy_noalias(out, P, min_val, max_val); + } + } + } + + + +template +inline +void +op_clamp::apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(&out != &X) + { + out.set_size(X.n_rows, X.n_cols); + + const uword N = out.n_elem; + + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i max_val) ? max_val : val); + } + } + else + { + arma_extra_debug_print("op_clamp::apply_direct(): inplace operation"); + + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); + } + } + + + +template +inline +void +op_clamp::apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword N = P.get_n_elem(); + + typename Proxy::ea_type A = P.get_ea(); + + for(uword i=0; i max_val) ? max_val : val); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT val = P.at(row,col); + + (*out_mem) = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); + + out_mem++; + } + } + } + + + +// + + + +template +inline +void +op_clamp::apply(Cube& out, const mtOpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const eT min_val = in.aux; + const eT max_val = in.aux_out_eT; + + arma_debug_check( (min_val > max_val), "clamp(): min_val must be less than max_val" ); + + if(is_Cube::value) + { + const unwrap_cube U(in.m); + + op_clamp::apply_direct(out, U.M, min_val, max_val); + } + else + { + const ProxyCube P(in.m); + + if(P.is_alias(out)) + { + Cube tmp; + + op_clamp::apply_proxy_noalias(tmp, P, min_val, max_val); + + out.steal_mem(tmp); + } + else + { + op_clamp::apply_proxy_noalias(out, P, min_val, max_val); + } + } + } + + + +template +inline +void +op_clamp::apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(&out != &X) + { + out.set_size(X.n_rows, X.n_cols, X.n_slices); + + const uword N = out.n_elem; + + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i max_val) ? max_val : val); + } + } + else + { + arma_extra_debug_print("op_clamp::apply_direct(): inplace operation"); + + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); + } + } + + + +template +inline +void +op_clamp::apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + const uword N = P.get_n_elem(); + + typename ProxyCube::ea_type A = P.get_ea(); + + for(uword i=0; i max_val) ? max_val : val); + } + } + else + { + for(uword s=0; s < n_slices; ++s) + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + const eT val = P.at(r,c,s); + + (*out_mem) = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); + + out_mem++; + } + } + } + + + +// + + + +template +inline +void +op_clamp_cx::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Mat::value) + { + const unwrap U(in.m); + + op_clamp_cx::apply_direct(out, U.M, in.aux, in.aux_out_eT); + } + else + { + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_clamp_cx::apply_proxy_noalias(tmp, P, in.aux, in.aux_out_eT); + + out.steal_mem(tmp); + } + else + { + op_clamp_cx::apply_proxy_noalias(out, P, in.aux, in.aux_out_eT); + } + } + } + + + +template +inline +void +op_clamp_cx::apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + + if(&out != &X) + { + out.set_size(X.n_rows, X.n_cols); + + const uword N = out.n_elem; + + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + out_mem[i] = std::complex(val_real,val_imag); + } + } + else + { + arma_extra_debug_print("op_clamp_cx::apply_direct(): inplace operation"); + + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); + } + } + + + +template +inline +void +op_clamp_cx::apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword N = P.get_n_elem(); + + typename Proxy::ea_type A = P.get_ea(); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + out_mem[i] = std::complex(val_real,val_imag); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT val = P.at(row,col); + + T val_real = std::real(val); + T val_imag = std::imag(val); + + val_real = (val_real < min_val_real) ? min_val_real : ((val_real > max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + (*out_mem) = std::complex(val_real,val_imag); out_mem++; + } + } + } + + + +// + + + +template +inline +void +op_clamp_cx::apply(Cube& out, const mtOpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Cube::value) + { + const unwrap_cube U(in.m); + + op_clamp_cx::apply_direct(out, U.M, in.aux, in.aux_out_eT); + } + else + { + const ProxyCube P(in.m); + + if(P.is_alias(out)) + { + Cube tmp; + + op_clamp_cx::apply_proxy_noalias(tmp, P, in.aux, in.aux_out_eT); + + out.steal_mem(tmp); + } + else + { + op_clamp_cx::apply_proxy_noalias(out, P, in.aux, in.aux_out_eT); + } + } + } + + + +template +inline +void +op_clamp_cx::apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + + if(&out != &X) + { + out.set_size(X.n_rows, X.n_cols, X.n_slices); + + const uword N = out.n_elem; + + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + out_mem[i] = std::complex(val_real,val_imag); + } + } + else + { + arma_extra_debug_print("op_clamp_cx::apply_direct(): inplace operation"); + + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); + } + } + + + +template +inline +void +op_clamp_cx::apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + const uword N = P.get_n_elem(); + + typename ProxyCube::ea_type A = P.get_ea(); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + out_mem[i] = std::complex(val_real,val_imag); + } + } + else + { + for(uword s=0; s < n_slices; ++s) + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + const eT val = P.at(r,c,s); + + T val_real = std::real(val); + T val_imag = std::imag(val); + + val_real = (val_real < min_val_real) ? min_val_real : ((val_real > max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + (*out_mem) = std::complex(val_real,val_imag); out_mem++; + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_col_as_mat_bones.hpp b/src/armadillo/include/armadillo_bits/op_col_as_mat_bones.hpp new file mode 100644 index 0000000..6e653ea --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_col_as_mat_bones.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_col_as_mat +//! @{ + + +class op_col_as_mat + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const CubeToMatOp& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_col_as_mat_meat.hpp b/src/armadillo/include/armadillo_bits/op_col_as_mat_meat.hpp new file mode 100644 index 0000000..2e0f0cd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_col_as_mat_meat.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_col_as_mat +//! @{ + + + +template +inline +void +op_col_as_mat::apply(Mat& out, const CubeToMatOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube U(expr.m); + const Cube& A = U.M; + + const uword in_col = expr.aux_uword; + + arma_debug_check_bounds( (in_col >= A.n_cols), "Cube::col_as_mat(): index out of bounds" ); + + const uword A_n_rows = A.n_rows; + const uword A_n_slices = A.n_slices; + + out.set_size(A_n_rows, A_n_slices); + + for(uword s=0; s < A_n_slices; ++s) + { + arrayops::copy(out.colptr(s), A.slice_colptr(s, in_col), A_n_rows); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cond_bones.hpp b/src/armadillo/include/armadillo_bits/op_cond_bones.hpp new file mode 100644 index 0000000..b764898 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cond_bones.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cond +//! @{ + + +class op_cond + : public traits_op_default + { + public: + + template static inline typename T1::pod_type apply(const Base& X); + + template static inline typename get_pod_type::result apply_diag(const Mat& A); + template static inline typename get_pod_type::result apply_sym ( Mat& A); + template static inline typename get_pod_type::result apply_gen ( Mat& A); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cond_meat.hpp b/src/armadillo/include/armadillo_bits/op_cond_meat.hpp new file mode 100644 index 0000000..f73ef9a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cond_meat.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cond +//! @{ + + + +template +inline +typename T1::pod_type +op_cond::apply(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + Mat A(X.get_ref()); + + if(A.n_elem == 0) { return T(0); } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_cond::apply(): detected diagonal matrix"); + + return op_cond::apply_diag(A); + } + + bool is_approx_sym = false; + bool is_approx_sympd = false; + + sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); + + const bool do_sym = (is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd); + + if(do_sym) + { + arma_extra_debug_print("op_cond: symmetric/hermitian optimisation"); + + return op_cond::apply_sym(A); + } + + return op_cond::apply_gen(A); + } + + + +template +inline +typename get_pod_type::result +op_cond::apply_diag(const Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword N = (std::min)(A.n_rows, A.n_cols); + + T abs_min = Datum::inf; + T abs_max = T(0); + + for(uword i=0; i < N; ++i) + { + const T abs_val = std::abs(A.at(i,i)); + + if(arma_isnan(abs_val)) + { + arma_debug_warn_level(3, "cond(): failed"); + + return Datum::nan; + } + + abs_min = (abs_val < abs_min) ? abs_val : abs_min; + abs_max = (abs_val > abs_max) ? abs_val : abs_max; + } + + if((abs_min == T(0)) || (abs_max == T(0))) { return Datum::inf; } + + return T(abs_max / abs_min); + } + + + +template +inline +typename get_pod_type::result +op_cond::apply_sym(Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Col eigval; + + const bool status = auxlib::eig_sym(eigval, A); + + if(status == false) + { + arma_debug_warn_level(3, "cond(): failed"); + + return Datum::nan; + } + + if(eigval.n_elem == 0) { return T(0); } + + const T* eigval_mem = eigval.memptr(); + + T abs_min = std::abs(eigval_mem[0]); + T abs_max = abs_min; + + for(uword i=1; i < eigval.n_elem; ++i) + { + const T abs_val = std::abs(eigval_mem[i]); + + abs_min = (abs_val < abs_min) ? abs_val : abs_min; + abs_max = (abs_val > abs_max) ? abs_val : abs_max; + } + + if((abs_min == T(0)) || (abs_max == T(0))) { return Datum::inf; } + + return T(abs_max / abs_min); + } + + + +template +inline +typename get_pod_type::result +op_cond::apply_gen(Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Col S; + + const bool status = auxlib::svd_dc(S, A); + + if(status == false) + { + arma_debug_warn_level(3, "cond(): failed"); + + return Datum::nan; + } + + if(S.n_elem == 0) { return T(0); } + + const T S_max = S[0]; + const T S_min = S[S.n_elem-1]; + + if((S_max == T(0)) || (S_min == T(0))) { return Datum::inf; } + + return T(S_max / S_min); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cor_bones.hpp b/src/armadillo/include/armadillo_bits/op_cor_bones.hpp new file mode 100644 index 0000000..7a506c3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cor_bones.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_cor +//! @{ + + + +class op_cor + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const Op< T1, op_cor>& in); + template inline static void apply(Mat& out, const Op< Op, op_cor>& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cor_meat.hpp b/src/armadillo/include/armadillo_bits/op_cor_meat.hpp new file mode 100644 index 0000000..6763964 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cor_meat.hpp @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_cor +//! @{ + + + +template +inline +void +op_cor::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword norm_type = in.aux_uword_a; + + const unwrap U(in.m); + const Mat& A = U.M; + + if(A.n_elem == 0) + { + out.reset(); + return; + } + + if(A.n_elem == 1) + { + out.set_size(1,1); + out[0] = eT(1); + return; + } + + const Mat& AA = (A.n_rows == 1) + ? Mat(const_cast(A.memptr()), A.n_cols, A.n_rows, false, false) + : Mat(const_cast(A.memptr()), A.n_rows, A.n_cols, false, false); + + const uword N = AA.n_rows; + const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); + + const Mat tmp = AA.each_row() - mean(AA,0); + + out = tmp.t() * tmp; + out /= norm_val; + + const Col s = sqrt(out.diag()); + + out /= (s * s.t()); // TODO: check for zeros in s? + } + + + +template +inline +void +op_cor::apply(Mat& out, const Op< Op, op_cor>& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword norm_type = in.aux_uword_a; + + if(is_cx::yes) + { + const Mat tmp = in.m; // force the evaluation of Op + + out = cor(tmp, norm_type); + } + else + { + const unwrap U(in.m.m); + const Mat& A = U.M; + + if(A.n_elem == 0) + { + out.reset(); + return; + } + + if(A.n_elem == 1) + { + out.set_size(1,1); + out[0] = eT(1); + return; + } + + const Mat& AA = (A.n_cols == 1) + ? Mat(const_cast(A.memptr()), A.n_cols, A.n_rows, false, false) + : Mat(const_cast(A.memptr()), A.n_rows, A.n_cols, false, false); + + const uword N = AA.n_cols; + const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); + + const Mat tmp = AA.each_col() - mean(AA,1); + + out = tmp * tmp.t(); + out /= norm_val; + + const Col s = sqrt(out.diag()); + + out /= (s * s.t()); // TODO: check for zeros in s? + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cov_bones.hpp b/src/armadillo/include/armadillo_bits/op_cov_bones.hpp new file mode 100644 index 0000000..5de43ba --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cov_bones.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_cov +//! @{ + + + +class op_cov + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const Op< T1, op_cov>& in); + template inline static void apply(Mat& out, const Op< Op, op_cov>& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cov_meat.hpp b/src/armadillo/include/armadillo_bits/op_cov_meat.hpp new file mode 100644 index 0000000..30944b7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cov_meat.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_cov +//! @{ + + + +template +inline +void +op_cov::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword norm_type = in.aux_uword_a; + + const unwrap U(in.m); + const Mat& A = U.M; + + if(A.n_elem == 0) + { + out.reset(); + return; + } + + const Mat& AA = (A.n_rows == 1) + ? Mat(const_cast(A.memptr()), A.n_cols, A.n_rows, false, false) + : Mat(const_cast(A.memptr()), A.n_rows, A.n_cols, false, false); + + const uword N = AA.n_rows; + const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); + + const Mat tmp = AA.each_row() - mean(AA,0); + + out = tmp.t() * tmp; + out /= norm_val; + } + + + +template +inline +void +op_cov::apply(Mat& out, const Op< Op, op_cov>& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword norm_type = in.aux_uword_a; + + if(is_cx::yes) + { + const Mat tmp = in.m; // force the evaluation of Op + + out = cov(tmp, norm_type); + } + else + { + const unwrap U(in.m.m); + const Mat& A = U.M; + + if(A.n_elem == 0) + { + out.reset(); + return; + } + + const Mat& AA = (A.n_cols == 1) + ? Mat(const_cast(A.memptr()), A.n_cols, A.n_rows, false, false) + : Mat(const_cast(A.memptr()), A.n_rows, A.n_cols, false, false); + + const uword N = AA.n_cols; + const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); + + const Mat tmp = AA.each_col() - mean(AA,1); + + out = tmp * tmp.t(); + out /= norm_val; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cumprod_bones.hpp b/src/armadillo/include/armadillo_bits/op_cumprod_bones.hpp new file mode 100644 index 0000000..ce3b686 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cumprod_bones.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cumprod +//! @{ + + + +class op_cumprod + : public traits_op_default + { + public: + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim); + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_cumprod_vec + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cumprod_meat.hpp b/src/armadillo/include/armadillo_bits/op_cumprod_meat.hpp new file mode 100644 index 0000000..14dc224 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cumprod_meat.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cumprod +//! @{ + + + +template +inline +void +op_cumprod::apply_noalias(Mat& out, const Mat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + uword n_rows = X.n_rows; + uword n_cols = X.n_cols; + + out.set_size(n_rows,n_cols); + + if(out.n_elem == 0) { return; } + + if(dim == 0) + { + if(n_cols == 1) + { + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + eT acc = eT(1); + + for(uword row=0; row < n_rows; ++row) + { + acc *= X_mem[row]; + + out_mem[row] = acc; + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + const eT* X_colmem = X.colptr(col); + eT* out_colmem = out.colptr(col); + + eT acc = eT(1); + + for(uword row=0; row < n_rows; ++row) + { + acc *= X_colmem[row]; + + out_colmem[row] = acc; + } + } + } + } + else + if(dim == 1) + { + if(n_rows == 1) + { + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + eT acc = eT(1); + + for(uword col=0; col < n_cols; ++col) + { + acc *= X_mem[col]; + + out_mem[col] = acc; + } + } + else + { + if(n_cols > 0) + { + arrayops::copy( out.colptr(0), X.colptr(0), n_rows ); + + for(uword col=1; col < n_cols; ++col) + { + const eT* out_colmem_prev = out.colptr(col-1); + eT* out_colmem = out.colptr(col ); + const eT* X_colmem = X.colptr(col ); + + for(uword row=0; row < n_rows; ++row) + { + out_colmem[row] = out_colmem_prev[row] * X_colmem[row]; + } + } + } + } + } + } + + + +template +inline +void +op_cumprod::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + + arma_debug_check( (dim > 1), "cumprod(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_cumprod::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + else + { + op_cumprod::apply_noalias(out, U.M, dim); + } + } + + + +template +inline +void +op_cumprod_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(in.m); + + const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + if(U.is_alias(out)) + { + Mat tmp; + + op_cumprod::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + else + { + op_cumprod::apply_noalias(out, U.M, dim); + } + } + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/op_cumsum_bones.hpp b/src/armadillo/include/armadillo_bits/op_cumsum_bones.hpp new file mode 100644 index 0000000..007d3f3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cumsum_bones.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cumsum +//! @{ + + + +class op_cumsum + : public traits_op_default + { + public: + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim); + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_cumsum_vec + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cumsum_meat.hpp b/src/armadillo/include/armadillo_bits/op_cumsum_meat.hpp new file mode 100644 index 0000000..d46eda2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cumsum_meat.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cumsum +//! @{ + + + +template +inline +void +op_cumsum::apply_noalias(Mat& out, const Mat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + uword n_rows = X.n_rows; + uword n_cols = X.n_cols; + + out.set_size(n_rows,n_cols); + + if(out.n_elem == 0) { return; } + + if(dim == 0) + { + if(n_cols == 1) + { + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + eT acc = eT(0); + + for(uword row=0; row < n_rows; ++row) + { + acc += X_mem[row]; + + out_mem[row] = acc; + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + const eT* X_colmem = X.colptr(col); + eT* out_colmem = out.colptr(col); + + eT acc = eT(0); + + for(uword row=0; row < n_rows; ++row) + { + acc += X_colmem[row]; + + out_colmem[row] = acc; + } + } + } + } + else + if(dim == 1) + { + if(n_rows == 1) + { + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + eT acc = eT(0); + + for(uword col=0; col < n_cols; ++col) + { + acc += X_mem[col]; + + out_mem[col] = acc; + } + } + else + { + if(n_cols > 0) + { + arrayops::copy( out.colptr(0), X.colptr(0), n_rows ); + + for(uword col=1; col < n_cols; ++col) + { + const eT* out_colmem_prev = out.colptr(col-1); + eT* out_colmem = out.colptr(col ); + const eT* X_colmem = X.colptr(col ); + + for(uword row=0; row < n_rows; ++row) + { + out_colmem[row] = out_colmem_prev[row] + X_colmem[row]; + } + } + } + } + } + } + + + +template +inline +void +op_cumsum::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + + arma_debug_check( (dim > 1), "cumsum(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_cumsum::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + else + { + op_cumsum::apply_noalias(out, U.M, dim); + } + } + + + +template +inline +void +op_cumsum_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(in.m); + + const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + if(U.is_alias(out)) + { + Mat tmp; + + op_cumsum::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + else + { + op_cumsum::apply_noalias(out, U.M, dim); + } + } + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/op_cx_scalar_bones.hpp b/src/armadillo/include/armadillo_bits/op_cx_scalar_bones.hpp new file mode 100644 index 0000000..c1c2847 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cx_scalar_bones.hpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cx_scalar +//! @{ + + + +class op_cx_scalar_times + : public traits_op_passthru + { + public: + + template + inline static void + apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_times>& X + ); + + template + inline static void + apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_times>& X + ); + }; + + + +class op_cx_scalar_plus + : public traits_op_passthru + { + public: + + template + inline static void + apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_plus>& X + ); + + template + inline static void + apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_plus>& X + ); + }; + + + +class op_cx_scalar_minus_pre + : public traits_op_passthru + { + public: + + template + inline static void + apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_minus_pre>& X + ); + + template + inline static void + apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_minus_pre>& X + ); + }; + + + +class op_cx_scalar_minus_post + : public traits_op_passthru + { + public: + + template + inline static void + apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_minus_post>& X + ); + + template + inline static void + apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_minus_post>& X + ); + }; + + + +class op_cx_scalar_div_pre + : public traits_op_passthru + { + public: + + template + inline static void + apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_div_pre>& X + ); + + template + inline static void + apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_div_pre>& X + ); + }; + + + +class op_cx_scalar_div_post + : public traits_op_passthru + { + public: + + template + inline static void + apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_div_post>& X + ); + + template + inline static void + apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_div_post>& X + ); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_cx_scalar_meat.hpp b/src/armadillo/include/armadillo_bits/op_cx_scalar_meat.hpp new file mode 100644 index 0000000..9552c01 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_cx_scalar_meat.hpp @@ -0,0 +1,564 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_cx_scalar +//! @{ + + + +template +inline +void +op_cx_scalar_times::apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_times>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = X.aux_out_eT; + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword n_elem = A.get_n_elem(); + + for(uword i=0; i +inline +void +op_cx_scalar_plus::apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_plus>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = X.aux_out_eT; + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword n_elem = A.get_n_elem(); + + for(uword i=0; i +inline +void +op_cx_scalar_minus_pre::apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_minus_pre>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = X.aux_out_eT; + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword n_elem = A.get_n_elem(); + + for(uword i=0; i +inline +void +op_cx_scalar_minus_post::apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_minus_post>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = X.aux_out_eT; + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword n_elem = A.get_n_elem(); + + for(uword i=0; i +inline +void +op_cx_scalar_div_pre::apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_div_pre>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = X.aux_out_eT; + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword n_elem = A.get_n_elem(); + + for(uword i=0; i +inline +void +op_cx_scalar_div_post::apply + ( + Mat< typename std::complex >& out, + const mtOp, T1, op_cx_scalar_div_post>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Proxy A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = X.aux_out_eT; + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + const uword n_elem = A.get_n_elem(); + + for(uword i=0; i +inline +void +op_cx_scalar_times::apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_times>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + const eT k = X.aux_out_eT; + const uword n_elem = out.n_elem; + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + for(uword i=0; i +inline +void +op_cx_scalar_plus::apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_plus>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + const eT k = X.aux_out_eT; + const uword n_elem = out.n_elem; + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + for(uword i=0; i +inline +void +op_cx_scalar_minus_pre::apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_minus_pre>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + const eT k = X.aux_out_eT; + const uword n_elem = out.n_elem; + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + for(uword i=0; i +inline +void +op_cx_scalar_minus_post::apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_minus_post>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + const eT k = X.aux_out_eT; + const uword n_elem = out.n_elem; + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + for(uword i=0; i +inline +void +op_cx_scalar_div_pre::apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_div_pre>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + const eT k = X.aux_out_eT; + const uword n_elem = out.n_elem; + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + for(uword i=0; i +inline +void +op_cx_scalar_div_post::apply + ( + Cube< typename std::complex >& out, + const mtOpCube, T1, op_cx_scalar_div_post>& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const ProxyCube A(X.m); + + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + const uword n_slices = A.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + const eT k = X.aux_out_eT; + const uword n_elem = out.n_elem; + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + for(uword i=0; i + struct pos + { + static constexpr uword n2 = row + col*2; + static constexpr uword n3 = row + col*3; + }; + + template + inline static bool apply_direct(typename T1::elem_type& out_val, const Base& expr); + + template + inline static typename T1::elem_type apply_diagmat(const Base& expr); + + template + inline static typename T1::elem_type apply_trimat(const Base& expr); + + template + arma_cold inline static eT apply_tiny_2x2(const Mat& X); + + template + arma_cold inline static eT apply_tiny_3x3(const Mat& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_det_meat.hpp b/src/armadillo/include/armadillo_bits/op_det_meat.hpp new file mode 100644 index 0000000..81c9b39 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_det_meat.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_det +//! @{ + + + +template +inline +bool +op_det::apply_direct(typename T1::elem_type& out_val, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(strip_diagmat::do_diagmat) + { + const strip_diagmat strip(expr.get_ref()); + + out_val = op_det::apply_diagmat(strip.M); + + return true; + } + + if(strip_trimat::do_trimat) + { + const strip_trimat strip(expr.get_ref()); + + out_val = op_det::apply_trimat(strip.M); + + return true; + } + + Mat A(expr.get_ref()); + + arma_debug_check( (A.is_square() == false), "det(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + if(N == 0) { out_val = eT(1); return true; } + if(N == 1) { out_val = A[0]; return true; } + + if((is_cx::no) && (N <= 3)) + { + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + eT det_val = eT(0); + + if(N == 2) { det_val = op_det::apply_tiny_2x2(A); } + if(N == 3) { det_val = op_det::apply_tiny_3x3(A); } + + const T abs_det_val = std::abs(det_val); + + if((abs_det_val > det_min) && (abs_det_val < det_max)) { out_val = det_val; return true; } + + // fallthrough if det_val is suspect + } + + if(A.is_diagmat()) { out_val = op_det::apply_diagmat(A); return true; } + + const bool is_triu = trimat_helper::is_triu(A); + const bool is_tril = is_triu ? false : trimat_helper::is_tril(A); + + if(is_triu || is_tril) { out_val = op_det::apply_trimat(A); return true; } + + return auxlib::det(out_val, A); + } + + + +template +inline +typename T1::elem_type +op_det::apply_diagmat(const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const diagmat_proxy A(expr.get_ref()); + + arma_debug_check( (A.n_rows != A.n_cols), "det(): given matrix must be square sized" ); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + eT val = eT(1); + + for(uword i=0; i +inline +typename T1::elem_type +op_det::apply_trimat(const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(expr.get_ref()); + + const uword N = P.get_n_rows(); + + arma_debug_check( (N != P.get_n_cols()), "det(): given matrix must be square sized" ); + + eT val = eT(1); + + for(uword i=0; i +inline +eT +op_det::apply_tiny_2x2(const Mat& X) + { + arma_extra_debug_sigprint(); + + const eT* Xm = X.memptr(); + + return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] ); + } + + + +template +inline +eT +op_det::apply_tiny_3x3(const Mat& X) + { + arma_extra_debug_sigprint(); + + const eT* Xm = X.memptr(); + + // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2); + // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0); + // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1); + // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2); + // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0); + // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1); + // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6); + + const eT val1 = Xm[pos<0,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]); + const eT val2 = Xm[pos<1,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]); + const eT val3 = Xm[pos<2,0>::n3]*(Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]); + + return ( val1 - val2 + val3 ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_diagmat_bones.hpp b/src/armadillo/include/armadillo_bits/op_diagmat_bones.hpp new file mode 100644 index 0000000..50f7dc7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_diagmat_bones.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_diagmat +//! @{ + + + +class op_diagmat + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& X); + + template + inline static void apply(Mat& out, const Proxy& P); + + template + inline static void apply(Mat& out, const Op< Glue, op_diagmat>& X); + + template + inline static void apply_times(Mat& out, const T1& X, const T2& Y, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_times(Mat& out, const T1& X, const T2& Y, const typename arma_cx_only::result* junk = nullptr); + }; + + + +class op_diagmat2 + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& X); + + template + inline static void apply(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_diagmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_diagmat_meat.hpp new file mode 100644 index 0000000..727da7b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_diagmat_meat.hpp @@ -0,0 +1,767 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_diagmat +//! @{ + + + +template +inline +void +op_diagmat::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Mat::value) + { + // allow detection of in-place operation + + const unwrap U(X.m); + const Mat& A = U.M; + + if(&out != &A) // no aliasing + { + const Proxy< Mat > P(A); + + op_diagmat::apply(out, P); + } + else // we have aliasing + { + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + if((n_rows == 1) || (n_cols == 1)) // create diagonal matrix from vector + { + const eT* out_mem = out.memptr(); + const uword N = out.n_elem; + + Mat tmp(N,N, arma_zeros_indicator()); + + for(uword i=0; i P(X.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_diagmat::apply(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_diagmat::apply(out, P); + } + } + } + + + +template +inline +void +op_diagmat::apply(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { out.reset(); return; } + + const bool P_is_vec = (T1::is_row) || (T1::is_col) || (n_rows == 1) || (n_cols == 1); + + if(P_is_vec) + { + out.zeros(n_elem, n_elem); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) { out.at(i,i) = Pea[i]; } + } + else + { + if(n_rows == 1) + { + for(uword i=0; i < n_elem; ++i) { out.at(i,i) = P.at(0,i); } + } + else + { + for(uword i=0; i < n_elem; ++i) { out.at(i,i) = P.at(i,0); } + } + } + } + else // P represents a matrix + { + out.zeros(n_rows, n_cols); + + const uword N = (std::min)(n_rows, n_cols); + + for(uword i=0; i +inline +void +op_diagmat::apply(Mat& out, const Op< Glue, op_diagmat>& X) + { + arma_extra_debug_sigprint(); + + op_diagmat::apply_times(out, X.m.A, X.m.B); + } + + + +template +inline +void +op_diagmat::apply_times(Mat& actual_out, const T1& X, const T2& Y, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const partial_unwrap UA(X); + const partial_unwrap UB(Y); + + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + // check if the multiplication results in a vector + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + if((A_n_rows == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = false; vector result"); + + const Mat C = A*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + if((A_n_cols == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = true; trans_B = false; vector result"); + + const Mat C = trans(A)*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_rows == 1) || (B_n_rows == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = true; vector result"); + + const Mat C = A*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_cols == 1) || (B_n_rows == 1)) + { + arma_extra_debug_print("trans_A = true; trans_B = true; vector result"); + + const Mat C = trans(A)*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = false; trans_B = false; matrix result"); + + out.zeros(A_n_rows, B_n_cols); + + const uword N = (std::min)(A_n_rows, B_n_cols); + + for(uword k=0; k < N; ++k) + { + eT acc1 = eT(0); + eT acc2 = eT(0); + + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + uword j; + + for(j=1; j < A_n_cols; j+=2) + { + const uword i = (j-1); + + const eT tmp_i = B_colptr[i]; + const eT tmp_j = B_colptr[j]; + + acc1 += A.at(k, i) * tmp_i; + acc2 += A.at(k, j) * tmp_j; + } + + const uword i = (j-1); + + if(i < A_n_cols) + { + acc1 += A.at(k, i) * B_colptr[i]; + } + + const eT acc = acc1 + acc2; + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false; matrix result"); + + out.zeros(A_n_cols, B_n_cols); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + const eT acc = op_dot::direct_dot(A_n_rows, A_colptr, B_colptr); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + arma_extra_debug_print("trans_A = false; trans_B = true; matrix result"); + + out.zeros(A_n_rows, B_n_rows); + + const uword N = (std::min)(A_n_rows, B_n_rows); + + for(uword k=0; k < N; ++k) + { + eT acc = eT(0); + + // condition: A_n_cols = B_n_cols + + for(uword i=0; i < A_n_cols; ++i) + { + acc += A.at(k,i) * B.at(k,i); + } + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true; matrix result"); + + out.zeros(A_n_cols, B_n_rows); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + for(uword k=0; k < N; ++k) + { + eT acc = eT(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + + for(uword i=0; i < A_n_rows; ++i) + { + acc += A_colptr[i] * B.at(k,i); + } + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + + + +template +inline +void +op_diagmat::apply_times(Mat& actual_out, const T1& X, const T2& Y, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + const partial_unwrap UA(X); + const partial_unwrap UB(Y); + + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + // check if the multiplication results in a vector + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + if((A_n_rows == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = false; vector result"); + + const Mat C = A*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + if((A_n_cols == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = true; trans_B = false; vector result"); + + const Mat C = trans(A)*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_rows == 1) || (B_n_rows == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = true; vector result"); + + const Mat C = A*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_cols == 1) || (B_n_rows == 1)) + { + arma_extra_debug_print("trans_A = true; trans_B = true; vector result"); + + const Mat C = trans(A)*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = false; trans_B = false; matrix result"); + + out.zeros(A_n_rows, B_n_cols); + + const uword N = (std::min)(A_n_rows, B_n_cols); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k, i) * B_colptr[i]; + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false; matrix result"); + + out.zeros(A_n_cols, B_n_cols); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * B_colptr[i]; + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + // take into account the complex conjugate of xx + + acc_real += (a*c) + (b*d); + acc_imag += (a*d) - (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true) ) + { + arma_extra_debug_print("trans_A = false; trans_B = true; matrix result"); + + out.zeros(A_n_rows, B_n_rows); + + const uword N = (std::min)(A_n_rows, B_n_rows); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + // condition: A_n_cols = B_n_cols + + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k,i) * std::conj(B.at(k,i)); + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == true) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true; matrix result"); + + out.zeros(A_n_cols, B_n_rows); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i)); + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = -xx.imag(); // take the conjugate + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + + + +// +// +// + + + +template +inline +void +op_diagmat2::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword row_offset = X.aux_uword_a; + const uword col_offset = X.aux_uword_b; + + const Proxy P(X.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_diagmat2::apply(tmp, P, row_offset, col_offset); + + out.steal_mem(tmp); + } + else + { + op_diagmat2::apply(out, P, row_offset, col_offset); + } + } + + + +template +inline +void +op_diagmat2::apply(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { out.reset(); return; } + + const bool P_is_vec = (T1::is_row) || (T1::is_col) || (n_rows == 1) || (n_cols == 1); + + if(P_is_vec) + { + const uword n_pad = (std::max)(row_offset, col_offset); + + out.zeros(n_elem + n_pad, n_elem + n_pad); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = Pea[i]; } + } + else + { + if(n_rows == 1) + { + for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = P.at(0,i); } + } + else + { + for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = P.at(i,0); } + } + } + } + else // P represents a matrix + { + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diagmat(): requested diagonal out of bounds" + ); + + out.zeros(n_rows, n_cols); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i + inline static void apply(Mat& out, const Op& X); + + template + inline static void apply_proxy(Mat& out, const Proxy& P); + + template + inline static void apply(Mat& out, const Op< Glue, op_diagvec>& X, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply(Mat& out, const Op< Glue, op_diagvec>& X, const typename arma_cx_only::result* junk = nullptr); + }; + + + +class op_diagvec2 + : public traits_op_col + { + public: + + template + inline static void apply(Mat& out, const Op& X); + + template + inline static void apply_proxy(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_diagvec_meat.hpp b/src/armadillo/include/armadillo_bits/op_diagvec_meat.hpp new file mode 100644 index 0000000..f337192 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_diagvec_meat.hpp @@ -0,0 +1,536 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_diagvec +//! @{ + + + +template +inline +void +op_diagvec::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.m); + + if(P.is_alias(out) == false) + { + op_diagvec::apply_proxy(out, P); + } + else + { + Mat tmp; + + op_diagvec::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_diagvec::apply_proxy(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + const uword len = (std::min)(n_rows, n_cols); + + out.set_size(len, 1); + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < len; i+=2, j+=2) + { + const eT tmp_i = P.at(i, i); + const eT tmp_j = P.at(j, j); + + out_mem[i] = tmp_i; + out_mem[j] = tmp_j; + } + + if(i < len) + { + out_mem[i] = P.at(i, i); + } + } + + + +template +inline +void +op_diagvec::apply(Mat& actual_out, const Op< Glue, op_diagvec>& X, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const partial_unwrap UA(X.m.A); + const partial_unwrap UB(X.m.B); + + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_elem == 0) || (B.n_elem == 0) ) { actual_out.reset(); return; } + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const bool is_alias = (UA.is_alias(actual_out) || UB.is_alias(actual_out)); + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = false; trans_B = false;"); + + const uword N = (std::min)(A_n_rows, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + eT acc1 = eT(0); + eT acc2 = eT(0); + + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + uword j; + + for(j=1; j < A_n_cols; j+=2) + { + const uword i = (j-1); + + const eT tmp_i = B_colptr[i]; + const eT tmp_j = B_colptr[j]; + + acc1 += A.at(k, i) * tmp_i; + acc2 += A.at(k, j) * tmp_j; + } + + const uword i = (j-1); + + if(i < A_n_cols) + { + acc1 += A.at(k, i) * B_colptr[i]; + } + + const eT acc = acc1 + acc2; + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false;"); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + const eT acc = op_dot::direct_dot(A_n_rows, A_colptr, B_colptr); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + arma_extra_debug_print("trans_A = false; trans_B = true;"); + + const uword N = (std::min)(A_n_rows, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + eT acc = eT(0); + + // condition: A_n_cols = B_n_cols + + for(uword i=0; i < A_n_cols; ++i) + { + acc += A.at(k,i) * B.at(k,i); + } + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true;"); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + eT acc = eT(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + + for(uword i=0; i < A_n_rows; ++i) + { + acc += A_colptr[i] * B.at(k,i); + } + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + + + +template +inline +void +op_diagvec::apply(Mat& actual_out, const Op< Glue, op_diagvec>& X, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + const partial_unwrap UA(X.m.A); + const partial_unwrap UB(X.m.B); + + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_elem == 0) || (B.n_elem == 0) ) { actual_out.reset(); return; } + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const bool is_alias = (UA.is_alias(actual_out) || UB.is_alias(actual_out)); + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = false; trans_B = false;"); + + const uword N = (std::min)(A_n_rows, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k, i) * B_colptr[i]; + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false;"); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * B_colptr[i]; + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + // take into account the complex conjugate of xx + + acc_real += (a*c) + (b*d); + acc_imag += (a*d) - (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true) ) + { + arma_extra_debug_print("trans_A = false; trans_B = true;"); + + const uword N = (std::min)(A_n_rows, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + // condition: A_n_cols = B_n_cols + + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k,i) * std::conj(B.at(k,i)); + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == true) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true;"); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i)); + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = -xx.imag(); // take the conjugate + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + + + +// +// +// + + + +template +inline +void +op_diagvec2::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword a = X.aux_uword_a; + const uword b = X.aux_uword_b; + + const uword row_offset = (b > 0) ? a : 0; + const uword col_offset = (b == 0) ? a : 0; + + const Proxy P(X.m); + + if(P.is_alias(out) == false) + { + op_diagvec2::apply_proxy(out, P, row_offset, col_offset); + } + else + { + Mat tmp; + + op_diagvec2::apply_proxy(tmp, P, row_offset, col_offset); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_diagvec2::apply_proxy(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diagvec(): requested diagonal is out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + out.set_size(len, 1); + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < len; i+=2, j+=2) + { + const eT tmp_i = P.at( i + row_offset, i + col_offset ); + const eT tmp_j = P.at( j + row_offset, j + col_offset ); + + out_mem[i] = tmp_i; + out_mem[j] = tmp_j; + } + + if(i < len) + { + out_mem[i] = P.at( i + row_offset, i + col_offset ); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_diff_bones.hpp b/src/armadillo/include/armadillo_bits/op_diff_bones.hpp new file mode 100644 index 0000000..a6844ab --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_diff_bones.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_diff +//! @{ + + + +class op_diff + : public traits_op_default + { + public: + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword k, const uword dim); + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_diff_vec + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_diff_meat.hpp b/src/armadillo/include/armadillo_bits/op_diff_meat.hpp new file mode 100644 index 0000000..a5b309b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_diff_meat.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_diff +//! @{ + + +template +inline +void +op_diff::apply_noalias(Mat& out, const Mat& X, const uword k, const uword dim) + { + arma_extra_debug_sigprint(); + + uword n_rows = X.n_rows; + uword n_cols = X.n_cols; + + if(dim == 0) + { + if(n_rows <= k) { out.set_size(0,n_cols); return; } + + --n_rows; + + out.set_size(n_rows,n_cols); + + for(uword col=0; col < n_cols; ++col) + { + eT* out_colmem = out.colptr(col); + const eT* X_colmem = X.colptr(col); + + for(uword row=0; row < n_rows; ++row) + { + const eT val0 = X_colmem[row ]; + const eT val1 = X_colmem[row+1]; + + out_colmem[row] = val1 - val0; + } + } + + if(k >= 2) + { + for(uword iter=2; iter <= k; ++iter) + { + --n_rows; + + for(uword col=0; col < n_cols; ++col) + { + eT* colmem = out.colptr(col); + + for(uword row=0; row < n_rows; ++row) + { + const eT val0 = colmem[row ]; + const eT val1 = colmem[row+1]; + + colmem[row] = val1 - val0; + } + } + } + + out = out( span(0,n_rows-1), span::all ); + } + } + else + if(dim == 1) + { + if(n_cols <= k) { out.set_size(n_rows,0); return; } + + --n_cols; + + out.set_size(n_rows,n_cols); + + if(n_rows == 1) + { + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword col=0; col < n_cols; ++col) + { + const eT val0 = X_mem[col ]; + const eT val1 = X_mem[col+1]; + + out_mem[col] = val1 - val0; + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + eT* out_col_mem = out.colptr(col); + + const eT* X_col0_mem = X.colptr(col ); + const eT* X_col1_mem = X.colptr(col+1); + + for(uword row=0; row < n_rows; ++row) + { + out_col_mem[row] = X_col1_mem[row] - X_col0_mem[row]; + } + } + } + + if(k >= 2) + { + for(uword iter=2; iter <= k; ++iter) + { + --n_cols; + + if(n_rows == 1) + { + eT* out_mem = out.memptr(); + + for(uword col=0; col < n_cols; ++col) + { + const eT val0 = out_mem[col ]; + const eT val1 = out_mem[col+1]; + + out_mem[col] = val1 - val0; + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + eT* col0_mem = out.colptr(col ); + const eT* col1_mem = out.colptr(col+1); + + for(uword row=0; row < n_rows; ++row) + { + col0_mem[row] = col1_mem[row] - col0_mem[row]; + } + } + } + } + + out = out( span::all, span(0,n_cols-1) ); + } + } + } + + + +template +inline +void +op_diff::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword k = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (dim > 1), "diff(): parameter 'dim' must be 0 or 1" ); + + if(k == 0) { out = in.m; return; } + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_diff::apply_noalias(tmp, U.M, k, dim); + + out.steal_mem(tmp); + } + else + { + op_diff::apply_noalias(out, U.M, k, dim); + } + } + + + +template +inline +void +op_diff_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword k = in.aux_uword_a; + + if(k == 0) { out = in.m; return; } + + const quasi_unwrap U(in.m); + + const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + if(U.is_alias(out)) + { + Mat tmp; + + op_diff::apply_noalias(tmp, U.M, k, dim); + + out.steal_mem(tmp); + } + else + { + op_diff::apply_noalias(out, U.M, k, dim); + } + } + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/op_dot_bones.hpp b/src/armadillo/include/armadillo_bits/op_dot_bones.hpp new file mode 100644 index 0000000..23068de --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_dot_bones.hpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_dot +//! @{ + +//! \brief +//! dot product operation + +class op_dot + : public traits_op_default + { + public: + + template + arma_inline static + typename arma_not_cx::result + direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B); + + template + arma_hot inline static + typename arma_cx_only::result + direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B); + + template + arma_hot inline static typename arma_real_only::result + direct_dot(const uword n_elem, const eT* const A, const eT* const B); + + template + arma_hot inline static typename arma_cx_only::result + direct_dot(const uword n_elem, const eT* const A, const eT* const B); + + template + arma_hot inline static typename arma_integral_only::result + direct_dot(const uword n_elem, const eT* const A, const eT* const B); + + + template + arma_hot inline static eT direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C); + + template + arma_hot inline static typename T1::elem_type apply(const T1& X, const T2& Y); + + template + arma_hot inline static typename arma_not_cx::result apply_proxy(const Proxy& PA, const Proxy& PB); + + template + arma_hot inline static typename arma_cx_only::result apply_proxy(const Proxy& PA, const Proxy& PB); + }; + + + +//! \brief +//! normalised dot product operation + +class op_norm_dot + : public traits_op_default + { + public: + + template + arma_hot inline static typename T1::elem_type apply(const T1& X, const T2& Y); + }; + + + +//! \brief +//! complex conjugate dot product operation + +class op_cdot + : public traits_op_default + { + public: + + template + arma_hot inline static eT direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B); + + template + arma_hot inline static eT direct_cdot(const uword n_elem, const eT* const A, const eT* const B); + + template + arma_hot inline static typename T1::elem_type apply (const T1& X, const T2& Y); + + template + arma_hot inline static typename T1::elem_type apply_unwrap(const T1& X, const T2& Y); + + template + arma_hot inline static typename T1::elem_type apply_proxy (const T1& X, const T2& Y); + }; + + + +class op_dot_mixed + : public traits_op_default + { + public: + + template + arma_hot inline static + typename promote_type::result + apply(const T1& A, const T2& B); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_dot_meat.hpp b/src/armadillo/include/armadillo_bits/op_dot_meat.hpp new file mode 100644 index 0000000..e94c76d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_dot_meat.hpp @@ -0,0 +1,580 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_dot +//! @{ + + + +//! for two arrays, generic version for non-complex values +template +arma_inline +typename arma_not_cx::result +op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) + { + arma_extra_debug_sigprint(); + + #if defined(__FAST_MATH__) + { + eT val = eT(0); + + for(uword i=0; i +inline +typename arma_cx_only::result +op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + T val_real = T(0); + T val_imag = T(0); + + for(uword i=0; i& X = A[i]; + const std::complex& Y = B[i]; + + const T a = X.real(); + const T b = X.imag(); + + const T c = Y.real(); + const T d = Y.imag(); + + val_real += (a*c) - (b*d); + val_imag += (a*d) + (b*c); + } + + return std::complex(val_real, val_imag); + } + + + +//! for two arrays, float and double version +template +inline +typename arma_real_only::result +op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) + { + arma_extra_debug_sigprint(); + + if( n_elem <= 32u ) + { + return op_dot::direct_dot_arma(n_elem, A, B); + } + else + { + #if defined(ARMA_USE_ATLAS) + { + arma_extra_debug_print("atlas::cblas_dot()"); + + return atlas::cblas_dot(n_elem, A, B); + } + #elif defined(ARMA_USE_BLAS) + { + arma_extra_debug_print("blas::dot()"); + + return blas::dot(n_elem, A, B); + } + #else + { + return op_dot::direct_dot_arma(n_elem, A, B); + } + #endif + } + } + + + +//! for two arrays, complex version +template +inline +typename arma_cx_only::result +op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) + { + if( n_elem <= 16u ) + { + return op_dot::direct_dot_arma(n_elem, A, B); + } + else + { + #if defined(ARMA_USE_ATLAS) + { + arma_extra_debug_print("atlas::cblas_cx_dot()"); + + return atlas::cblas_cx_dot(n_elem, A, B); + } + #elif defined(ARMA_USE_BLAS) + { + arma_extra_debug_print("blas::dot()"); + + return blas::dot(n_elem, A, B); + } + #else + { + return op_dot::direct_dot_arma(n_elem, A, B); + } + #endif + } + } + + + +//! for two arrays, integral version +template +inline +typename arma_integral_only::result +op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) + { + return op_dot::direct_dot_arma(n_elem, A, B); + } + + + + +//! for three arrays +template +inline +eT +op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) + { + arma_extra_debug_sigprint(); + + eT val = eT(0); + + for(uword i=0; i +inline +typename T1::elem_type +op_dot::apply(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + const bool use_at = (Proxy::use_at) || (Proxy::use_at); + + const bool have_direct_mem = (quasi_unwrap::has_orig_mem) && (quasi_unwrap::has_orig_mem); + + if(use_at || have_direct_mem) + { + const quasi_unwrap A(X); + const quasi_unwrap B(Y); + + arma_debug_check( (A.M.n_elem != B.M.n_elem), "dot(): objects must have the same number of elements" ); + + return op_dot::direct_dot(A.M.n_elem, A.M.memptr(), B.M.memptr()); + } + else + { + if(is_subview_row::value && is_subview_row::value) + { + typedef typename T1::elem_type eT; + + const subview_row& A = reinterpret_cast< const subview_row& >(X); + const subview_row& B = reinterpret_cast< const subview_row& >(Y); + + if( (A.m.n_rows == 1) && (B.m.n_rows == 1) ) + { + arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" ); + + const eT* A_mem = A.m.memptr(); + const eT* B_mem = B.m.memptr(); + + return op_dot::direct_dot(A.n_elem, &A_mem[A.aux_col1], &B_mem[B.aux_col1]); + } + } + + const Proxy PA(X); + const Proxy PB(Y); + + arma_debug_check( (PA.get_n_elem() != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); + + if(is_Mat::stored_type>::value && is_Mat::stored_type>::value) + { + const quasi_unwrap::stored_type> A(PA.Q); + const quasi_unwrap::stored_type> B(PB.Q); + + return op_dot::direct_dot(A.M.n_elem, A.M.memptr(), B.M.memptr()); + } + + return op_dot::apply_proxy(PA,PB); + } + } + + + +template +inline +typename arma_not_cx::result +op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const uword N = PA.get_n_elem(); + + ea_type1 A = PA.get_ea(); + ea_type2 B = PB.get_ea(); + + eT val1 = eT(0); + eT val2 = eT(0); + + uword i,j; + + for(i=0, j=1; j +inline +typename arma_cx_only::result +op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const uword N = PA.get_n_elem(); + + ea_type1 A = PA.get_ea(); + ea_type2 B = PB.get_ea(); + + T val_real = T(0); + T val_imag = T(0); + + for(uword i=0; i xx = A[i]; + const std::complex yy = B[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + val_real += (a*c) - (b*d); + val_imag += (a*d) + (b*c); + } + + return std::complex(val_real, val_imag); + } + + + +// +// op_norm_dot + + + +template +inline +typename T1::elem_type +op_norm_dot::apply(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const quasi_unwrap tmp1(X); + const quasi_unwrap tmp2(Y); + + const Col A( const_cast(tmp1.M.memptr()), tmp1.M.n_elem, false ); + const Col B( const_cast(tmp2.M.memptr()), tmp2.M.n_elem, false ); + + arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" ); + + const T denom = norm(A,2) * norm(B,2); + + return (denom != T(0)) ? ( op_dot::apply(A,B) / denom ) : eT(0); + } + + + +// +// op_cdot + + + +template +inline +eT +op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + T val_real = T(0); + T val_imag = T(0); + + for(uword i=0; i& X = A[i]; + const std::complex& Y = B[i]; + + const T a = X.real(); + const T b = X.imag(); + + const T c = Y.real(); + const T d = Y.imag(); + + val_real += (a*c) + (b*d); + val_imag += (a*d) - (b*c); + } + + return std::complex(val_real, val_imag); + } + + + +template +inline +eT +op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) + { + arma_extra_debug_sigprint(); + + if( n_elem <= 32u ) + { + return op_cdot::direct_cdot_arma(n_elem, A, B); + } + else + { + #if defined(ARMA_USE_BLAS) + { + arma_extra_debug_print("blas::gemv()"); + + // using gemv() workaround due to compatibility issues with cdotc() and zdotc() + + const char trans = 'C'; + + const blas_int m = blas_int(n_elem); + const blas_int n = 1; + //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); + const blas_int inc = 1; + + const eT alpha = eT(1); + const eT beta = eT(0); + + eT result[2]; // paranoia: using two elements instead of one + + //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc); + blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc); + + return result[0]; + } + #else + { + return op_cdot::direct_cdot_arma(n_elem, A, B); + } + #endif + } + } + + + +template +inline +typename T1::elem_type +op_cdot::apply(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + if(is_Mat::value && is_Mat::value) + { + return op_cdot::apply_unwrap(X,Y); + } + else + { + return op_cdot::apply_proxy(X,Y); + } + } + + + +template +inline +typename T1::elem_type +op_cdot::apply_unwrap(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp1(X); + const unwrap tmp2(Y); + + const Mat& A = tmp1.M; + const Mat& B = tmp2.M; + + arma_debug_check( (A.n_elem != B.n_elem), "cdot(): objects must have the same number of elements" ); + + return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem ); + } + + + +template +inline +typename T1::elem_type +op_cdot::apply_proxy(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const bool use_at = (Proxy::use_at) || (Proxy::use_at); + + if(use_at == false) + { + const Proxy PA(X); + const Proxy PB(Y); + + const uword N = PA.get_n_elem(); + + arma_debug_check( (N != PB.get_n_elem()), "cdot(): objects must have the same number of elements" ); + + ea_type1 A = PA.get_ea(); + ea_type2 B = PB.get_ea(); + + T val_real = T(0); + T val_imag = T(0); + + for(uword i=0; i AA = A[i]; + const std::complex BB = B[i]; + + const T a = AA.real(); + const T b = AA.imag(); + + const T c = BB.real(); + const T d = BB.imag(); + + val_real += (a*c) + (b*d); + val_imag += (a*d) - (b*c); + } + + return std::complex(val_real, val_imag); + } + else + { + return op_cdot::apply_unwrap( X, Y ); + } + } + + + +template +inline +typename promote_type::result +op_dot_mixed::apply(const T1& A, const T2& B) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT1; + typedef typename T2::elem_type in_eT2; + + typedef typename promote_type::result out_eT; + + const Proxy PA(A); + const Proxy PB(B); + + const uword N = PA.get_n_elem(); + + arma_debug_check( (N != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); + + out_eT acc = out_eT(0); + + for(uword i=0; i < N; ++i) + { + acc += upgrade_val::apply(PA[i]) * upgrade_val::apply(PB[i]); + } + + return acc; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_dotext_bones.hpp b/src/armadillo/include/armadillo_bits/op_dotext_bones.hpp new file mode 100644 index 0000000..dc3b7b8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_dotext_bones.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_dotext +//! @{ + + + +class op_dotext + : public traits_op_default + { + public: + + + template + inline static eT direct_rowvec_mat_colvec (const eT* A_mem, const Mat& B, const eT* C_mem); + + template + inline static eT direct_rowvec_transmat_colvec (const eT* A_mem, const Mat& B, const eT* C_mem); + + template + inline static eT direct_rowvec_diagmat_colvec (const eT* A_mem, const Mat& B, const eT* C_mem); + + template + inline static eT direct_rowvec_invdiagmat_colvec(const eT* A_mem, const Mat& B, const eT* C_mem); + + template + inline static eT direct_rowvec_invdiagvec_colvec(const eT* A_mem, const Mat& B, const eT* C_mem); + + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/op_dotext_meat.hpp b/src/armadillo/include/armadillo_bits/op_dotext_meat.hpp new file mode 100644 index 0000000..c190b2c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_dotext_meat.hpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_dotext +//! @{ + + + +template +inline +eT +op_dotext::direct_rowvec_mat_colvec + ( + const eT* A_mem, + const Mat& B, + const eT* C_mem + ) + { + arma_extra_debug_sigprint(); + + const uword cost_AB = B.n_cols; + const uword cost_BC = B.n_rows; + + if(cost_AB <= cost_BC) + { + podarray tmp(B.n_cols); + + for(uword col=0; col tmp(B.n_rows); + + for(uword row=0; row +inline +eT +op_dotext::direct_rowvec_transmat_colvec + ( + const eT* A_mem, + const Mat& B, + const eT* C_mem + ) + { + arma_extra_debug_sigprint(); + + const uword cost_AB = B.n_rows; + const uword cost_BC = B.n_cols; + + if(cost_AB <= cost_BC) + { + podarray tmp(B.n_rows); + + for(uword row=0; row tmp(B.n_cols); + + for(uword col=0; col +inline +eT +op_dotext::direct_rowvec_diagmat_colvec + ( + const eT* A_mem, + const Mat& B, + const eT* C_mem + ) + { + arma_extra_debug_sigprint(); + + eT val = eT(0); + + for(uword i=0; i +inline +eT +op_dotext::direct_rowvec_invdiagmat_colvec + ( + const eT* A_mem, + const Mat& B, + const eT* C_mem + ) + { + arma_extra_debug_sigprint(); + + eT val = eT(0); + + for(uword i=0; i +inline +eT +op_dotext::direct_rowvec_invdiagvec_colvec + ( + const eT* A_mem, + const Mat& B, + const eT* C_mem + ) + { + arma_extra_debug_sigprint(); + + const eT* B_mem = B.mem; + + eT val = eT(0); + + for(uword i=0; i + inline static void apply(Mat& out, const Op& expr); + + template + inline static bool apply_direct(Mat& out, const Base& X); + }; + + + +class op_expmat_sym + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_expmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_expmat_meat.hpp new file mode 100644 index 0000000..d45fb36 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_expmat_meat.hpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_expmat +//! @{ + + +//! implementation based on: +//! Cleve Moler, Charles Van Loan. +//! Nineteen Dubious Ways to Compute the Exponential of a Matrix, Twenty-Five Years Later. +//! SIAM Review, Vol. 45, No. 1, 2003, pp. 3-49. +//! http://dx.doi.org/10.1137/S00361445024180 + + +template +inline +void +op_expmat::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + const bool status = op_expmat::apply_direct(out, expr.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("expmat(): given matrix appears ill-conditioned"); + } + } + + + +template +inline +bool +op_expmat::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(is_op_diagmat::value) + { + out = expr.get_ref(); // force the evaluation of diagmat() + + arma_debug_check( (out.is_square() == false), "expmat(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + const uword N = (std::min)(out.n_rows, out.n_cols); + + for(uword i=0; i A = expr.get_ref(); + + arma_debug_check( (A.is_square() == false), "expmat(): given matrix must be square sized" ); + + if(A.is_diagmat()) + { + arma_extra_debug_print("op_expmat: detected diagonal matrix"); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + out.zeros(N,N); + + for(uword i=0; i::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd)); + } + + if(do_sym) + { + arma_extra_debug_print("op_expmat: symmetric/hermitian optimisation"); + + Col< T> eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, A, 'd', "expmat()"); + + if(eig_status == false) { return false; } + + eigval = exp(eigval); + + out = eigvec * diagmat(eigval) * eigvec.t(); + + return true; + } + + const T norm_val = arma::norm(A, "inf"); + + if(arma_isfinite(norm_val) == false) { return false; } + + const double log2_val = (norm_val > T(0)) ? double(eop_aux::log2(norm_val)) : double(0); + + int exponent = int(0); std::frexp(log2_val, &exponent); + + const uword s = uword( (std::max)(int(0), exponent + int(1)) ); + + A /= eT(eop_aux::pow(double(2), double(s))); + + T c = T(0.5); + + Mat E(A.n_rows, A.n_rows, fill::eye); E += c * A; + Mat D(A.n_rows, A.n_rows, fill::eye); D -= c * A; + + Mat X = A; + + bool positive = true; + + const uword N = 6; + + for(uword i = 2; i <= N; ++i) + { + c = c * T(N - i + 1) / T(i * (2*N - i + 1)); + + X = A * X; + + E += c * X; + + if(positive) { D += c * X; } else { D -= c * X; } + + positive = (positive) ? false : true; + } + + if( (D.internal_has_nonfinite()) || (E.internal_has_nonfinite()) ) { return false; } + + const bool status = solve(out, D, E, solve_opts::no_approx); + + if(status == false) { return false; } + + for(uword i=0; i < s; ++i) { out = out * out; } + + return true; + } + + + +template +inline +void +op_expmat_sym::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_expmat_sym::apply_direct(out, in.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("expmat_sym(): transformation failed"); + } + } + + + +template +inline +bool +op_expmat_sym::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap U(expr.get_ref()); + const Mat& X = U.M; + + arma_debug_check( (X.is_square() == false), "expmat_sym(): given matrix must be square sized" ); + + if((arma_config::debug) && (arma_config::warn_level > 0) && (is_cx::yes) && (sym_helper::check_diag_imag(X) == false)) + { + arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); + } + + if(is_op_diagmat::value || X.is_diagmat()) + { + arma_extra_debug_print("op_expmat_sym: detected diagonal matrix"); + + out = X; + + eT* colmem = out.memptr(); + + const uword N = X.n_rows; + + for(uword i=0; i eigval; + Mat eigvec; + + const bool status = eig_sym_helper(eigval, eigvec, X, 'd', "expmat_sym()"); + + if(status == false) { return false; } + + eigval = exp(eigval); + + out = eigvec * diagmat(eigval) * eigvec.t(); + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(expr); + arma_stop_logic_error("expmat_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_fft_bones.hpp b/src/armadillo/include/armadillo_bits/op_fft_bones.hpp new file mode 100644 index 0000000..b0dcbfd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_fft_bones.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_fft +//! @{ + + + +class op_fft_real + : public traits_op_passthru + { + public: + + template + inline static void apply( Mat< std::complex >& out, const mtOp,T1,op_fft_real>& in ); + }; + + + +class op_fft_cx + : public traits_op_passthru + { + public: + + template + inline static void apply( Mat& out, const Op& in ); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword a, const uword b); + }; + + + +class op_ifft_cx + : public traits_op_passthru + { + public: + + template + inline static void apply( Mat& out, const Op& in ); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_fft_meat.hpp b/src/armadillo/include/armadillo_bits/op_fft_meat.hpp new file mode 100644 index 0000000..4f5d93a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_fft_meat.hpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_fft +//! @{ + + +#if defined(ARMA_USE_FFTW3) + +template +class fft_engine_wrapper + { + public: + + static constexpr uword threshold = 512; + + fft_engine_kissfft* worker_kissfft = nullptr; + fft_engine_fftw3 * worker_fftw3 = nullptr; + + inline + ~fft_engine_wrapper() + { + arma_extra_debug_sigprint(); + + if(worker_kissfft != nullptr) { delete worker_kissfft; } + if(worker_fftw3 != nullptr) { delete worker_fftw3; } + } + + inline + fft_engine_wrapper(const uword N_samples, const uword N_exec) + { + arma_extra_debug_sigprint(); + + const bool use_fftw3 = N_samples >= (threshold / N_exec); + + worker_kissfft = (use_fftw3 == false) ? new fft_engine_kissfft(N_samples) : nullptr; + worker_fftw3 = (use_fftw3 == true ) ? new fft_engine_fftw3 (N_samples) : nullptr; + } + + inline + void + run(cx_type* Y, const cx_type* X) + { + arma_extra_debug_sigprint(); + + if(worker_kissfft != nullptr) { (*worker_kissfft).run(Y,X); } + else if(worker_fftw3 != nullptr) { (*worker_fftw3).run(Y,X); } + } + }; + +#endif + + +// +// op_fft_real + + +template +inline +void +op_fft_real::apply( Mat< std::complex >& out, const mtOp,T1,op_fft_real>& in ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type in_eT; + typedef typename std::complex out_eT; + + // no need to worry about aliasing, as we're going from a real object to complex complex, which by definition cannot alias + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + const uword n_rows = X.n_rows; + const uword n_cols = X.n_cols; + const uword n_elem = X.n_elem; + + const bool is_vec = ( (n_rows == 1) || (n_cols == 1) ); + + const uword N_orig = (is_vec) ? n_elem : n_rows; + const uword N_user = (in.aux_uword_b == 0) ? in.aux_uword_a : N_orig; + + #if defined(ARMA_USE_FFTW3) + const uword N_exec = (is_vec) ? uword(1) : n_cols; + fft_engine_wrapper worker(N_user, N_exec); + #else + fft_engine_kissfft worker(N_user); + #endif + + if(is_vec) + { + (n_cols == 1) ? out.set_size(N_user, 1) : out.set_size(1, N_user); + + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } + + if( (N_user == 1) && (N_orig >= 1) ) { out[0] = out_eT( X[0] ); return; } + + podarray data(N_user, arma_zeros_indicator()); + + out_eT* data_mem = data.memptr(); + const in_eT* X_mem = X.memptr(); + + const uword N = (std::min)(N_user, N_orig); + + for(uword i=0; i < N; ++i) { data_mem[i].real(X_mem[i]); } + + worker.run( out.memptr(), data_mem ); + } + else + { + // process each column seperately + + out.set_size(N_user, n_cols); + + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } + + if( (N_user == 1) && (N_orig >= 1) ) + { + for(uword col=0; col < n_cols; ++col) { out.at(0,col).real( X.at(0,col) ); } + + return; + } + + podarray data(N_user, arma_zeros_indicator()); + + out_eT* data_mem = data.memptr(); + + const uword N = (std::min)(N_user, N_orig); + + for(uword col=0; col < n_cols; ++col) + { + for(uword i=0; i < N; ++i) { data_mem[i].real( X.at(i, col) ); } + + worker.run( out.colptr(col), data_mem ); + } + } + } + + + +// +// op_fft_cx + + +template +inline +void +op_fft_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_fft_cx::apply_noalias(tmp, U.M, in.aux_uword_a, in.aux_uword_b); + + out.steal_mem(tmp); + } + else + { + op_fft_cx::apply_noalias(out, U.M, in.aux_uword_a, in.aux_uword_b); + } + } + + + +template +inline +void +op_fft_cx::apply_noalias(Mat& out, const Mat& X, const uword a, const uword b) + { + arma_extra_debug_sigprint(); + + const uword n_rows = X.n_rows; + const uword n_cols = X.n_cols; + const uword n_elem = X.n_elem; + + const bool is_vec = ( (n_rows == 1) || (n_cols == 1) ); + + const uword N_orig = (is_vec) ? n_elem : n_rows; + const uword N_user = (b == 0) ? a : N_orig; + + #if defined(ARMA_USE_FFTW3) + const uword N_exec = (is_vec) ? uword(1) : n_cols; + fft_engine_wrapper worker(N_user, N_exec); + #else + fft_engine_kissfft worker(N_user); + #endif + + if(is_vec) + { + (n_cols == 1) ? out.set_size(N_user, 1) : out.set_size(1, N_user); + + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } + + if( (N_user == 1) && (N_orig >= 1) ) { out[0] = X[0]; return; } + + if(N_user > N_orig) + { + podarray data(N_user); + + eT* data_mem = data.memptr(); + + arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); + + arrayops::copy(data_mem, X.memptr(), (std::min)(N_user, N_orig)); + + worker.run( out.memptr(), data_mem ); + } + else + { + worker.run( out.memptr(), X.memptr() ); + } + } + else + { + // process each column seperately + + out.set_size(N_user, n_cols); + + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } + + if( (N_user == 1) && (N_orig >= 1) ) + { + for(uword col=0; col < n_cols; ++col) { out.at(0,col) = X.at(0,col); } + + return; + } + + if(N_user > N_orig) + { + podarray data(N_user); + + eT* data_mem = data.memptr(); + + arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); + + const uword N = (std::min)(N_user, N_orig); + + for(uword col=0; col < n_cols; ++col) + { + arrayops::copy(data_mem, X.colptr(col), N); + + worker.run( out.colptr(col), data_mem ); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + worker.run( out.colptr(col), X.colptr(col) ); + } + } + } + + + // correct the scaling for the inverse transform + if(inverse) + { + typedef typename get_pod_type::result T; + + const T k = T(1) / T(N_user); + + eT* out_mem = out.memptr(); + + const uword out_n_elem = out.n_elem; + + for(uword i=0; i < out_n_elem; ++i) { out_mem[i] *= k; } + } + } + + + +// +// op_ifft_cx + + +template +inline +void +op_ifft_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_fft_cx::apply_noalias(tmp, U.M, in.aux_uword_a, in.aux_uword_b); + + out.steal_mem(tmp); + } + else + { + op_fft_cx::apply_noalias(out, U.M, in.aux_uword_a, in.aux_uword_b); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_find_bones.hpp b/src/armadillo/include/armadillo_bits/op_find_bones.hpp new file mode 100644 index 0000000..6e7c9cc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_find_bones.hpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_find +//! @{ + + + +class op_find + : public traits_op_col + { + public: + + template + inline static uword + helper + ( + Mat& indices, + const Base& X + ); + + template + inline static uword + helper + ( + Mat& indices, + const mtOp& X, + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr + ); + + template + inline static uword + helper + ( + Mat& indices, + const mtOp& X, + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr + ); + + template + inline static uword + helper + ( + Mat& indices, + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr, + const typename arma_not_cx::result* junk3 = nullptr + ); + + template + inline static uword + helper + ( + Mat& indices, + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr, + const typename arma_cx_only::result* junk3 = nullptr + ); + + template + inline static void apply(Mat& out, const mtOp& X); + }; + + + +class op_find_simple + : public traits_op_col + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + }; + + + +class op_find_finite + : public traits_op_col + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + }; + + + +class op_find_nonfinite + : public traits_op_col + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + }; + + + +class op_find_nan + : public traits_op_col + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_find_meat.hpp b/src/armadillo/include/armadillo_bits/op_find_meat.hpp new file mode 100644 index 0000000..b2fd6dd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_find_meat.hpp @@ -0,0 +1,660 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_find +//! @{ + + + +template +inline +uword +op_find::helper + ( + Mat& indices, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy A(X.get_ref()); + + const uword n_elem = A.get_n_elem(); + + indices.set_size(n_elem, 1); + + uword* indices_mem = indices.memptr(); + uword n_nz = 0; + + if(Proxy::use_at == false) + { + typename Proxy::ea_type PA = A.get_ea(); + + for(uword i=0; i +inline +uword +op_find::helper + ( + Mat& indices, + const mtOp& X, + const typename arma_op_rel_only::result* junk1, + const typename arma_not_cx::result* junk2 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::elem_type eT; + + const eT val = X.aux; + + if((is_same_type::yes || is_same_type::yes) && arma_config::debug && arma_isnan(val)) + { + arma_debug_warn_level(1, "find(): NaN is not equal to anything; suggest to use find_nonfinite() instead"); + } + + const Proxy A(X.m); + + const uword n_elem = A.get_n_elem(); + + indices.set_size(n_elem, 1); + + uword* indices_mem = indices.memptr(); + uword n_nz = 0; + + if(Proxy::use_at == false) + { + typename Proxy::ea_type PA = A.get_ea(); + + uword i,j; + for(i=0, j=1; j < n_elem; i+=2, j+=2) + { + const eT tpi = PA[i]; + const eT tpj = PA[j]; + + bool not_zero_i; + bool not_zero_j; + + if(is_same_type::yes) { not_zero_i = (val < tpi); } + else if(is_same_type::yes) { not_zero_i = (tpi < val); } + else if(is_same_type::yes) { not_zero_i = (val > tpi); } + else if(is_same_type::yes) { not_zero_i = (tpi > val); } + else if(is_same_type::yes) { not_zero_i = (val <= tpi); } + else if(is_same_type::yes) { not_zero_i = (tpi <= val); } + else if(is_same_type::yes) { not_zero_i = (val >= tpi); } + else if(is_same_type::yes) { not_zero_i = (tpi >= val); } + else if(is_same_type::yes) { not_zero_i = (tpi == val); } + else if(is_same_type::yes) { not_zero_i = (tpi != val); } + else { not_zero_i = false; } + + if(is_same_type::yes) { not_zero_j = (val < tpj); } + else if(is_same_type::yes) { not_zero_j = (tpj < val); } + else if(is_same_type::yes) { not_zero_j = (val > tpj); } + else if(is_same_type::yes) { not_zero_j = (tpj > val); } + else if(is_same_type::yes) { not_zero_j = (val <= tpj); } + else if(is_same_type::yes) { not_zero_j = (tpj <= val); } + else if(is_same_type::yes) { not_zero_j = (val >= tpj); } + else if(is_same_type::yes) { not_zero_j = (tpj >= val); } + else if(is_same_type::yes) { not_zero_j = (tpj == val); } + else if(is_same_type::yes) { not_zero_j = (tpj != val); } + else { not_zero_j = false; } + + if(not_zero_i) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero_j) { indices_mem[n_nz] = j; ++n_nz; } + } + + if(i < n_elem) + { + bool not_zero; + + const eT tmp = PA[i]; + + if(is_same_type::yes) { not_zero = (val < tmp); } + else if(is_same_type::yes) { not_zero = (tmp < val); } + else if(is_same_type::yes) { not_zero = (val > tmp); } + else if(is_same_type::yes) { not_zero = (tmp > val); } + else if(is_same_type::yes) { not_zero = (val <= tmp); } + else if(is_same_type::yes) { not_zero = (tmp <= val); } + else if(is_same_type::yes) { not_zero = (val >= tmp); } + else if(is_same_type::yes) { not_zero = (tmp >= val); } + else if(is_same_type::yes) { not_zero = (tmp == val); } + else if(is_same_type::yes) { not_zero = (tmp != val); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + } + } + else + { + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + uword i = 0; + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT tmp = A.at(row,col); + + bool not_zero; + + if(is_same_type::yes) { not_zero = (val < tmp); } + else if(is_same_type::yes) { not_zero = (tmp < val); } + else if(is_same_type::yes) { not_zero = (val > tmp); } + else if(is_same_type::yes) { not_zero = (tmp > val); } + else if(is_same_type::yes) { not_zero = (val <= tmp); } + else if(is_same_type::yes) { not_zero = (tmp <= val); } + else if(is_same_type::yes) { not_zero = (val >= tmp); } + else if(is_same_type::yes) { not_zero = (tmp >= val); } + else if(is_same_type::yes) { not_zero = (tmp == val); } + else if(is_same_type::yes) { not_zero = (tmp != val); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + + ++i; + } + } + + return n_nz; + } + + + +template +inline +uword +op_find::helper + ( + Mat& indices, + const mtOp& X, + const typename arma_op_rel_only::result* junk1, + const typename arma_cx_only::result* junk2 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::elem_type eT; + typedef typename Proxy::ea_type ea_type; + + const eT val = X.aux; + + if((is_same_type::yes || is_same_type::yes) && arma_config::debug && arma_isnan(val)) + { + arma_debug_warn_level(1, "find(): NaN is not equal to anything; suggest to use find_nonfinite() instead"); + } + + const Proxy A(X.m); + + const uword n_elem = A.get_n_elem(); + + indices.set_size(n_elem, 1); + + uword* indices_mem = indices.memptr(); + uword n_nz = 0; + + + if(Proxy::use_at == false) + { + ea_type PA = A.get_ea(); + + for(uword i=0; i::yes) { not_zero = (tmp == val); } + else if(is_same_type::yes) { not_zero = (tmp != val); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + } + } + else + { + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + uword i = 0; + + for(uword col=0; col::yes) { not_zero = (tmp == val); } + else if(is_same_type::yes) { not_zero = (tmp != val); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + + i++; + } + } + + return n_nz; + } + + + +template +inline +uword +op_find::helper + ( + Mat& indices, + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1, + const typename arma_not_cx::result* junk2, + const typename arma_not_cx::result* junk3 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + arma_ignore(junk3); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "relational operator"); + + const uword n_elem = A.get_n_elem(); + + indices.set_size(n_elem, 1); + + uword* indices_mem = indices.memptr(); + uword n_nz = 0; + + if((Proxy::use_at == false) && (Proxy::use_at == false)) + { + ea_type1 PA = A.get_ea(); + ea_type2 PB = B.get_ea(); + + for(uword i=0; i::yes) { not_zero = (tmp1 < tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 > tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 <= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 >= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 == tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 != tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 && tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 || tmp2); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + } + } + else + { + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + uword i = 0; + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT1 tmp1 = A.at(row,col); + const eT2 tmp2 = B.at(row,col); + + bool not_zero; + + if(is_same_type::yes) { not_zero = (tmp1 < tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 > tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 <= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 >= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 == tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 != tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 && tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 || tmp2); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + + i++; + } + } + + return n_nz; + } + + + +template +inline +uword +op_find::helper + ( + Mat& indices, + const mtGlue& X, + const typename arma_glue_rel_only::result* junk1, + const typename arma_cx_only::result* junk2, + const typename arma_cx_only::result* junk3 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + arma_ignore(junk3); + + typedef typename Proxy::ea_type ea_type1; + typedef typename Proxy::ea_type ea_type2; + + const Proxy A(X.A); + const Proxy B(X.B); + + arma_debug_assert_same_size(A, B, "relational operator"); + + const uword n_elem = A.get_n_elem(); + + indices.set_size(n_elem, 1); + + uword* indices_mem = indices.memptr(); + uword n_nz = 0; + + if((Proxy::use_at == false) && (Proxy::use_at == false)) + { + ea_type1 PA = A.get_ea(); + ea_type2 PB = B.get_ea(); + + for(uword i=0; i::yes) { not_zero = (PA[i] == PB[i]); } + else if(is_same_type::yes) { not_zero = (PA[i] != PB[i]); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + } + } + else + { + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); + + uword i = 0; + + for(uword col=0; col::yes) { not_zero = (A.at(row,col) == B.at(row,col)); } + else if(is_same_type::yes) { not_zero = (A.at(row,col) != B.at(row,col)); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + + i++; + } + } + + return n_nz; + } + + + +template +inline +void +op_find::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + const uword k = X.aux_uword_a; + const uword type = X.aux_uword_b; + + Mat indices; + const uword n_nz = op_find::helper(indices, X.m); + + if(n_nz > 0) + { + if(type == 0) // "first" + { + out = (k > 0 && k <= n_nz) ? indices.rows(0, k-1 ) : indices.rows(0, n_nz-1); + } + else // "last" + { + out = (k > 0 && k <= n_nz) ? indices.rows(n_nz-k, n_nz-1) : indices.rows(0, n_nz-1); + } + } + else + { + out.set_size(0,1); // empty column vector + } + } + + + +// + + + +template +inline +void +op_find_simple::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + Mat indices; + const uword n_nz = op_find::helper(indices, X.m); + + out.steal_mem_col(indices, n_nz); + } + + + +// + + + +template +inline +void +op_find_finite::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "find_finite(): detection of non-finite values is not reliable in fast math mode"); } + + const Proxy P(X.m); + + const uword n_elem = P.get_n_elem(); + + Mat indices(n_elem, 1, arma_nozeros_indicator()); + + uword* indices_mem = indices.memptr(); + uword count = 0; + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i +inline +void +op_find_nonfinite::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "find_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + const Proxy P(X.m); + + const uword n_elem = P.get_n_elem(); + + Mat indices(n_elem, 1, arma_nozeros_indicator()); + + uword* indices_mem = indices.memptr(); + uword count = 0; + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i +inline +void +op_find_nan::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "find_nan(): detection of non-finite values is not reliable in fast math mode"); } + + const Proxy P(X.m); + + const uword n_elem = P.get_n_elem(); + + Mat indices(n_elem, 1, arma_nozeros_indicator()); + + uword* indices_mem = indices.memptr(); + uword count = 0; + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i + static inline bool apply_helper(Mat& out, const Proxy& P, const bool ascending_indices); + + template + static inline void apply(Mat& out, const mtOp& in); + }; + + + +template +struct arma_find_unique_packet + { + eT val; + uword index; + }; + + + +template +struct arma_find_unique_comparator + { + arma_inline + bool + operator() (const arma_find_unique_packet& A, const arma_find_unique_packet& B) const + { + return (A.val < B.val); + } + }; + + + +template +struct arma_find_unique_comparator< std::complex > + { + arma_inline + bool + operator() (const arma_find_unique_packet< std::complex >& A, const arma_find_unique_packet< std::complex >& B) const + { + const T A_real = A.val.real(); + const T B_real = B.val.real(); + + return ( (A_real < B_real) ? true : ((A_real == B_real) ? (A.val.imag() < B.val.imag()) : false) ); + } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_find_unique_meat.hpp b/src/armadillo/include/armadillo_bits/op_find_unique_meat.hpp new file mode 100644 index 0000000..92ee322 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_find_unique_meat.hpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_find_unique +//! @{ + + + +template +inline +bool +op_find_unique::apply_helper(Mat& out, const Proxy& P, const bool ascending_indices) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { out.set_size(0,1); return true; } + if(n_elem == 1) { out.set_size(1,1); out[0] = 0; return true; } + + uvec indices(n_elem, arma_nozeros_indicator()); + + std::vector< arma_find_unique_packet > packet_vec(n_elem); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i comparator; + + std::sort( packet_vec.begin(), packet_vec.end(), comparator ); + + uword* indices_mem = indices.memptr(); + + indices_mem[0] = packet_vec[0].index; + + uword count = 1; + + for(uword i=1; i < n_elem; ++i) + { + const eT diff = packet_vec[i-1].val - packet_vec[i].val; + + if(diff != eT(0)) + { + indices_mem[count] = packet_vec[i].index; + ++count; + } + } + + out.steal_mem_col(indices,count); + + if(ascending_indices) { std::sort(out.begin(), out.end()); } + + return true; + } + + + +template +inline +void +op_find_unique::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + const Proxy P(in.m); + + const bool ascending_indices = (in.aux_uword_a == uword(1)); + + const bool all_non_nan = op_find_unique::apply_helper(out, P, ascending_indices); + + if(all_non_nan == false) + { + arma_debug_check( true, "find_unique(): detected NaN" ); + + out.reset(); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_flip_bones.hpp b/src/armadillo/include/armadillo_bits/op_flip_bones.hpp new file mode 100644 index 0000000..c81eca1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_flip_bones.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_flip +//! @{ + + + +class op_flipud + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_direct(Mat& out, const Mat& X); + + template + inline static void apply_proxy_noalias(Mat& out, const Proxy& P); + }; + + + + +class op_fliplr + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_direct(Mat& out, const Mat& X); + + template + inline static void apply_proxy_noalias(Mat& out, const Proxy& P); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_flip_meat.hpp b/src/armadillo/include/armadillo_bits/op_flip_meat.hpp new file mode 100644 index 0000000..470ae92 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_flip_meat.hpp @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_flip +//! @{ + + + +template +inline +void +op_flipud::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Mat::value) + { + // allow detection of in-place operation + + const unwrap U(in.m); + + op_flipud::apply_direct(out, U.M); + } + else + { + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_flipud::apply_proxy_noalias(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_flipud::apply_proxy_noalias(out, P); + } + } + } + + + +template +inline +void +op_flipud::apply_direct(Mat& out, const Mat& X) + { + arma_extra_debug_sigprint(); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword X_n_rows_m1 = X_n_rows - 1; + + if(&out != &X) + { + out.set_size(X_n_rows, X_n_cols); + + if(X_n_cols == 1) + { + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[X_n_rows_m1 - row] = X_mem[row]; + } + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + const eT* X_colmem = X.colptr(col); + eT* out_colmem = out.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + out_colmem[X_n_rows_m1 - row] = X_colmem[row]; + } + } + } + } + else // in-place operation + { + const uword N = X_n_rows / 2; + + if(X_n_cols == 1) + { + eT* out_mem = out.memptr(); + + for(uword row=0; row < N; ++row) + { + std::swap(out_mem[X_n_rows_m1 - row], out_mem[row]); + } + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + eT* out_colmem = out.colptr(col); + + for(uword row=0; row < N; ++row) + { + std::swap(out_colmem[X_n_rows_m1 - row], out_colmem[row]); + } + } + } + } + } + + + +template +inline +void +op_flipud::apply_proxy_noalias(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + typedef typename Proxy::stored_type P_stored_type; + + if(is_Mat::value) + { + const unwrap U(P.Q); + + op_flipud::apply_direct(out, U.M); + + return; + } + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + + const uword P_n_rows_m1 = P_n_rows - 1; + + out.set_size(P_n_rows, P_n_cols); + + if( ((T1::is_col) || (P_n_cols == 1)) && (Proxy::use_at == false) ) + { + eT* out_mem = out.memptr(); + + const typename Proxy::ea_type P_ea = P.get_ea(); + + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[P_n_rows_m1 - row] = P_ea[row]; + } + } + else + { + for(uword col=0; col < P_n_cols; ++col) + { + eT* out_colmem = out.colptr(col); + + for(uword row=0; row < P_n_rows; ++row) + { + out_colmem[P_n_rows_m1 - row] = P.at(row, col); + } + } + } + } + + + +// + + + +template +inline +void +op_fliplr::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Mat::value) + { + // allow detection of in-place operation + + const unwrap U(in.m); + + op_fliplr::apply_direct(out, U.M); + } + else + { + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_fliplr::apply_proxy_noalias(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_fliplr::apply_proxy_noalias(out, P); + } + } + } + + + +template +inline +void +op_fliplr::apply_direct(Mat& out, const Mat& X) + { + arma_extra_debug_sigprint(); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword X_n_cols_m1 = X_n_cols - 1; + + if(&out != &X) + { + out.set_size(X_n_rows, X_n_cols); + + if(X_n_rows == 1) + { + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[X_n_cols_m1 - col] = X_mem[col]; + } + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + out.col(X_n_cols_m1 - col) = X.col(col); + } + } + } + else // in-place operation + { + const uword N = X_n_cols / 2; + + if(X_n_rows == 1) + { + eT* out_mem = out.memptr(); + + for(uword col=0; col < N; ++col) + { + std::swap(out_mem[X_n_cols_m1 - col], out_mem[col]); + } + } + else + { + for(uword col=0; col < N; ++col) + { + out.swap_cols(X_n_cols_m1 - col, col); + } + } + } + } + + + +template +inline +void +op_fliplr::apply_proxy_noalias(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + typedef typename Proxy::stored_type P_stored_type; + + if(is_Mat::value) + { + const unwrap U(P.Q); + + op_fliplr::apply_direct(out, U.M); + + return; + } + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + + const uword P_n_cols_m1 = P_n_cols - 1; + + out.set_size(P_n_rows, P_n_cols); + + if( ((T1::is_row) || (P_n_rows == 1)) && (Proxy::use_at == false) ) + { + eT* out_mem = out.memptr(); + + const typename Proxy::ea_type P_ea = P.get_ea(); + + for(uword col=0; col < P_n_cols; ++col) + { + out_mem[P_n_cols_m1 - col] = P_ea[col]; + } + } + else + { + for(uword col=0; col < P_n_cols; ++col) + { + eT* out_colmem = out.colptr(P_n_cols_m1 - col); + + for(uword row=0; row < P_n_rows; ++row) + { + out_colmem[row] = P.at(row,col); + } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_hist_bones.hpp b/src/armadillo/include/armadillo_bits/op_hist_bones.hpp new file mode 100644 index 0000000..c014ba2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_hist_bones.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_hist +//! @{ + + + +class op_hist + : public traits_op_passthru + { + public: + + template + inline static void apply_noalias(Mat& out, const Mat& A, const uword n_bins, const uword dim); + + template + inline static void apply(Mat& out, const mtOp& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_hist_meat.hpp b/src/armadillo/include/armadillo_bits/op_hist_meat.hpp new file mode 100644 index 0000000..04c5ed8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_hist_meat.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_hist +//! @{ + + + +template +inline +void +op_hist::apply_noalias(Mat& out, const Mat& A, const uword n_bins, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ((A.is_vec() == false) && (A.is_empty() == false)), "hist(): only vectors are supported when automatically determining bin centers" ); + + if(n_bins == 0) { out.reset(); return; } + + uword A_n_elem = A.n_elem; + const eT* A_mem = A.memptr(); + + eT min_val = priv::most_pos(); + eT max_val = priv::most_neg(); + + uword i,j; + for(i=0, j=1; j < A_n_elem; i+=2, j+=2) + { + const eT val_i = A_mem[i]; + const eT val_j = A_mem[j]; + + if(min_val > val_i) { min_val = val_i; } + if(min_val > val_j) { min_val = val_j; } + + if(max_val < val_i) { max_val = val_i; } + if(max_val < val_j) { max_val = val_j; } + } + + if(i < A_n_elem) + { + const eT val_i = A_mem[i]; + + if(min_val > val_i) { min_val = val_i; } + if(max_val < val_i) { max_val = val_i; } + } + + if(min_val == max_val) + { + min_val -= (n_bins/2); + max_val += (n_bins/2); + } + + if(arma_isfinite(min_val) == false) { min_val = priv::most_neg(); } + if(arma_isfinite(max_val) == false) { max_val = priv::most_pos(); } + + Col c(n_bins, arma_nozeros_indicator()); + eT* c_mem = c.memptr(); + + for(uword ii=0; ii < n_bins; ++ii) + { + c_mem[ii] = (0.5 + ii) / double(n_bins); + } + + c = ((max_val - min_val) * c) + min_val; + + glue_hist::apply_noalias(out, A, c, dim); + } + + + +template +inline +void +op_hist::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + const uword n_bins = X.aux_uword_a; + + const quasi_unwrap U(X.m); + + const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + if(is_non_integral::value) + { + if(U.is_alias(out)) + { + Mat tmp; + + op_hist::apply_noalias(tmp, U.M, n_bins, dim); + + out.steal_mem(tmp); + } + else + { + op_hist::apply_noalias(out, U.M, n_bins, dim); + } + } + else + { + Mat converted = conv_to< Mat >::from(U.M); + + op_hist::apply_noalias(out, converted, n_bins, dim); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_htrans_bones.hpp b/src/armadillo/include/armadillo_bits/op_htrans_bones.hpp new file mode 100644 index 0000000..c10f624 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_htrans_bones.hpp @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_htrans +//! @{ + + +//! 'hermitian transpose' operation + +class op_htrans + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + arma_hot inline static void apply_mat_noalias(Mat& out, const Mat& A, const typename arma_not_cx::result* junk = nullptr); + + template + arma_hot inline static void apply_mat_noalias(Mat& out, const Mat& A, const typename arma_cx_only::result* junk = nullptr); + + // + + template + arma_hot inline static void block_worker(std::complex* Y, const std::complex* X, const uword X_n_rows, const uword Y_n_rows, const uword n_rows, const uword n_cols); + + template + arma_hot inline static void apply_mat_noalias_large(Mat< std::complex >& out, const Mat< std::complex >& A); + + // + + template + arma_hot inline static void apply_mat_inplace(Mat& out, const typename arma_not_cx::result* junk = nullptr); + + template + arma_hot inline static void apply_mat_inplace(Mat& out, const typename arma_cx_only::result* junk = nullptr); + + // + + template + inline static void apply_mat(Mat& out, const Mat& A, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_mat(Mat& out, const Mat& A, const typename arma_cx_only::result* junk = nullptr); + + // + + template + inline static void apply_proxy(Mat& out, const Proxy& P); + + // + + template + inline static void apply_direct(Mat& out, const T1& X); + + template + inline static void apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +class op_htrans2 + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + inline static void apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_htrans_meat.hpp b/src/armadillo/include/armadillo_bits/op_htrans_meat.hpp new file mode 100644 index 0000000..e03893c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_htrans_meat.hpp @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_htrans +//! @{ + + + +template +inline +void +op_htrans::apply_mat_noalias(Mat& out, const Mat& A, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_strans::apply_mat_noalias(out, A); + } + + + +template +inline +void +op_htrans::apply_mat_noalias(Mat& out, const Mat& A, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + out.set_size(A_n_cols, A_n_rows); + + if( (A_n_cols == 1) || (A_n_rows == 1) ) + { + const uword n_elem = A.n_elem; + + const eT* A_mem = A.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::conj(A_mem[i]); + } + } + else + if( (A_n_rows >= 512) && (A_n_cols >= 512) ) + { + op_htrans::apply_mat_noalias_large(out, A); + } + else + { + eT* outptr = out.memptr(); + + for(uword k=0; k < A_n_rows; ++k) + { + const eT* Aptr = &(A.at(k,0)); + + for(uword j=0; j < A_n_cols; ++j) + { + (*outptr) = std::conj(*Aptr); + + Aptr += A_n_rows; + outptr++; + } + } + } + } + + + +template +inline +void +op_htrans::block_worker(std::complex* Y, const std::complex* X, const uword X_n_rows, const uword Y_n_rows, const uword n_rows, const uword n_cols) + { + for(uword row = 0; row < n_rows; ++row) + { + const uword Y_offset = row * Y_n_rows; + + for(uword col = 0; col < n_cols; ++col) + { + const uword X_offset = col * X_n_rows; + + Y[col + Y_offset] = std::conj(X[row + X_offset]); + } + } + } + + + +template +inline +void +op_htrans::apply_mat_noalias_large(Mat< std::complex >& out, const Mat< std::complex >& A) + { + arma_extra_debug_sigprint(); + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + const uword block_size = 64; + + const uword n_rows_base = block_size * (n_rows / block_size); + const uword n_cols_base = block_size * (n_cols / block_size); + + const uword n_rows_extra = n_rows - n_rows_base; + const uword n_cols_extra = n_cols - n_cols_base; + + const std::complex* X = A.memptr(); + std::complex* Y = out.memptr(); + + for(uword row = 0; row < n_rows_base; row += block_size) + { + const uword Y_offset = row * n_cols; + + for(uword col = 0; col < n_cols_base; col += block_size) + { + const uword X_offset = col * n_rows; + + op_htrans::block_worker(&Y[col + Y_offset], &X[row + X_offset], n_rows, n_cols, block_size, block_size); + } + + const uword X_offset = n_cols_base * n_rows; + + op_htrans::block_worker(&Y[n_cols_base + Y_offset], &X[row + X_offset], n_rows, n_cols, block_size, n_cols_extra); + } + + if(n_rows_extra == 0) { return; } + + const uword Y_offset = n_rows_base * n_cols; + + for(uword col = 0; col < n_cols_base; col += block_size) + { + const uword X_offset = col * n_rows; + + op_htrans::block_worker(&Y[col + Y_offset], &X[n_rows_base + X_offset], n_rows, n_cols, n_rows_extra, block_size); + } + + const uword X_offset = n_cols_base * n_rows; + + op_htrans::block_worker(&Y[n_cols_base + Y_offset], &X[n_rows_base + X_offset], n_rows, n_cols, n_rows_extra, n_cols_extra); + } + + + +template +inline +void +op_htrans::apply_mat_inplace(Mat& out, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_strans::apply_mat_inplace(out); + } + + + +template +inline +void +op_htrans::apply_mat_inplace(Mat& out, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + if(n_rows == n_cols) + { + arma_extra_debug_print("doing in-place hermitian transpose of a square matrix"); + + for(uword col=0; col < n_cols; ++col) + { + eT* coldata = out.colptr(col); + + out.at(col,col) = std::conj( out.at(col,col) ); + + for(uword row=(col+1); row < n_rows; ++row) + { + const eT val1 = std::conj(coldata[row]); + const eT val2 = std::conj(out.at(col,row)); + + out.at(col,row) = val1; + coldata[row] = val2; + } + } + } + else + { + Mat tmp; + + op_htrans::apply_mat_noalias(tmp, out); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_htrans::apply_mat(Mat& out, const Mat& A, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_strans::apply_mat(out, A); + } + + + +template +inline +void +op_htrans::apply_mat(Mat& out, const Mat& A, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(&out != &A) + { + op_htrans::apply_mat_noalias(out, A); + } + else + { + op_htrans::apply_mat_inplace(out); + } + } + + + +template +inline +void +op_htrans::apply_proxy(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) + { + out.set_size(n_cols, n_rows); + + eT* out_mem = out.memptr(); + + const uword n_elem = P.get_n_elem(); + + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::conj(Pea[i]); + } + } + else + { + out.set_size(n_cols, n_rows); + + eT* outptr = out.memptr(); + + for(uword k=0; k < n_rows; ++k) + { + for(uword j=0; j < n_cols; ++j) + { + (*outptr) = std::conj(P.at(k,j)); + + outptr++; + } + } + } + } + + + +template +inline +void +op_htrans::apply_direct(Mat& out, const T1& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // allow detection of in-place transpose + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) + { + const unwrap U(X); + + op_htrans::apply_mat(out, U.M); + } + else + { + const Proxy P(X); + + const bool is_alias = P.is_alias(out); + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap::stored_type> U(P.Q); + + if(is_alias) + { + Mat tmp; + + op_htrans::apply_mat_noalias(tmp, U.M); + + out.steal_mem(tmp); + } + else + { + op_htrans::apply_mat_noalias(out, U.M); + } + } + else + { + if(is_alias) + { + Mat tmp; + + op_htrans::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_htrans::apply_proxy(out, P); + } + } + } + } + + + +template +inline +void +op_htrans::apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_strans::apply_direct(out, in.m); + } + + + +template +inline +void +op_htrans::apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_htrans::apply_direct(out, in.m); + } + + + +// +// op_htrans2 + + + +template +inline +void +op_htrans2::apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_strans::apply_direct(out, in.m); + + arrayops::inplace_mul(out.memptr(), in.aux, out.n_elem); + } + + + +template +inline +void +op_htrans2::apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_htrans::apply_direct(out, in.m); + + arrayops::inplace_mul(out.memptr(), in.aux, out.n_elem); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_index_max_bones.hpp b/src/armadillo/include/armadillo_bits/op_index_max_bones.hpp new file mode 100644 index 0000000..d226f22 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_index_max_bones.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_index_max +//! @{ + + +class op_index_max + : public traits_op_xvec + { + public: + + // dense matrices + + template + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim); + + + // cubes + + template + inline static void apply(Cube& out, const mtOpCube& in); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + + // sparse matrices + + template + inline static void apply(Mat& out, const SpBase& expr, const uword dim); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_index_max_meat.hpp b/src/armadillo/include/armadillo_bits/op_index_max_meat.hpp new file mode 100644 index 0000000..5034921 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_index_max_meat.hpp @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_index_max +//! @{ + + + +template +inline +void +op_index_max::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "index_max(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + if(U.is_alias(out) == false) + { + op_index_max::apply_noalias(out, X, dim); + } + else + { + Mat tmp; + + op_index_max::apply_noalias(tmp, X, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_index_max::apply_noalias(Mat& out, const Mat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_index_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + uword* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + op_max::direct_max( X.colptr(col), X_n_rows, out_mem[col] ); + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_max::apply(): dim = 1"); + + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + uword* out_mem = out.memptr(); + + Col tmp(X_n_rows, arma_nozeros_indicator()); + + T* tmp_mem = tmp.memptr(); + + if(is_cx::yes) + { + const eT* col_mem = X.colptr(0); + + for(uword row=0; row < X_n_rows; ++row) + { + tmp_mem[row] = eop_aux::arma_abs(col_mem[row]); + } + } + else + { + arrayops::copy(tmp_mem, (T*)(X.colptr(0)), X_n_rows); + } + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + T& max_val = tmp_mem[row]; + T col_val = (is_cx::yes) ? T(eop_aux::arma_abs(col_mem[row])) : T(access::tmp_real(col_mem[row])); + + if(max_val < col_val) + { + max_val = col_val; + + out_mem[row] = col; + } + } + } + } + } + + + +template +inline +void +op_index_max::apply(Cube& out, const mtOpCube& in) + { + arma_extra_debug_sigprint(); + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 2), "index_max(): parameter 'dim' must be 0 or 1 or 2" ); + + const unwrap_cube U(in.m); + + if(U.is_alias(out) == false) + { + op_index_max::apply_noalias(out, U.M, dim); + } + else + { + Cube tmp; + + op_index_max::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_index_max::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_index_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + op_max::direct_max( X.slice_colptr(slice,col), X_n_rows, out_mem[col] ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_max::apply(): dim = 1"); + + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + Col tmp(X_n_rows, arma_nozeros_indicator()); + + eT* tmp_mem = tmp.memptr(); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + arrayops::copy(tmp_mem, X.slice_colptr(slice,0), X_n_rows); + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.slice_colptr(slice,col); + + for(uword row=0; row < X_n_rows; ++row) + { + const eT val = col_mem[row]; + + if(val > tmp_mem[row]) + { + tmp_mem[row] = val; + out_mem[row] = col; + } + } + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_index_max::apply(): dim = 2"); + + out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(out.is_empty() || X.is_empty()) { return; } + + Mat tmp(X.slice_memptr(0), X_n_rows, X_n_cols); // copy slice 0 + + eT* tmp_mem = tmp.memptr(); + uword* out_mem = out.memptr(); + + const uword N = X.n_elem_slice; + + for(uword slice=1; slice < X_n_slices; ++slice) + { + const eT* X_slice_mem = X.slice_memptr(slice); + + for(uword i=0; i < N; ++i) + { + const eT val = X_slice_mem[i]; + + if(val > tmp_mem[i]) + { + tmp_mem[i] = val; + out_mem[i] = slice; + } + } + } + } + } + + + +template +inline +void +op_index_max::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_index_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + op_max::direct_max( X.slice_colptr(slice,col), X_n_rows, out_mem[col] ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_max::apply(): dim = 1"); + + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + Col tmp(X_n_rows, arma_nozeros_indicator()); + + T* tmp_mem = tmp.memptr(); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + const eT* col0_mem = X.slice_colptr(slice,0); + + for(uword row=0; row < X_n_rows; ++row) + { + tmp_mem[row] = std::abs( col0_mem[row] ); + } + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.slice_colptr(slice,col); + + for(uword row=0; row < X_n_rows; ++row) + { + const T val = std::abs( col_mem[row] ); + + if(val > tmp_mem[row]) + { + tmp_mem[row] = val; + out_mem[row] = col; + } + } + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_index_max::apply(): dim = 2"); + + out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(out.is_empty() || X.is_empty()) { return; } + + uword* out_mem = out.memptr(); + + Mat tmp(X_n_rows, X_n_cols, arma_nozeros_indicator()); + + T* tmp_mem = tmp.memptr(); + const eT* X_slice0_mem = X.slice_memptr(0); + + const uword N = X.n_elem_slice; + + for(uword i=0; i tmp_mem[i]) + { + tmp_mem[i] = val; + out_mem[i] = slice; + } + } + } + } + } + + + +template +inline +void +op_index_max::apply(Mat& out, const SpBase& expr, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_check( (dim > 1), "index_max(): parameter 'dim' must be 0 or 1" ); + + const unwrap_spmat U(expr.get_ref()); + const SpMat& X = U.M; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_index_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + uword* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = X.col(col).index_max(); + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_max::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + uword* out_mem = out.memptr(); + + const SpMat Xt = X.st(); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] = Xt.col(row).index_max(); + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_index_min_bones.hpp b/src/armadillo/include/armadillo_bits/op_index_min_bones.hpp new file mode 100644 index 0000000..050b8c0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_index_min_bones.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_index_min +//! @{ + + +class op_index_min + : public traits_op_xvec + { + public: + + // dense matrices + + template + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim); + + + // cubes + + template + inline static void apply(Cube& out, const mtOpCube& in); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + + // sparse matrices + + template + inline static void apply(Mat& out, const SpBase& expr, const uword dim); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_index_min_meat.hpp b/src/armadillo/include/armadillo_bits/op_index_min_meat.hpp new file mode 100644 index 0000000..13162ab --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_index_min_meat.hpp @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_index_min +//! @{ + + + +template +inline +void +op_index_min::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "index_min(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + if(U.is_alias(out) == false) + { + op_index_min::apply_noalias(out, X, dim); + } + else + { + Mat tmp; + + op_index_min::apply_noalias(tmp, X, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_index_min::apply_noalias(Mat& out, const Mat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_index_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + uword* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + op_min::direct_min( X.colptr(col), X_n_rows, out_mem[col] ); + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_min::apply(): dim = 1"); + + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + uword* out_mem = out.memptr(); + + Col tmp(X_n_rows, arma_nozeros_indicator()); + + T* tmp_mem = tmp.memptr(); + + if(is_cx::yes) + { + const eT* col_mem = X.colptr(0); + + for(uword row=0; row < X_n_rows; ++row) + { + tmp_mem[row] = eop_aux::arma_abs(col_mem[row]); + } + } + else + { + arrayops::copy(tmp_mem, (T*)(X.colptr(0)), X_n_rows); + } + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + T& min_val = tmp_mem[row]; + T col_val = (is_cx::yes) ? T(eop_aux::arma_abs(col_mem[row])) : T(access::tmp_real(col_mem[row])); + + if(min_val > col_val) + { + min_val = col_val; + + out_mem[row] = col; + } + } + } + } + } + + + +template +inline +void +op_index_min::apply(Cube& out, const mtOpCube& in) + { + arma_extra_debug_sigprint(); + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 2), "index_min(): parameter 'dim' must be 0 or 1 or 2" ); + + const unwrap_cube U(in.m); + + if(U.is_alias(out) == false) + { + op_index_min::apply_noalias(out, U.M, dim); + } + else + { + Cube tmp; + + op_index_min::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_index_min::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_index_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + op_min::direct_min( X.slice_colptr(slice,col), X_n_rows, out_mem[col] ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_min::apply(): dim = 1"); + + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + Col tmp(X_n_rows, arma_nozeros_indicator()); + + eT* tmp_mem = tmp.memptr(); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + arrayops::copy(tmp_mem, X.slice_colptr(slice,0), X_n_rows); + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.slice_colptr(slice,col); + + for(uword row=0; row < X_n_rows; ++row) + { + const eT val = col_mem[row]; + + if(val < tmp_mem[row]) + { + tmp_mem[row] = val; + out_mem[row] = col; + } + } + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_index_min::apply(): dim = 2"); + + out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(out.is_empty() || X.is_empty()) { return; } + + Mat tmp(X.slice_memptr(0), X_n_rows, X_n_cols); // copy slice 0 + + eT* tmp_mem = tmp.memptr(); + uword* out_mem = out.memptr(); + + const uword N = X.n_elem_slice; + + for(uword slice=1; slice < X_n_slices; ++slice) + { + const eT* X_slice_mem = X.slice_memptr(slice); + + for(uword i=0; i < N; ++i) + { + const eT val = X_slice_mem[i]; + + if(val < tmp_mem[i]) + { + tmp_mem[i] = val; + out_mem[i] = slice; + } + } + } + } + } + + + +template +inline +void +op_index_min::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_index_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + op_min::direct_min( X.slice_colptr(slice,col), X_n_rows, out_mem[col] ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_min::apply(): dim = 1"); + + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(out.is_empty() || X.is_empty()) { return; } + + Col tmp(X_n_rows, arma_nozeros_indicator()); + + T* tmp_mem = tmp.memptr(); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + uword* out_mem = out.slice_memptr(slice); + + const eT* col0_mem = X.slice_colptr(slice,0); + + for(uword row=0; row < X_n_rows; ++row) + { + tmp_mem[row] = std::abs( col0_mem[row] ); + } + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.slice_colptr(slice,col); + + for(uword row=0; row < X_n_rows; ++row) + { + const T val = std::abs( col_mem[row] ); + + if(val < tmp_mem[row]) + { + tmp_mem[row] = val; + out_mem[row] = col; + } + } + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_index_min::apply(): dim = 2"); + + out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(out.is_empty() || X.is_empty()) { return; } + + uword* out_mem = out.memptr(); + + Mat tmp(X_n_rows, X_n_cols, arma_nozeros_indicator()); + + T* tmp_mem = tmp.memptr(); + const eT* X_slice0_mem = X.slice_memptr(0); + + const uword N = X.n_elem_slice; + + for(uword i=0; i +inline +void +op_index_min::apply(Mat& out, const SpBase& expr, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_check( (dim > 1), "index_min(): parameter 'dim' must be 0 or 1" ); + + const unwrap_spmat U(expr.get_ref()); + const SpMat& X = U.M; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_index_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + uword* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = X.col(col).index_min(); + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_index_min::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + uword* out_mem = out.memptr(); + + const SpMat Xt = X.st(); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] = Xt.col(row).index_min(); + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_inv_gen_bones.hpp b/src/armadillo/include/armadillo_bits/op_inv_gen_bones.hpp new file mode 100644 index 0000000..fe952ed --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_inv_gen_bones.hpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_gen +//! @{ + + + +class op_inv_gen_default + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr, const char* caller_sig); + }; + + + +class op_inv_gen_full + : public traits_op_default + { + public: + + template + struct pos + { + static constexpr uword n2 = row + col*2; + static constexpr uword n3 = row + col*3; + }; + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr, const char* caller_sig, const uword flags); + + template + arma_cold inline static bool apply_tiny_2x2(Mat& X); + + template + arma_cold inline static bool apply_tiny_3x3(Mat& X); + }; + + + +template +struct op_inv_gen_state + { + uword size = uword(0); + T rcond = T(0); + bool is_diag = false; + bool is_sym = false; + }; + + + +class op_inv_gen_rcond + : public traits_op_default + { + public: + + template + inline static bool apply_direct(Mat& out_inv, op_inv_gen_state& out_state, const Base& expr); + }; + + + +namespace inv_opts + { + struct opts + { + const uword flags; + + inline constexpr explicit opts(const uword in_flags); + + inline const opts operator+(const opts& rhs) const; + }; + + inline + constexpr + opts::opts(const uword in_flags) + : flags(in_flags) + {} + + inline + const opts + opts::operator+(const opts& rhs) const + { + const opts result( flags | rhs.flags ); + + return result; + } + + // The values below (eg. 1u << 1) are for internal Armadillo use only. + // The values can change without notice. + + static constexpr uword flag_none = uword(0 ); + static constexpr uword flag_fast = uword(1u << 0); + static constexpr uword flag_tiny = uword(1u << 0); // deprecated + static constexpr uword flag_allow_approx = uword(1u << 1); + static constexpr uword flag_likely_sympd = uword(1u << 2); // deprecated + static constexpr uword flag_no_sympd = uword(1u << 3); // deprecated + static constexpr uword flag_no_ugly = uword(1u << 4); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_fast : public opts { inline constexpr opts_fast() : opts(flag_fast ) {} }; + struct opts_tiny : public opts { inline constexpr opts_tiny() : opts(flag_tiny ) {} }; + struct opts_allow_approx : public opts { inline constexpr opts_allow_approx() : opts(flag_allow_approx) {} }; + struct opts_likely_sympd : public opts { inline constexpr opts_likely_sympd() : opts(flag_likely_sympd) {} }; + struct opts_no_sympd : public opts { inline constexpr opts_no_sympd() : opts(flag_no_sympd ) {} }; + struct opts_no_ugly : public opts { inline constexpr opts_no_ugly() : opts(flag_no_ugly ) {} }; + + static constexpr opts_none none; + static constexpr opts_fast fast; + static constexpr opts_tiny tiny; + static constexpr opts_allow_approx allow_approx; + static constexpr opts_likely_sympd likely_sympd; + static constexpr opts_no_sympd no_sympd; + static constexpr opts_no_ugly no_ugly; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_inv_gen_meat.hpp b/src/armadillo/include/armadillo_bits/op_inv_gen_meat.hpp new file mode 100644 index 0000000..a7585d7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_inv_gen_meat.hpp @@ -0,0 +1,428 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_gen +//! @{ + + + +template +inline +void +op_inv_gen_default::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_gen_default::apply_direct(out, X.m, "inv()"); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv(): matrix is singular"); + } + } + + + +template +inline +bool +op_inv_gen_default::apply_direct(Mat& out, const Base& expr, const char* caller_sig) + { + arma_extra_debug_sigprint(); + + return op_inv_gen_full::apply_direct(out, expr, caller_sig, uword(0)); + } + + + +// + + + +template +inline +void +op_inv_gen_full::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const uword flags = X.aux_uword_a; + + const bool status = op_inv_gen_full::apply_direct(out, X.m, "inv()", flags); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv(): matrix is singular"); + } + } + + + +template +inline +bool +op_inv_gen_full::apply_direct(Mat& out, const Base& expr, const char* caller_sig, const uword flags) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(has_user_flags == true ) { arma_extra_debug_print("op_inv_gen_full: has_user_flags = true"); } + if(has_user_flags == false) { arma_extra_debug_print("op_inv_gen_full: has_user_flags = false"); } + + const bool fast = has_user_flags && bool(flags & inv_opts::flag_fast ); + const bool allow_approx = has_user_flags && bool(flags & inv_opts::flag_allow_approx); + const bool no_ugly = has_user_flags && bool(flags & inv_opts::flag_no_ugly ); + + if(has_user_flags) + { + arma_extra_debug_print("op_inv_gen_full: enabled flags:"); + + if(fast ) { arma_extra_debug_print("fast"); } + if(allow_approx) { arma_extra_debug_print("allow_approx"); } + if(no_ugly ) { arma_extra_debug_print("no_ugly"); } + + arma_debug_check( (fast && allow_approx), "inv(): options 'fast' and 'allow_approx' are mutually exclusive" ); + arma_debug_check( (fast && no_ugly ), "inv(): options 'fast' and 'no_ugly' are mutually exclusive" ); + arma_debug_check( (no_ugly && allow_approx), "inv(): options 'no_ugly' and 'allow_approx' are mutually exclusive" ); + } + + if(no_ugly) + { + op_inv_gen_state inv_state; + + const bool status = op_inv_gen_rcond::apply_direct(out, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) { return false; } + + return true; + } + + if(allow_approx) + { + op_inv_gen_state inv_state; + + Mat tmp; + + const bool status = op_inv_gen_rcond::apply_direct(tmp, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) + { + Mat A = expr.get_ref(); + + if(inv_state.is_diag) { return op_pinv::apply_diag(out, A, T(0) ); } + if(inv_state.is_sym ) { return op_pinv::apply_sym (out, A, T(0), uword(0)); } + + return op_pinv::apply_gen(out, A, T(0), uword(0)); + } + + out.steal_mem(tmp); + + return true; + } + + out = expr.get_ref(); + + arma_debug_check( (out.is_square() == false), caller_sig, ": given matrix must be square sized", [&](){ out.soft_reset(); } ); + + const uword N = out.n_rows; + + if(N == 0) { return true; } + + if(is_cx::no) + { + if(N == 1) + { + const eT a = out[0]; + + out[0] = eT(1) / a; + + return (a != eT(0)); + } + else + if(N == 2) + { + const bool status = op_inv_gen_full::apply_tiny_2x2(out); + + if(status) { return true; } + } + else + if(N == 3) + { + const bool status = op_inv_gen_full::apply_tiny_3x3(out); + + if(status) { return true; } + } + + // fallthrough if optimisation failed + } + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_gen_full: detected diagonal matrix"); + + eT* colmem = out.memptr(); + + for(uword i=0; i strip(expr.get_ref()); + + const bool is_triu_expr = strip.do_triu; + const bool is_tril_expr = strip.do_tril; + + const bool is_triu_mat = (is_triu_expr || is_tril_expr) ? false : ( trimat_helper::is_triu(out)); + const bool is_tril_mat = (is_triu_expr || is_tril_expr) ? false : ((is_triu_mat) ? false : trimat_helper::is_tril(out)); + + if(is_triu_expr || is_tril_expr || is_triu_mat || is_tril_mat) + { + return auxlib::inv_tr(out, ((is_triu_expr || is_triu_mat) ? uword(0) : uword(1))); + } + + const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(out); + + if(try_sympd) + { + arma_extra_debug_print("op_inv_gen_full: attempting sympd optimisation"); + + Mat tmp = out; + + bool sympd_state = false; + + const bool status = auxlib::inv_sympd(tmp, sympd_state); + + if(status) { out.steal_mem(tmp); return true; } + + if((status == false) && (sympd_state == true)) { return false; } + + arma_extra_debug_print("op_inv_gen_full: sympd optimisation failed"); + + // fallthrough if optimisation failed + } + + return auxlib::inv(out); + } + + + +template +inline +bool +op_inv_gen_full::apply_tiny_2x2(Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + // NOTE: assuming matrix X is square sized + + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + eT* Xm = X.memptr(); + + const eT a = Xm[pos<0,0>::n2]; + const eT b = Xm[pos<0,1>::n2]; + const eT c = Xm[pos<1,0>::n2]; + const eT d = Xm[pos<1,1>::n2]; + + const eT det_val = (a*d - b*c); + const T abs_det_val = std::abs(det_val); + + if((abs_det_val < det_min) || (abs_det_val > det_max) || arma_isnan(det_val)) { return false; } + + Xm[pos<0,0>::n2] = d / det_val; + Xm[pos<0,1>::n2] = -b / det_val; + Xm[pos<1,0>::n2] = -c / det_val; + Xm[pos<1,1>::n2] = a / det_val; + + return true; + } + + + +template +inline +bool +op_inv_gen_full::apply_tiny_3x3(Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + // NOTE: assuming matrix X is square sized + + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + Mat Y(3, 3, arma_nozeros_indicator()); + + eT* Xm = X.memptr(); + eT* Ym = Y.memptr(); + + const eT det_val = op_det::apply_tiny_3x3(X); + const T abs_det_val = std::abs(det_val); + + if((abs_det_val < det_min) || (abs_det_val > det_max) || arma_isnan(det_val)) { return false; } + + Ym[pos<0,0>::n3] = (Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]) / det_val; + Ym[pos<1,0>::n3] = -(Xm[pos<2,2>::n3]*Xm[pos<1,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<1,2>::n3]) / det_val; + Ym[pos<2,0>::n3] = (Xm[pos<2,1>::n3]*Xm[pos<1,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<1,1>::n3]) / det_val; + + Ym[pos<0,1>::n3] = -(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<1,1>::n3] = (Xm[pos<2,2>::n3]*Xm[pos<0,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<2,1>::n3] = -(Xm[pos<2,1>::n3]*Xm[pos<0,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<0,1>::n3]) / det_val; + + Ym[pos<0,2>::n3] = (Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<1,2>::n3] = -(Xm[pos<1,2>::n3]*Xm[pos<0,0>::n3] - Xm[pos<1,0>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<2,2>::n3] = (Xm[pos<1,1>::n3]*Xm[pos<0,0>::n3] - Xm[pos<1,0>::n3]*Xm[pos<0,1>::n3]) / det_val; + + const eT check_val = Xm[pos<0,0>::n3]*Ym[pos<0,0>::n3] + Xm[pos<0,1>::n3]*Ym[pos<1,0>::n3] + Xm[pos<0,2>::n3]*Ym[pos<2,0>::n3]; + + const T max_diff = (is_float::value) ? T(1e-4) : T(1e-10); // empirically determined; may need tuning + + if(std::abs(T(1) - check_val) >= max_diff) { return false; } + + arrayops::copy(Xm, Ym, uword(3*3)); + + return true; + } + + + +template +inline +bool +op_inv_gen_rcond::apply_direct(Mat& out, op_inv_gen_state& out_state, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out = expr.get_ref(); + out_state.size = out.n_rows; + out_state.rcond = T(0); + + arma_debug_check( (out.is_square() == false), "inv(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_gen_rcond: detected diagonal matrix"); + + out_state.is_diag = true; + + eT* colmem = out.memptr(); + + T max_abs_src_val = T(0); + T max_abs_inv_val = T(0); + + const uword N = out.n_rows; + + for(uword i=0; i max_abs_src_val) ? abs_src_val : max_abs_src_val; + max_abs_inv_val = (abs_inv_val > max_abs_inv_val) ? abs_inv_val : max_abs_inv_val; + + colmem += N; + } + + out_state.rcond = T(1) / (max_abs_src_val * max_abs_inv_val); + + return true; + } + + const strip_trimat strip(expr.get_ref()); + + const bool is_triu_expr = strip.do_triu; + const bool is_tril_expr = strip.do_tril; + + const bool is_triu_mat = (is_triu_expr || is_tril_expr) ? false : ( trimat_helper::is_triu(out)); + const bool is_tril_mat = (is_triu_expr || is_tril_expr) ? false : ((is_triu_mat) ? false : trimat_helper::is_tril(out)); + + if(is_triu_expr || is_tril_expr || is_triu_mat || is_tril_mat) + { + return auxlib::inv_tr_rcond(out, out_state.rcond, ((is_triu_expr || is_triu_mat) ? uword(0) : uword(1))); + } + + const bool try_sympd = arma_config::optimise_sym && ((auxlib::crippled_lapack(out)) ? false : sym_helper::guess_sympd(out)); + + if(try_sympd) + { + arma_extra_debug_print("op_inv_gen_rcond: attempting sympd optimisation"); + + out_state.is_sym = true; + + Mat tmp = out; + + bool sympd_state = false; + + const bool status = auxlib::inv_sympd_rcond(tmp, sympd_state, out_state.rcond); + + if(status) { out.steal_mem(tmp); return true; } + + if((status == false) && (sympd_state == true)) { return false; } + + arma_extra_debug_print("op_inv_gen_rcond: sympd optimisation failed"); + + // fallthrough if optimisation failed + } + + return auxlib::inv_rcond(out, out_state.rcond); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_inv_spd_bones.hpp b/src/armadillo/include/armadillo_bits/op_inv_spd_bones.hpp new file mode 100644 index 0000000..85a5013 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_inv_spd_bones.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_spd +//! @{ + + + +class op_inv_spd_default + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr); + }; + + + +class op_inv_spd_full + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr, const uword flags); + + template + arma_cold inline static bool apply_tiny_2x2(Mat& X); + }; + + + +template +struct op_inv_spd_state + { + uword size = uword(0); + T rcond = T(0); + bool is_diag = false; + }; + + + +class op_inv_spd_rcond + : public traits_op_default + { + public: + + template + inline static bool apply_direct(Mat& out_inv, op_inv_spd_state& out_state, const Base& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_inv_spd_meat.hpp b/src/armadillo/include/armadillo_bits/op_inv_spd_meat.hpp new file mode 100644 index 0000000..0c60974 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_inv_spd_meat.hpp @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_spd +//! @{ + + + +template +inline +void +op_inv_spd_default::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_spd_default::apply_direct(out, X.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite"); + } + } + + + +template +inline +bool +op_inv_spd_default::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + return op_inv_spd_full::apply_direct(out, expr, uword(0)); + } + + + +// + + + +template +inline +void +op_inv_spd_full::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const uword flags = X.aux_uword_a; + + const bool status = op_inv_spd_full::apply_direct(out, X.m, flags); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite"); + } + } + + + +template +inline +bool +op_inv_spd_full::apply_direct(Mat& out, const Base& expr, const uword flags) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(has_user_flags == true ) { arma_extra_debug_print("op_inv_spd_full: has_user_flags = true"); } + if(has_user_flags == false) { arma_extra_debug_print("op_inv_spd_full: has_user_flags = false"); } + + const bool fast = has_user_flags && bool(flags & inv_opts::flag_fast ); + const bool allow_approx = has_user_flags && bool(flags & inv_opts::flag_allow_approx); + const bool no_ugly = has_user_flags && bool(flags & inv_opts::flag_no_ugly ); + + if(has_user_flags) + { + arma_extra_debug_print("op_inv_spd_full: enabled flags:"); + + if(fast ) { arma_extra_debug_print("fast"); } + if(allow_approx) { arma_extra_debug_print("allow_approx"); } + if(no_ugly ) { arma_extra_debug_print("no_ugly"); } + + arma_debug_check( (fast && allow_approx), "inv_sympd(): options 'fast' and 'allow_approx' are mutually exclusive" ); + arma_debug_check( (fast && no_ugly ), "inv_sympd(): options 'fast' and 'no_ugly' are mutually exclusive" ); + arma_debug_check( (no_ugly && allow_approx), "inv_sympd(): options 'no_ugly' and 'allow_approx' are mutually exclusive" ); + } + + if(no_ugly) + { + op_inv_spd_state inv_state; + + const bool status = op_inv_spd_rcond::apply_direct(out, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) { return false; } + + return true; + } + + if(allow_approx) + { + op_inv_spd_state inv_state; + + Mat tmp; + + const bool status = op_inv_spd_rcond::apply_direct(tmp, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) + { + const Mat A = expr.get_ref(); + + if(inv_state.is_diag) { return op_pinv::apply_diag(out, A, T(0)); } + + return op_pinv::apply_sym(out, A, T(0), uword(0)); + } + + out.steal_mem(tmp); + + return true; + } + + out = expr.get_ref(); + + arma_debug_check( (out.is_square() == false), "inv_sympd(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if((arma_config::debug) && (arma_config::warn_level > 0)) + { + if(auxlib::rudimentary_sym_check(out) == false) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + else + if((is_cx::yes) && (sym_helper::check_diag_imag(out) == false)) + { + arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); + } + } + + const uword N = out.n_rows; + + if(N == 0) { return true; } + + if(is_cx::no) + { + if(N == 1) + { + const T a = access::tmp_real(out[0]); + + out[0] = eT(T(1) / a); + + return (a > T(0)); + } + else + if(N == 2) + { + const bool status = op_inv_spd_full::apply_tiny_2x2(out); + + if(status) { return true; } + } + + // fallthrough if optimisation failed + } + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_spd_full: detected diagonal matrix"); + + eT* colmem = out.memptr(); + + for(uword i=0; i +inline +bool +op_inv_spd_full::apply_tiny_2x2(Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + // NOTE: assuming matrix X is square sized + // NOTE: assuming matrix X is symmetric + // NOTE: assuming matrix X is real + + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + eT* Xm = X.memptr(); + + T a = access::tmp_real(Xm[0]); + T c = access::tmp_real(Xm[1]); + T d = access::tmp_real(Xm[3]); + + const T det_val = (a*d - c*c); + + // positive definite iff all leading principal minors are positive + // a = first leading principal minor (top-left 1x1 submatrix) + // det_val = second leading principal minor (top-left 2x2 submatrix) + + if(a <= T(0)) { return false; } + + // NOTE: since det_min is positive, this also checks whether det_val is positive + if((det_val < det_min) || (det_val > det_max) || arma_isnan(det_val)) { return false; } + + d /= det_val; + c /= det_val; + a /= det_val; + + Xm[0] = d; + Xm[1] = -c; + Xm[2] = -c; + Xm[3] = a; + + return true; + } + + + +// + + + +template +inline +bool +op_inv_spd_rcond::apply_direct(Mat& out, op_inv_spd_state& out_state, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out = expr.get_ref(); + out_state.size = out.n_rows; + out_state.rcond = T(0); + + arma_debug_check( (out.is_square() == false), "inv_sympd(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if((arma_config::debug) && (arma_config::warn_level > 0)) + { + if(auxlib::rudimentary_sym_check(out) == false) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + else + if((is_cx::yes) && (sym_helper::check_diag_imag(out) == false)) + { + arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); + } + } + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_spd_rcond: detected diagonal matrix"); + + out_state.is_diag = true; + + eT* colmem = out.memptr(); + + T max_abs_src_val = T(0); + T max_abs_inv_val = T(0); + + const uword N = out.n_rows; + + for(uword i=0; i max_abs_src_val) ? abs_src_val : max_abs_src_val; + max_abs_inv_val = (abs_inv_val > max_abs_inv_val) ? abs_inv_val : max_abs_inv_val; + + colmem += N; + } + + out_state.rcond = T(1) / (max_abs_src_val * max_abs_inv_val); + + return true; + } + + if(auxlib::crippled_lapack(out)) + { + arma_extra_debug_print("op_inv_spd_rcond: workaround for crippled lapack"); + + Mat tmp = out; + + bool sympd_state = false; + + auxlib::inv_sympd(out, sympd_state); + + if(sympd_state == false) { out.soft_reset(); out_state.rcond = T(0); return false; } + + out_state.rcond = auxlib::rcond(tmp); + + if(out_state.rcond == T(0)) { out.soft_reset(); return false; } + + return true; + } + + bool is_sympd_junk = false; + + return auxlib::inv_sympd_rcond(out, is_sympd_junk, out_state.rcond); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_log_det_bones.hpp b/src/armadillo/include/armadillo_bits/op_log_det_bones.hpp new file mode 100644 index 0000000..e2f3daf --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_log_det_bones.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_log_det +//! @{ + + + +class op_log_det + : public traits_op_default + { + public: + + template + inline static bool apply_direct(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr); + + template + inline static bool apply_diagmat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr); + + template + inline static bool apply_trimat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr); + }; + + + +class op_log_det_sympd + : public traits_op_default + { + public: + + template + inline static bool apply_direct(typename T1::pod_type& out_val, const Base& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_log_det_meat.hpp b/src/armadillo/include/armadillo_bits/op_log_det_meat.hpp new file mode 100644 index 0000000..7b88859 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_log_det_meat.hpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_log_det +//! @{ + + + +template +inline +bool +op_log_det::apply_direct(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + // typedef typename T1::pod_type T; + + if(strip_diagmat::do_diagmat) + { + const strip_diagmat strip(expr.get_ref()); + + return op_log_det::apply_diagmat(out_val, out_sign, strip.M); + } + + if(strip_trimat::do_trimat) + { + const strip_trimat strip(expr.get_ref()); + + return op_log_det::apply_trimat(out_val, out_sign, strip.M); + } + + Mat A(expr.get_ref()); + + arma_debug_check( (A.is_square() == false), "log_det(): given matrix must be square sized" ); + + if(A.is_diagmat()) { return op_log_det::apply_diagmat(out_val, out_sign, A); } + + const bool is_triu = trimat_helper::is_triu(A); + const bool is_tril = is_triu ? false : trimat_helper::is_tril(A); + + if(is_triu || is_tril) { return op_log_det::apply_trimat(out_val, out_sign, A); } + + // const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); + // + // if(try_sympd) + // { + // arma_extra_debug_print("op_log_det: attempting sympd optimisation"); + // + // T out_val_real = T(0); + // + // const bool status = auxlib::log_det_sympd(out_val_real, A); + // + // if(status) + // { + // out_val = eT(out_val_real); + // out_sign = T(1); + // + // return true; + // } + // + // arma_extra_debug_print("op_log_det: sympd optimisation failed"); + // + // // restore A as it's destroyed by auxlib::log_det_sympd() + // A = expr.get_ref(); + // + // // fallthrough to the next return statement + // } + + return auxlib::log_det(out_val, out_sign, A); + } + + + +template +inline +bool +op_log_det::apply_diagmat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const diagmat_proxy A(expr.get_ref()); + + arma_debug_check( (A.n_rows != A.n_cols), "log_det(): given matrix must be square sized" ); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + if(N == 0) + { + out_val = eT(0); + out_sign = T(1); + + return true; + } + + eT x = A[0]; + + T sign = (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + eT val = (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + + for(uword i=1; i::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + } + + out_val = val; + out_sign = sign; + + return (arma_isnan(out_val) == false); + } + + + +template +inline +bool +op_log_det::apply_trimat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const Proxy P(expr.get_ref()); + + const uword N = P.get_n_rows(); + + arma_debug_check( (N != P.get_n_cols()), "log_det(): given matrix must be square sized" ); + + if(N == 0) + { + out_val = eT(0); + out_sign = T(1); + + return true; + } + + eT x = P.at(0,0); + + T sign = (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + eT val = (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + + for(uword i=1; i::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + } + + out_val = val; + out_sign = sign; + + return (arma_isnan(out_val) == false); + } + + + +// + + + +template +inline +bool +op_log_det_sympd::apply_direct(typename T1::pod_type& out_val, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + Mat A(expr.get_ref()); + + arma_debug_check( (A.is_square() == false), "log_det_sympd(): given matrix must be square sized" ); + + if((arma_config::debug) && (arma_config::warn_level > 0) && (is_cx::yes) && (sym_helper::check_diag_imag(A) == false)) + { + arma_debug_warn_level(1, "log_det_sympd(): imaginary components on diagonal are non-zero"); + } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_log_det_sympd: detected diagonal matrix"); + + eT* colmem = A.memptr(); + + out_val = T(0); + + const uword N = A.n_rows; + + for(uword i=0; i::no ) { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not hermitian"); } + } + + return auxlib::log_det_sympd(out_val, A); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_logmat_bones.hpp b/src/armadillo/include/armadillo_bits/op_logmat_bones.hpp new file mode 100644 index 0000000..77e967d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_logmat_bones.hpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_logmat +//! @{ + + + +class op_logmat + : public traits_op_default + { + public: + + template + inline static void apply(Mat< std::complex >& out, const mtOp,T1,op_logmat>& in); + + template + inline static bool apply_direct(Mat< std::complex >& out, const Op& expr, const uword); + + template + inline static bool apply_direct(Mat< std::complex >& out, const Base& expr, const uword n_iters); + }; + + + +class op_logmat_cx + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Op& expr, const uword); + + template + inline static bool apply_direct_noalias(Mat& out, const diagmat_proxy& P); + + template + inline static bool apply_direct(Mat& out, const Base& expr, const uword n_iters); + + template + inline static bool apply_common(Mat< std::complex >& out, Mat< std::complex >& S, const uword n_iters); + + + template + inline static bool helper(Mat& S, const uword m); + }; + + + +class op_logmat_sympd + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_logmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_logmat_meat.hpp new file mode 100644 index 0000000..494f747 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_logmat_meat.hpp @@ -0,0 +1,572 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_logmat +//! @{ + + +// Partly based on algorithm 11.9 (inverse scaling and squaring algorithm with Schur decomposition) in: +// Nicholas J. Higham. +// Functions of Matrices: Theory and Computation. +// SIAM, 2008. +// ISBN 978-0-89871-646-7 + + +template +inline +void +op_logmat::apply(Mat< std::complex >& out, const mtOp,T1,op_logmat>& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_logmat::apply_direct(out, in.m, in.aux_uword_a); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("logmat(): transformation failed"); + } + } + + + +template +inline +bool +op_logmat::apply_direct(Mat< std::complex >& out, const Op& expr, const uword) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type T; + + const diagmat_proxy P(expr.m); + + arma_debug_check( (P.n_rows != P.n_cols), "logmat(): given matrix must be square sized" ); + + const uword N = P.n_rows; + + out.zeros(N,N); // aliasing can't happen as op_logmat is defined as cx_mat = op(mat) + + for(uword i=0; i= T(0)) + { + out.at(i,i) = std::log(val); + } + else + { + out.at(i,i) = std::log( std::complex(val) ); + } + } + + return true; + } + + + +template +inline +bool +op_logmat::apply_direct(Mat< std::complex >& out, const Base& expr, const uword n_iters) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_T; + typedef typename std::complex out_T; + + const quasi_unwrap expr_unwrap(expr.get_ref()); + const Mat& A = expr_unwrap.M; + + arma_debug_check( (A.is_square() == false), "logmat(): given matrix must be square sized" ); + + if(A.n_elem == 0) + { + out.reset(); + return true; + } + else + if(A.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::log( std::complex( A[0] ) ); + return true; + } + + if(A.is_diagmat()) + { + arma_extra_debug_print("op_logmat: detected diagonal matrix"); + + const uword N = A.n_rows; + + out.zeros(N,N); // aliasing can't happen as op_logmat is defined as cx_mat = op(mat) + + for(uword i=0; i= in_T(0)) + { + out.at(i,i) = std::log(val); + } + else + { + out.at(i,i) = std::log( out_T(val) ); + } + } + + return true; + } + + const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); + + if(try_sympd) + { + arma_extra_debug_print("op_logmat: attempting sympd optimisation"); + + // if matrix A is sympd, all its eigenvalues are positive + + Col eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, A, 'd', "logmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const in_T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i >::from( eigvec * diagmat(eigval) * eigvec.t() ); + + return true; + } + } + + arma_extra_debug_print("op_logmat: sympd optimisation failed"); + + // fallthrough if eigen decomposition failed or an eigenvalue is <= 0 + } + + + Mat S(A.n_rows, A.n_cols, arma_nozeros_indicator()); + + const in_T* Amem = A.memptr(); + out_T* Smem = S.memptr(); + + const uword n_elem = A.n_elem; + + for(uword i=0; i( Amem[i] ); + } + + return op_logmat_cx::apply_common(out, S, n_iters); + } + + + +template +inline +void +op_logmat_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_logmat_cx::apply_direct(out, in.m, in.aux_uword_a); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("logmat(): transformation failed"); + } + } + + + +template +inline +bool +op_logmat_cx::apply_direct(Mat& out, const Op& expr, const uword) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const diagmat_proxy P(expr.m); + + bool status = false; + + if(P.is_alias(out)) + { + Mat tmp; + + status = op_logmat_cx::apply_direct_noalias(tmp, P); + + out.steal_mem(tmp); + } + else + { + status = op_logmat_cx::apply_direct_noalias(out, P); + } + + return status; + } + + + +template +inline +bool +op_logmat_cx::apply_direct_noalias(Mat& out, const diagmat_proxy& P) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (P.n_rows != P.n_cols), "logmat(): given matrix must be square sized" ); + + const uword N = P.n_rows; + + out.zeros(N,N); + + for(uword i=0; i +inline +bool +op_logmat_cx::apply_direct(Mat& out, const Base& expr, const uword n_iters) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + Mat S = expr.get_ref(); + + arma_debug_check( (S.n_rows != S.n_cols), "logmat(): given matrix must be square sized" ); + + if(S.n_elem == 0) + { + out.reset(); + return true; + } + else + if(S.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::log(S[0]); + return true; + } + + if(S.is_diagmat()) + { + arma_extra_debug_print("op_logmat_cx: detected diagonal matrix"); + + const uword N = S.n_rows; + + out.zeros(N,N); // aliasing can't happen as S is generated + + for(uword i=0; i eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, S, 'd', "logmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i +inline +bool +op_logmat_cx::apply_common(Mat< std::complex >& out, Mat< std::complex >& S, const uword n_iters) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + Mat U; + + const bool schur_ok = auxlib::schur(U,S); + + if(schur_ok == false) { arma_extra_debug_print("logmat(): schur decomposition failed"); return false; } + + // NOTE: theta[0] and theta[1] not really used + double theta[] = { 1.10e-5, 1.82e-3, 1.6206284795015624e-2, 5.3873532631381171e-2, 1.1352802267628681e-1, 1.8662860613541288e-1, 2.642960831111435e-1 }; + + const uword N = S.n_rows; + + uword p = 0; + uword m = 6; + + uword iter = 0; + + while(iter < n_iters) + { + const T tau = norm( (S - eye< Mat >(N,N)), 1 ); + + if(tau <= theta[6]) + { + p++; + + uword j1 = 0; + uword j2 = 0; + + for(uword i=2; i<=6; ++i) { if( tau <= theta[i]) { j1 = i; break; } } + for(uword i=2; i<=6; ++i) { if((tau/2.0) <= theta[i]) { j2 = i; break; } } + + // sanity check, for development purposes only + arma_debug_check( (j2 > j1), "internal error: op_logmat::apply_direct(): j2 > j1" ); + + if( ((j1 - j2) <= 1) || (p == 2) ) { m = j1; break; } + } + + const bool sqrtmat_ok = op_sqrtmat_cx::apply_direct(S,S); + + if(sqrtmat_ok == false) { arma_extra_debug_print("logmat(): sqrtmat() failed"); return false; } + + iter++; + } + + if(iter >= n_iters) { arma_debug_warn_level(2, "logmat(): reached max iterations without full convergence"); } + + S.diag() -= eT(1); + + if(m >= 1) + { + const bool helper_ok = op_logmat_cx::helper(S,m); + + if(helper_ok == false) { return false; } + } + + out = U * S * U.t(); + + out *= eT(eop_aux::pow(double(2), double(iter))); + + return true; + } + + + +template +inline +bool +op_logmat_cx::helper(Mat& A, const uword m) + { + arma_extra_debug_sigprint(); + + if(A.internal_has_nonfinite()) { return false; } + + const vec indices = regspace(1,m-1); + + mat tmp(m, m, arma_zeros_indicator()); + + tmp.diag(-1) = indices / sqrt(square(2.0*indices) - 1.0); + tmp.diag(+1) = indices / sqrt(square(2.0*indices) - 1.0); + + vec eigval; + mat eigvec; + + const bool eig_ok = eig_sym_helper(eigval, eigvec, tmp, 'd', "logmat()"); + + if(eig_ok == false) { arma_extra_debug_print("logmat(): eig_sym() failed"); return false; } + + const vec nodes = (eigval + 1.0) / 2.0; + const vec weights = square(eigvec.row(0).t()); + + const uword N = A.n_rows; + + Mat B(N, N, arma_zeros_indicator()); + + Mat X; + + for(uword i=0; i < m; ++i) + { + // B += weights(i) * solve( (nodes(i)*A + eye< Mat >(N,N)), A ); + + //const bool solve_ok = solve( X, (nodes(i)*A + eye< Mat >(N,N)), A, solve_opts::fast ); + const bool solve_ok = solve( X, trimatu(nodes(i)*A + eye< Mat >(N,N)), A, solve_opts::no_approx ); + + if(solve_ok == false) { arma_extra_debug_print("logmat(): solve() failed"); return false; } + + B += weights(i) * X; + } + + A = B; + + return true; + } + + + +template +inline +void +op_logmat_sympd::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_logmat_sympd::apply_direct(out, in.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("logmat_sympd(): transformation failed"); + } + } + + + +template +inline +bool +op_logmat_sympd::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + const unwrap U(expr.get_ref()); + const Mat& X = U.M; + + arma_debug_check( (X.is_square() == false), "logmat_sympd(): given matrix must be square sized" ); + + if((arma_config::debug) && (arma_config::warn_level > 0) && (is_cx::yes) && (sym_helper::check_diag_imag(X) == false)) + { + arma_debug_warn_level(1, "logmat_sympd(): imaginary components on diagonal are non-zero"); + } + + if(is_op_diagmat::value || X.is_diagmat()) + { + arma_extra_debug_print("op_logmat_sympd: detected diagonal matrix"); + + out = X; + + eT* colmem = out.memptr(); + + const uword N = X.n_rows; + + for(uword i=0; i eigval; + Mat eigvec; + + const bool status = eig_sym_helper(eigval, eigvec, X, 'd', "logmat_sympd()"); + + if(status == false) { return false; } + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + + // + // cubes + + template + inline static void apply(Cube& out, const OpCube& in); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + + // + // for non-complex numbers + + template + inline static eT direct_max(const eT* const X, const uword N); + + template + inline static eT direct_max(const eT* const X, const uword N, uword& index_of_max_val); + + template + inline static eT direct_max(const Mat& X, const uword row); + + template + inline static eT max(const subview& X); + + template + inline static typename arma_not_cx::result max(const Base& X); + + template + inline static typename arma_not_cx::result max(const BaseCube& X); + + template + inline static typename arma_not_cx::result max_with_index(const Proxy& P, uword& index_of_max_val); + + template + inline static typename arma_not_cx::result max_with_index(const ProxyCube& P, uword& index_of_max_val); + + + // + // for complex numbers + + template + inline static std::complex direct_max(const std::complex* const X, const uword n_elem); + + template + inline static std::complex direct_max(const std::complex* const X, const uword n_elem, uword& index_of_max_val); + + template + inline static std::complex direct_max(const Mat< std::complex >& X, const uword row); + + template + inline static std::complex max(const subview< std::complex >& X); + + template + inline static typename arma_cx_only::result max(const Base& X); + + template + inline static typename arma_cx_only::result max(const BaseCube& X); + + template + inline static typename arma_cx_only::result max_with_index(const Proxy& P, uword& index_of_max_val); + + template + inline static typename arma_cx_only::result max_with_index(const ProxyCube& P, uword& index_of_max_val); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_max_meat.hpp b/src/armadillo/include/armadillo_bits/op_max_meat.hpp new file mode 100644 index 0000000..34de86b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_max_meat.hpp @@ -0,0 +1,1325 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_max +//! @{ + + + +template +inline +void +op_max::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "max(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + if(U.is_alias(out) == false) + { + op_max::apply_noalias(out, X, dim); + } + else + { + Mat tmp; + + op_max::apply_noalias(tmp, X, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_max::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + eT* out_mem = out.memptr(); + + arrayops::copy(out_mem, X.colptr(0), X_n_rows); + + for(uword col=1; col out_mem[row]) { out_mem[row] = col_val; } + } + } + } + } + + + +template +inline +void +op_max::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword row=0; row +inline +void +op_max::apply(Cube& out, const OpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 2), "max(): parameter 'dim' must be 0 or 1 or 2" ); + + const unwrap_cube U(in.m); + + if(U.is_alias(out) == false) + { + op_max::apply_noalias(out, U.M, dim); + } + else + { + Cube tmp; + + op_max::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_max::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(X_n_rows == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = op_max::direct_max( X.slice_colptr(slice,col), X_n_rows ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_max::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(X_n_cols == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + arrayops::copy(out_mem, X.slice_colptr(slice,0), X_n_rows); + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.slice_colptr(slice,col); + + for(uword row=0; row < X_n_rows; ++row) + { + const eT col_val = col_mem[row]; + + if(col_val > out_mem[row]) { out_mem[row] = col_val; } + } + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_max::apply(): dim = 2"); + + out.set_size(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(X_n_slices == 0) { return; } + + const uword N = X.n_elem_slice; + + eT* out_mem = out.slice_memptr(0); + + arrayops::copy(out_mem, X.slice_memptr(0), N); + + for(uword slice=1; slice < X_n_slices; ++slice) + { + const eT* X_mem = X.slice_memptr(slice); + + for(uword i=0; i < N; ++i) + { + const eT val = X_mem[i]; + + if(val > out_mem[i]) { out_mem[i] = val; } + } + } + } + } + + + +template +inline +void +op_max::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_max::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(X_n_rows == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = op_max::direct_max( X.slice_colptr(slice,col), X_n_rows ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_max::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(X_n_cols == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + const Mat tmp('j', X.slice_memptr(slice), X_n_rows, X_n_cols); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] = op_max::direct_max(tmp, row); + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_max::apply(): dim = 2"); + + out.set_size(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(X_n_slices == 0) { return; } + + const uword N = X.n_elem_slice; + + eT* out_mem = out.slice_memptr(0); + + arrayops::copy(out_mem, X.slice_memptr(0), N); + + for(uword slice=1; slice < X_n_slices; ++slice) + { + const eT* X_mem = X.slice_memptr(slice); + + for(uword i=0; i < N; ++i) + { + const eT& val = X_mem[i]; + + if(std::abs(val) > std::abs(out_mem[i])) { out_mem[i] = val; } + } + } + } + } + + + +template +inline +eT +op_max::direct_max(const eT* const X, const uword n_elem) + { + arma_extra_debug_sigprint(); + + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + + uword i,j; + for(i=0, j=1; j max_val_i) { max_val_i = X_i; } + if(X_j > max_val_j) { max_val_j = X_j; } + } + + if(i < n_elem) + { + const eT X_i = X[i]; + + if(X_i > max_val_i) { max_val_i = X_i; } + } + + return (max_val_i > max_val_j) ? max_val_i : max_val_j; + } + + + +template +inline +eT +op_max::direct_max(const eT* const X, const uword n_elem, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + + uword best_index_i = 0; + uword best_index_j = 0; + + uword i,j; + for(i=0, j=1; j max_val_i) { max_val_i = X_i; best_index_i = i; } + if(X_j > max_val_j) { max_val_j = X_j; best_index_j = j; } + } + + if(i < n_elem) + { + const eT X_i = X[i]; + + if(X_i > max_val_i) { max_val_i = X_i; best_index_i = i; } + } + + index_of_max_val = (max_val_i > max_val_j) ? best_index_i : best_index_j; + + return (max_val_i > max_val_j) ? max_val_i : max_val_j; + } + + + +template +inline +eT +op_max::direct_max(const Mat& X, const uword row) + { + arma_extra_debug_sigprint(); + + const uword X_n_cols = X.n_cols; + + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + + uword i,j; + for(i=0, j=1; j < X_n_cols; i+=2, j+=2) + { + const eT tmp_i = X.at(row,i); + const eT tmp_j = X.at(row,j); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } + } + + if(i < X_n_cols) + { + const eT tmp_i = X.at(row,i); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + } + + return (max_val_i > max_val_j) ? max_val_i : max_val_j; + } + + + +template +inline +eT +op_max::max(const subview& X) + { + arma_extra_debug_sigprint(); + + if(X.n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(X_n_rows == 1) + { + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + + const Mat& A = X.m; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_col_p1 = start_col + X_n_cols; + + uword i,j; + for(i=start_col, j=start_col+1; j < end_col_p1; i+=2, j+=2) + { + const eT tmp_i = A.at(start_row, i); + const eT tmp_j = A.at(start_row, j); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } + } + + if(i < end_col_p1) + { + const eT tmp_i = A.at(start_row, i); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + } + + return (max_val_i > max_val_j) ? max_val_i : max_val_j; + } + + eT max_val = priv::most_neg(); + + for(uword col=0; col < X_n_cols; ++col) + { + max_val = (std::max)(max_val, op_max::direct_max(X.colptr(col), X_n_rows)); + } + + return max_val; + } + + + +template +inline +typename arma_not_cx::result +op_max::max(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword i,j; + + for(i=0, j=1; j max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } + } + + if(i < n_elem) + { + const eT tmp_i = A[i]; + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows == 1) + { + uword i,j; + for(i=0, j=1; j < n_cols; i+=2, j+=2) + { + const eT tmp_i = P.at(0,i); + const eT tmp_j = P.at(0,j); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } + } + + if(i < n_cols) + { + const eT tmp_i = P.at(0,i); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + uword i,j; + for(i=0, j=1; j < n_rows; i+=2, j+=2) + { + const eT tmp_i = P.at(i,col); + const eT tmp_j = P.at(j,col); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } + } + + if(i < n_rows) + { + const eT tmp_i = P.at(i,col); + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + } + } + } + } + + return (max_val_i > max_val_j) ? max_val_i : max_val_j; + } + + + +template +inline +typename arma_not_cx::result +op_max::max(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const ProxyCube P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + eT max_val = priv::most_neg(); + + if(ProxyCube::use_at == false) + { + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword i,j; + + for(i=0, j=1; j max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } + } + + if(i < n_elem) + { + const eT tmp_i = A[i]; + + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + } + + max_val = (max_val_i > max_val_j) ? max_val_i : max_val_j; + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp = P.at(row,col,slice); + + if(tmp > max_val) { max_val = tmp; } + } + } + + return max_val; + } + + + +template +inline +typename arma_not_cx::result +op_max::max_with_index(const Proxy& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_neg(); + uword best_index = 0; + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i best_val) { best_val = tmp; best_index = i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows == 1) + { + for(uword i=0; i < n_cols; ++i) + { + const eT tmp = P.at(0,i); + + if(tmp > best_val) { best_val = tmp; best_index = i; } + } + } + else + if(n_cols == 1) + { + for(uword i=0; i < n_rows; ++i) + { + const eT tmp = P.at(i,0); + + if(tmp > best_val) { best_val = tmp; best_index = i; } + } + } + else + { + uword count = 0; + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT tmp = P.at(row,col); + + if(tmp > best_val) { best_val = tmp; best_index = count; } + + ++count; + } + } + } + + index_of_max_val = best_index; + + return best_val; + } + + + +template +inline +typename arma_not_cx::result +op_max::max_with_index(const ProxyCube& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_neg(); + uword best_index = 0; + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + const eT tmp = A[i]; + + if(tmp > best_val) { best_val = tmp; best_index = i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + uword count = 0; + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp = P.at(row,col,slice); + + if(tmp > best_val) { best_val = tmp; best_index = count; } + + ++count; + } + } + + index_of_max_val = best_index; + + return best_val; + } + + + +template +inline +std::complex +op_max::direct_max(const std::complex* const X, const uword n_elem) + { + arma_extra_debug_sigprint(); + + uword index = 0; + T max_val = priv::most_neg(); + + for(uword i=0; i max_val) + { + max_val = tmp_val; + index = i; + } + } + + return X[index]; + } + + + +template +inline +std::complex +op_max::direct_max(const std::complex* const X, const uword n_elem, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + uword index = 0; + T max_val = priv::most_neg(); + + for(uword i=0; i max_val) + { + max_val = tmp_val; + index = i; + } + } + + index_of_max_val = index; + + return X[index]; + } + + + +template +inline +std::complex +op_max::direct_max(const Mat< std::complex >& X, const uword row) + { + arma_extra_debug_sigprint(); + + const uword X_n_cols = X.n_cols; + + uword index = 0; + T max_val = priv::most_neg(); + + for(uword col=0; col max_val) + { + max_val = tmp_val; + index = col; + } + } + + return X.at(row,index); + } + + + +template +inline +std::complex +op_max::max(const subview< std::complex >& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + if(X.n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + const Mat& A = X.m; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_row_p1 = start_row + X_n_rows; + const uword end_col_p1 = start_col + X_n_cols; + + T max_val = priv::most_neg(); + + uword best_row = 0; + uword best_col = 0; + + if(X_n_rows == 1) + { + best_col = 0; + + for(uword col=start_col; col < end_col_p1; ++col) + { + const T tmp_val = std::abs( A.at(start_row, col) ); + + if(tmp_val > max_val) + { + max_val = tmp_val; + best_col = col; + } + } + + best_row = start_row; + } + else + { + for(uword col=start_col; col < end_col_p1; ++col) + for(uword row=start_row; row < end_row_p1; ++row) + { + const T tmp_val = std::abs( A.at(row, col) ); + + if(tmp_val > max_val) + { + max_val = tmp_val; + best_row = row; + best_col = col; + } + } + } + + return A.at(best_row, best_col); + } + + + +template +inline +typename arma_cx_only::result +op_max::max(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const Proxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + T max_val = priv::most_neg(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword index = 0; + + for(uword i=0; i max_val) + { + max_val = tmp; + index = i; + } + } + + return( A[index] ); + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + uword best_row = 0; + uword best_col = 0; + + if(n_rows == 1) + { + for(uword col=0; col < n_cols; ++col) + { + const T tmp = std::abs(P.at(0,col)); + + if(tmp > max_val) + { + max_val = tmp; + best_col = col; + } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const T tmp = std::abs(P.at(row,col)); + + if(tmp > max_val) + { + max_val = tmp; + + best_row = row; + best_col = col; + } + } + } + + return P.at(best_row, best_col); + } + } + + + +template +inline +typename arma_cx_only::result +op_max::max(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const ProxyCube P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + T max_val = priv::most_neg(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword index = 0; + + for(uword i=0; i max_val) + { + max_val = tmp; + index = i; + } + } + + return( A[index] ); + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + eT max_val_orig = eT(0); + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp_orig = P.at(row,col,slice); + const T tmp = std::abs(tmp_orig); + + if(tmp > max_val) + { + max_val = tmp; + max_val_orig = tmp_orig; + } + } + + return max_val_orig; + } + } + + + +template +inline +typename arma_cx_only::result +op_max::max_with_index(const Proxy& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + T best_val = priv::most_neg(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword best_index = 0; + + for(uword i=0; i best_val) { best_val = tmp; best_index = i; } + } + + index_of_max_val = best_index; + + return( A[best_index] ); + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + uword best_row = 0; + uword best_col = 0; + uword best_index = 0; + + if(n_rows == 1) + { + for(uword col=0; col < n_cols; ++col) + { + const T tmp = std::abs(P.at(0,col)); + + if(tmp > best_val) { best_val = tmp; best_col = col; } + } + + best_index = best_col; + } + else + if(n_cols == 1) + { + for(uword row=0; row < n_rows; ++row) + { + const T tmp = std::abs(P.at(row,0)); + + if(tmp > best_val) { best_val = tmp; best_row = row; } + } + + best_index = best_row; + } + else + { + uword count = 0; + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const T tmp = std::abs(P.at(row,col)); + + if(tmp > best_val) + { + best_val = tmp; + + best_row = row; + best_col = col; + + best_index = count; + } + + ++count; + } + } + + index_of_max_val = best_index; + + return P.at(best_row, best_col); + } + } + + + +template +inline +typename arma_cx_only::result +op_max::max_with_index(const ProxyCube& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + T best_val = priv::most_neg(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword best_index = 0; + + for(uword i=0; i < n_elem; ++i) + { + const T tmp = std::abs(A[i]); + + if(tmp > best_val) { best_val = tmp; best_index = i; } + } + + index_of_max_val = best_index; + + return( A[best_index] ); + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + eT best_val_orig = eT(0); + uword best_index = 0; + uword count = 0; + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp_orig = P.at(row,col,slice); + const T tmp = std::abs(tmp_orig); + + if(tmp > best_val) + { + best_val = tmp; + best_val_orig = tmp_orig; + best_index = count; + } + + ++count; + } + + index_of_max_val = best_index; + + return best_val_orig; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_mean_bones.hpp b/src/armadillo/include/armadillo_bits/op_mean_bones.hpp new file mode 100644 index 0000000..20a86ae --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_mean_bones.hpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_mean +//! @{ + + +//! Class for finding mean values of a matrix +class op_mean + : public traits_op_xvec + { + public: + + // dense matrices + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_noalias(Mat& out, const Proxy& P, const uword dim); + + template + inline static void apply_noalias_unwrap(Mat& out, const Proxy& P, const uword dim); + + template + inline static void apply_noalias_proxy(Mat& out, const Proxy& P, const uword dim); + + + // cubes + + template + inline static void apply(Cube& out, const OpCube& in); + + template + inline static void apply_noalias(Cube& out, const ProxyCube& P, const uword dim); + + template + inline static void apply_noalias_unwrap(Cube& out, const ProxyCube& P, const uword dim); + + template + inline static void apply_noalias_proxy(Cube& out, const ProxyCube& P, const uword dim); + + + // + + template + inline static eT direct_mean(const eT* const X, const uword N); + + template + inline static eT direct_mean_robust(const eT* const X, const uword N); + + + // + + template + inline static eT direct_mean(const Mat& X, const uword row); + + template + inline static eT direct_mean_robust(const Mat& X, const uword row); + + + // + + template + inline static eT mean_all(const subview& X); + + template + inline static eT mean_all_robust(const subview& X); + + + // + + template + inline static eT mean_all(const diagview& X); + + template + inline static eT mean_all_robust(const diagview& X); + + + // + + template + inline static typename T1::elem_type mean_all(const Op& X); + + template + inline static typename T1::elem_type mean_all(const Base& X); + + + // + + template + arma_inline static eT robust_mean(const eT A, const eT B); + + template + arma_inline static std::complex robust_mean(const std::complex& A, const std::complex& B); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_mean_meat.hpp b/src/armadillo/include/armadillo_bits/op_mean_meat.hpp new file mode 100644 index 0000000..7e7a49d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_mean_meat.hpp @@ -0,0 +1,713 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_mean +//! @{ + + + +template +inline +void +op_mean::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "mean(): parameter 'dim' must be 0 or 1" ); + + const Proxy P(in.m); + + if(P.is_alias(out) == false) + { + op_mean::apply_noalias(out, P, dim); + } + else + { + Mat tmp; + + op_mean::apply_noalias(tmp, P, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_mean::apply_noalias(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + if(is_Mat::stored_type>::value) + { + op_mean::apply_noalias_unwrap(out, P, dim); + } + else + { + op_mean::apply_noalias_proxy(out, P, dim); + } + } + + + +template +inline +void +op_mean::apply_noalias_unwrap(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typedef typename Proxy::stored_type P_stored_type; + + const unwrap tmp(P.Q); + + const typename unwrap::stored_type& X = tmp.M; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = op_mean::direct_mean( X.colptr(col), X_n_rows ); + } + } + else + if(dim == 1) + { + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + const eT* col_mem = X.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] += col_mem[row]; + } + } + + out /= T(X_n_cols); + + for(uword row=0; row < X_n_rows; ++row) + { + if(arma_isfinite(out_mem[row]) == false) + { + out_mem[row] = op_mean::direct_mean_robust( X, row ); + } + } + } + } + + + +template +inline +void +op_mean::apply_noalias_proxy(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + + if(dim == 0) + { + out.set_size((P_n_rows > 0) ? 1 : 0, P_n_cols); + + if(P_n_rows == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col < P_n_cols; ++col) + { + eT val1 = eT(0); + eT val2 = eT(0); + + uword i,j; + for(i=0, j=1; j < P_n_rows; i+=2, j+=2) + { + val1 += P.at(i,col); + val2 += P.at(j,col); + } + + if(i < P_n_rows) + { + val1 += P.at(i,col); + } + + out_mem[col] = (val1 + val2) / T(P_n_rows); + } + } + else + if(dim == 1) + { + out.zeros(P_n_rows, (P_n_cols > 0) ? 1 : 0); + + if(P_n_cols == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col < P_n_cols; ++col) + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] += P.at(row,col); + } + + out /= T(P_n_cols); + } + + if(out.internal_has_nonfinite()) + { + // TODO: replace with dedicated handling to avoid unwrapping + op_mean::apply_noalias_unwrap(out, P, dim); + } + } + + + +// +// cubes + + + +template +inline +void +op_mean::apply(Cube& out, const OpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 2), "mean(): parameter 'dim' must be 0 or 1 or 2" ); + + const ProxyCube P(in.m); + + if(P.is_alias(out) == false) + { + op_mean::apply_noalias(out, P, dim); + } + else + { + Cube tmp; + + op_mean::apply_noalias(tmp, P, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_mean::apply_noalias(Cube& out, const ProxyCube& P, const uword dim) + { + arma_extra_debug_sigprint(); + + if(is_Cube::stored_type>::value) + { + op_mean::apply_noalias_unwrap(out, P, dim); + } + else + { + op_mean::apply_noalias_proxy(out, P, dim); + } + } + + + +template +inline +void +op_mean::apply_noalias_unwrap(Cube& out, const ProxyCube& P, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typedef typename ProxyCube::stored_type P_stored_type; + + const unwrap_cube U(P.Q); + + const Cube& X = U.M; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(X_n_rows == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = op_mean::direct_mean( X.slice_colptr(slice,col), X_n_rows ); + } + } + } + else + if(dim == 1) + { + out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(X_n_cols == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + const eT* col_mem = X.slice_colptr(slice,col); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] += col_mem[row]; + } + } + + const Mat tmp('j', X.slice_memptr(slice), X_n_rows, X_n_cols); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] /= T(X_n_cols); + + if(arma_isfinite(out_mem[row]) == false) + { + out_mem[row] = op_mean::direct_mean_robust( tmp, row ); + } + } + } + } + else + if(dim == 2) + { + out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(X_n_slices == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + arrayops::inplace_plus(out_mem, X.slice_memptr(slice), X.n_elem_slice ); + } + + out /= T(X_n_slices); + + podarray tmp(X_n_slices); + + for(uword col=0; col < X_n_cols; ++col) + for(uword row=0; row < X_n_rows; ++row) + { + if(arma_isfinite(out.at(row,col,0)) == false) + { + for(uword slice=0; slice < X_n_slices; ++slice) + { + tmp[slice] = X.at(row,col,slice); + } + + out.at(row,col,0) = op_mean::direct_mean_robust(tmp.memptr(), X_n_slices); + } + } + } + } + + + +template +inline +void +op_mean::apply_noalias_proxy(Cube& out, const ProxyCube& P, const uword dim) + { + arma_extra_debug_sigprint(); + + op_mean::apply_noalias_unwrap(out, P, dim); + + // TODO: implement specialised handling + } + + + + +// + + + +template +inline +eT +op_mean::direct_mean(const eT* const X, const uword n_elem) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const eT result = arrayops::accumulate(X, n_elem) / T(n_elem); + + return arma_isfinite(result) ? result : op_mean::direct_mean_robust(X, n_elem); + } + + + +template +inline +eT +op_mean::direct_mean_robust(const eT* const X, const uword n_elem) + { + arma_extra_debug_sigprint(); + + // use an adapted form of the mean finding algorithm from the running_stat class + + typedef typename get_pod_type::result T; + + uword i,j; + + eT r_mean = eT(0); + + for(i=0, j=1; j +inline +eT +op_mean::direct_mean(const Mat& X, const uword row) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_cols = X.n_cols; + + eT val = eT(0); + + uword i,j; + for(i=0, j=1; j < X_n_cols; i+=2, j+=2) + { + val += X.at(row,i); + val += X.at(row,j); + } + + if(i < X_n_cols) + { + val += X.at(row,i); + } + + const eT result = val / T(X_n_cols); + + return arma_isfinite(result) ? result : op_mean::direct_mean_robust(X, row); + } + + + +template +inline +eT +op_mean::direct_mean_robust(const Mat& X, const uword row) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_cols = X.n_cols; + + eT r_mean = eT(0); + + for(uword col=0; col < X_n_cols; ++col) + { + r_mean = r_mean + (X.at(row,col) - r_mean)/T(col+1); + } + + return r_mean; + } + + + +template +inline +eT +op_mean::mean_all(const subview& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_elem = X.n_elem; + + if(X_n_elem == 0) + { + arma_debug_check(true, "mean(): object has no elements"); + + return Datum::nan; + } + + eT val = eT(0); + + if(X_n_rows == 1) + { + const Mat& A = X.m; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_col_p1 = start_col + X_n_cols; + + uword i,j; + for(i=start_col, j=start_col+1; j < end_col_p1; i+=2, j+=2) + { + val += A.at(start_row, i); + val += A.at(start_row, j); + } + + if(i < end_col_p1) + { + val += A.at(start_row, i); + } + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + val += arrayops::accumulate(X.colptr(col), X_n_rows); + } + } + + const eT result = val / T(X_n_elem); + + return arma_isfinite(result) ? result : op_mean::mean_all_robust(X); + } + + + +template +inline +eT +op_mean::mean_all_robust(const subview& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_row_p1 = start_row + X_n_rows; + const uword end_col_p1 = start_col + X_n_cols; + + const Mat& A = X.m; + + + eT r_mean = eT(0); + + if(X_n_rows == 1) + { + uword i=0; + + for(uword col = start_col; col < end_col_p1; ++col, ++i) + { + r_mean = r_mean + (A.at(start_row,col) - r_mean)/T(i+1); + } + } + else + { + uword i=0; + + for(uword col = start_col; col < end_col_p1; ++col) + for(uword row = start_row; row < end_row_p1; ++row, ++i) + { + r_mean = r_mean + (A.at(row,col) - r_mean)/T(i+1); + } + } + + return r_mean; + } + + + +template +inline +eT +op_mean::mean_all(const diagview& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_elem = X.n_elem; + + if(X_n_elem == 0) + { + arma_debug_check(true, "mean(): object has no elements"); + + return Datum::nan; + } + + eT val = eT(0); + + for(uword i=0; i +inline +eT +op_mean::mean_all_robust(const diagview& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword X_n_elem = X.n_elem; + + eT r_mean = eT(0); + + for(uword i=0; i +inline +typename T1::elem_type +op_mean::mean_all(const Op& X) + { + arma_extra_debug_sigprint(); + + return op_mean::mean_all(X.m); + } + + + +template +inline +typename T1::elem_type +op_mean::mean_all(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + const uword A_n_elem = A.n_elem; + + if(A_n_elem == 0) + { + arma_debug_check(true, "mean(): object has no elements"); + + return Datum::nan; + } + + return op_mean::direct_mean(A.memptr(), A_n_elem); + } + + + +template +arma_inline +eT +op_mean::robust_mean(const eT A, const eT B) + { + return A + (B - A)/eT(2); + } + + + +template +arma_inline +std::complex +op_mean::robust_mean(const std::complex& A, const std::complex& B) + { + return A + (B - A)/T(2); + } + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/op_median_bones.hpp b/src/armadillo/include/armadillo_bits/op_median_bones.hpp new file mode 100644 index 0000000..8212d14 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_median_bones.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_median +//! @{ + + +template +struct arma_cx_median_packet + { + T val; + uword index; + }; + + + +template +arma_inline +bool +operator< (const arma_cx_median_packet& A, const arma_cx_median_packet& B) + { + return (A.val < B.val); + } + + + +class op_median + : public traits_op_xvec + { + public: + + template + inline static void apply(Mat& out, const Op& expr); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + // + // + + template + inline static typename T1::elem_type median_vec(const T1& X, const typename arma_not_cx::result* junk = nullptr); + + template + inline static typename T1::elem_type median_vec(const T1& X, const typename arma_cx_only::result* junk = nullptr); + + // + // + + template + inline static eT direct_median(std::vector& X); + + template + inline static void direct_cx_median_index(uword& out_index1, uword& out_index2, std::vector< arma_cx_median_packet >& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_median_meat.hpp b/src/armadillo/include/armadillo_bits/op_median_meat.hpp new file mode 100644 index 0000000..ae80515 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_median_meat.hpp @@ -0,0 +1,338 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_median +//! @{ + + + +template +inline +void +op_median::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(expr.m); + + const uword dim = expr.aux_uword_a; + + arma_debug_check( U.M.internal_has_nan(), "median(): detected NaN" ); + arma_debug_check( (dim > 1), "median(): parameter 'dim' must be 0 or 1" ); + + if(U.is_alias(out)) + { + Mat tmp; + + op_median::apply_noalias(out, U.M, dim); + + out.steal_mem(tmp); + } + else + { + op_median::apply_noalias(out, U.M, dim); + } + } + + + +template +inline +void +op_median::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) // in each column + { + arma_extra_debug_print("op_median::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + std::vector tmp_vec(X_n_rows); + + for(uword col=0; col < X_n_cols; ++col) + { + arrayops::copy( &(tmp_vec[0]), X.colptr(col), X_n_rows ); + + out[col] = op_median::direct_median(tmp_vec); + } + } + } + else + if(dim == 1) // in each row + { + arma_extra_debug_print("op_median::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols > 0) + { + std::vector tmp_vec(X_n_cols); + + for(uword row=0; row < X_n_rows; ++row) + { + for(uword col=0; col < X_n_cols; ++col) { tmp_vec[col] = X.at(row,col); } + + out[row] = op_median::direct_median(tmp_vec); + } + } + } + } + + + +template +inline +void +op_median::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) // in each column + { + arma_extra_debug_print("op_median::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + std::vector< arma_cx_median_packet > tmp_vec(X_n_rows); + + for(uword col=0; col 0) ? 1 : 0); + + if(X_n_cols > 0) + { + std::vector< arma_cx_median_packet > tmp_vec(X_n_cols); + + for(uword row=0; row +inline +typename T1::elem_type +op_median::median_vec + ( + const T1& X, + const typename arma_not_cx::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(X); + + const uword n_elem = U.M.n_elem; + + if(n_elem == 0) + { + arma_debug_check(true, "median(): object has no elements"); + + return Datum::nan; + } + + arma_debug_check( U.M.internal_has_nan(), "median(): detected NaN" ); + + std::vector tmp_vec(n_elem); + + arrayops::copy( &(tmp_vec[0]), U.M.memptr(), n_elem ); + + return op_median::direct_median(tmp_vec); + } + + + +template +inline +typename T1::elem_type +op_median::median_vec + ( + const T1& X, + const typename arma_cx_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const quasi_unwrap U(X); + + const uword n_elem = U.M.n_elem; + + if(n_elem == 0) + { + arma_debug_check(true, "median(): object has no elements"); + + return Datum::nan; + } + + arma_debug_check( U.M.internal_has_nan(), "median(): detected NaN" ); + + std::vector< arma_cx_median_packet > tmp_vec(n_elem); + + const eT* A = U.M.memptr(); + + for(uword i=0; i +inline +eT +op_median::direct_median(std::vector& X) + { + arma_extra_debug_sigprint(); + + const uword n_elem = uword(X.size()); + const uword half = n_elem/2; + + typename std::vector::iterator first = X.begin(); + typename std::vector::iterator nth = first + half; + typename std::vector::iterator pastlast = X.end(); + + std::nth_element(first, nth, pastlast); + + if((n_elem % 2) == 0) // even number of elements + { + typename std::vector::iterator start = X.begin(); + typename std::vector::iterator pastend = start + half; + + const eT val1 = (*nth); + const eT val2 = (*(std::max_element(start, pastend))); + + return op_mean::robust_mean(val1, val2); + } + else // odd number of elements + { + return (*nth); + } + } + + + +template +inline +void +op_median::direct_cx_median_index + ( + uword& out_index1, + uword& out_index2, + std::vector< arma_cx_median_packet >& X + ) + { + arma_extra_debug_sigprint(); + + typedef arma_cx_median_packet eT; + + const uword n_elem = uword(X.size()); + const uword half = n_elem/2; + + typename std::vector::iterator first = X.begin(); + typename std::vector::iterator nth = first + half; + typename std::vector::iterator pastlast = X.end(); + + std::nth_element(first, nth, pastlast); + + out_index1 = (*nth).index; + + if((n_elem % 2) == 0) // even number of elements + { + typename std::vector::iterator start = X.begin(); + typename std::vector::iterator pastend = start + half; + + out_index2 = (*(std::max_element(start, pastend))).index; + } + else // odd number of elements + { + out_index2 = out_index1; + } + } + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/op_min_bones.hpp b/src/armadillo/include/armadillo_bits/op_min_bones.hpp new file mode 100644 index 0000000..e9f5a62 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_min_bones.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_min +//! @{ + + +class op_min + : public traits_op_xvec + { + public: + + // + // dense matrices + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + + // + // cubes + + template + inline static void apply(Cube& out, const OpCube& in); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + + // + // for non-complex numbers + + template + inline static eT direct_min(const eT* const X, const uword N); + + template + inline static eT direct_min(const eT* const X, const uword N, uword& index_of_min_val); + + template + inline static eT direct_min(const Mat& X, const uword row); + + template + inline static eT min(const subview& X); + + template + inline static typename arma_not_cx::result min(const Base& X); + + template + inline static typename arma_not_cx::result min(const BaseCube& X); + + template + inline static typename arma_not_cx::result min_with_index(const Proxy& P, uword& index_of_min_val); + + template + inline static typename arma_not_cx::result min_with_index(const ProxyCube& P, uword& index_of_min_val); + + + // + // for complex numbers + + template + inline static std::complex direct_min(const std::complex* const X, const uword n_elem); + + template + inline static std::complex direct_min(const std::complex* const X, const uword n_elem, uword& index_of_min_val); + + template + inline static std::complex direct_min(const Mat< std::complex >& X, const uword row); + + template + inline static std::complex min(const subview< std::complex >& X); + + template + inline static typename arma_cx_only::result min(const Base& X); + + template + inline static typename arma_cx_only::result min(const BaseCube& X); + + template + inline static typename arma_cx_only::result min_with_index(const Proxy& P, uword& index_of_min_val); + + template + inline static typename arma_cx_only::result min_with_index(const ProxyCube& P, uword& index_of_min_val); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_min_meat.hpp b/src/armadillo/include/armadillo_bits/op_min_meat.hpp new file mode 100644 index 0000000..9879185 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_min_meat.hpp @@ -0,0 +1,1325 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_min +//! @{ + + + +template +inline +void +op_min::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "min(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + if(U.is_alias(out) == false) + { + op_min::apply_noalias(out, X, dim); + } + else + { + Mat tmp; + + op_min::apply_noalias(tmp, X, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_min::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + eT* out_mem = out.memptr(); + + arrayops::copy(out_mem, X.colptr(0), X_n_rows); + + for(uword col=1; col +inline +void +op_min::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword col=0; col 0) ? 1 : 0); + + if(X_n_cols == 0) { return; } + + eT* out_mem = out.memptr(); + + for(uword row=0; row +inline +void +op_min::apply(Cube& out, const OpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 2), "min(): parameter 'dim' must be 0 or 1 or 2" ); + + const unwrap_cube U(in.m); + + if(U.is_alias(out) == false) + { + op_min::apply_noalias(out, U.M, dim); + } + else + { + Cube tmp; + + op_min::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_min::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(X_n_rows == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = op_min::direct_min( X.slice_colptr(slice,col), X_n_rows ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_min::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(X_n_cols == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + arrayops::copy(out_mem, X.slice_colptr(slice,0), X_n_rows); + + for(uword col=1; col < X_n_cols; ++col) + { + const eT* col_mem = X.slice_colptr(slice,col); + + for(uword row=0; row < X_n_rows; ++row) + { + const eT col_val = col_mem[row]; + + if(col_val < out_mem[row]) { out_mem[row] = col_val; } + } + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_min::apply(): dim = 2"); + + out.set_size(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(X_n_slices == 0) { return; } + + const uword N = X.n_elem_slice; + + eT* out_mem = out.slice_memptr(0); + + arrayops::copy(out_mem, X.slice_memptr(0), N); + + for(uword slice=1; slice < X_n_slices; ++slice) + { + const eT* X_mem = X.slice_memptr(slice); + + for(uword i=0; i < N; ++i) + { + const eT val = X_mem[i]; + + if(val < out_mem[i]) { out_mem[i] = val; } + } + } + } + } + + + +template +inline +void +op_min::apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + arma_extra_debug_print("op_min::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices); + + if(X_n_rows == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = op_min::direct_min( X.slice_colptr(slice,col), X_n_rows ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_min::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices); + + if(X_n_cols == 0) { return; } + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + const Mat tmp('j', X.slice_memptr(slice), X_n_rows, X_n_cols); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] = op_min::direct_min(tmp, row); + } + } + } + else + if(dim == 2) + { + arma_extra_debug_print("op_min::apply(): dim = 2"); + + out.set_size(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0); + + if(X_n_slices == 0) { return; } + + const uword N = X.n_elem_slice; + + eT* out_mem = out.slice_memptr(0); + + arrayops::copy(out_mem, X.slice_memptr(0), N); + + for(uword slice=1; slice < X_n_slices; ++slice) + { + const eT* X_mem = X.slice_memptr(slice); + + for(uword i=0; i < N; ++i) + { + const eT& val = X_mem[i]; + + if(std::abs(val) < std::abs(out_mem[i])) { out_mem[i] = val; } + } + } + } + } + + + +template +inline +eT +op_min::direct_min(const eT* const X, const uword n_elem) + { + arma_extra_debug_sigprint(); + + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + + uword i,j; + for(i=0, j=1; j +inline +eT +op_min::direct_min(const eT* const X, const uword n_elem, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + + uword best_index_i = 0; + uword best_index_j = 0; + + uword i,j; + for(i=0, j=1; j +inline +eT +op_min::direct_min(const Mat& X, const uword row) + { + arma_extra_debug_sigprint(); + + const uword X_n_cols = X.n_cols; + + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + + uword i,j; + for(i=0, j=1; j < X_n_cols; i+=2, j+=2) + { + const eT tmp_i = X.at(row,i); + const eT tmp_j = X.at(row,j); + + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } + } + + if(i < X_n_cols) + { + const eT tmp_i = X.at(row,i); + + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + } + + return (min_val_i < min_val_j) ? min_val_i : min_val_j; + } + + + +template +inline +eT +op_min::min(const subview& X) + { + arma_extra_debug_sigprint(); + + if(X.n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(X_n_rows == 1) + { + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + + const Mat& A = X.m; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_col_p1 = start_col + X_n_cols; + + uword i,j; + for(i=start_col, j=start_col+1; j < end_col_p1; i+=2, j+=2) + { + const eT tmp_i = A.at(start_row, i); + const eT tmp_j = A.at(start_row, j); + + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } + } + + if(i < end_col_p1) + { + const eT tmp_i = A.at(start_row, i); + + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + } + + return (min_val_i < min_val_j) ? min_val_i : min_val_j; + } + + eT min_val = priv::most_pos(); + + for(uword col=0; col < X_n_cols; ++col) + { + min_val = (std::min)(min_val, op_min::direct_min(X.colptr(col), X_n_rows)); + } + + return min_val; + } + + + +template +inline +typename arma_not_cx::result +op_min::min(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword i,j; + + for(i=0, j=1; j +inline +typename arma_not_cx::result +op_min::min(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const ProxyCube P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + eT min_val = priv::most_pos(); + + if(ProxyCube::use_at == false) + { + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword i,j; + + for(i=0, j=1; j +inline +typename arma_not_cx::result +op_min::min_with_index(const Proxy& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_pos(); + uword best_index = 0; + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i +inline +typename arma_not_cx::result +op_min::min_with_index(const ProxyCube& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_pos(); + uword best_index = 0; + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + const eT tmp = A[i]; + + if(tmp < best_val) { best_val = tmp; best_index = i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + uword count = 0; + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp = P.at(row,col,slice); + + if(tmp < best_val) { best_val = tmp; best_index = count; } + + ++count; + } + } + + index_of_min_val = best_index; + + return best_val; + } + + + +template +inline +std::complex +op_min::direct_min(const std::complex* const X, const uword n_elem) + { + arma_extra_debug_sigprint(); + + uword index = 0; + T min_val = priv::most_pos(); + + for(uword i=0; i +inline +std::complex +op_min::direct_min(const std::complex* const X, const uword n_elem, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + uword index = 0; + T min_val = priv::most_pos(); + + for(uword i=0; i +inline +std::complex +op_min::direct_min(const Mat< std::complex >& X, const uword row) + { + arma_extra_debug_sigprint(); + + const uword X_n_cols = X.n_cols; + + uword index = 0; + T min_val = priv::most_pos(); + + for(uword col=0; col +inline +std::complex +op_min::min(const subview< std::complex >& X) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + if(X.n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + const Mat& A = X.m; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_row_p1 = start_row + X_n_rows; + const uword end_col_p1 = start_col + X_n_cols; + + T min_val = priv::most_pos(); + + uword best_row = 0; + uword best_col = 0; + + if(X_n_rows == 1) + { + best_col = 0; + + for(uword col=start_col; col < end_col_p1; ++col) + { + const T tmp_val = std::abs( A.at(start_row, col) ); + + if(tmp_val < min_val) + { + min_val = tmp_val; + best_col = col; + } + } + + best_row = start_row; + } + else + { + for(uword col=start_col; col < end_col_p1; ++col) + for(uword row=start_row; row < end_row_p1; ++row) + { + const T tmp_val = std::abs( A.at(row, col) ); + + if(tmp_val < min_val) + { + min_val = tmp_val; + best_row = row; + best_col = col; + } + } + } + + return A.at(best_row, best_col); + } + + + +template +inline +typename arma_cx_only::result +op_min::min(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const Proxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + T min_val = priv::most_pos(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword index = 0; + + for(uword i=0; i +inline +typename arma_cx_only::result +op_min::min(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const ProxyCube P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + T min_val = priv::most_pos(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword index = 0; + + for(uword i=0; i +inline +typename arma_cx_only::result +op_min::min_with_index(const Proxy& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + T best_val = priv::most_pos(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword best_index = 0; + + for(uword i=0; i +inline +typename arma_cx_only::result +op_min::min_with_index(const ProxyCube& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + T best_val = priv::most_pos(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + uword best_index = 0; + + for(uword i=0; i < n_elem; ++i) + { + const T tmp = std::abs(A[i]); + + if(tmp < best_val) { best_val = tmp; best_index = i; } + } + + index_of_min_val = best_index; + + return( A[best_index] ); + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + eT best_val_orig = eT(0); + uword best_index = 0; + uword count = 0; + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp_orig = P.at(row,col,slice); + const T tmp = std::abs(tmp_orig); + + if(tmp < best_val) + { + best_val = tmp; + best_val_orig = tmp_orig; + best_index = count; + } + + ++count; + } + + index_of_min_val = best_index; + + return best_val_orig; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_misc_bones.hpp b/src/armadillo/include/armadillo_bits/op_misc_bones.hpp new file mode 100644 index 0000000..5fd1571 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_misc_bones.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_misc +//! @{ + + + +class op_real + : public traits_op_passthru + { + public: + + template + inline static void apply( Mat& out, const mtOp& X); + + template + inline static void apply( Cube& out, const mtOpCube& X); + }; + + + +class op_imag + : public traits_op_passthru + { + public: + + template + inline static void apply( Mat& out, const mtOp& X); + + template + inline static void apply( Cube& out, const mtOpCube& X); + }; + + + +class op_abs + : public traits_op_passthru + { + public: + + template + inline static void apply( Mat& out, const mtOp& X); + + template + inline static void apply( Cube& out, const mtOpCube& X); + }; + + + +class op_arg + : public traits_op_passthru + { + public: + + template + inline static void apply( Mat& out, const mtOp& X); + + template + inline static void apply( Cube& out, const mtOpCube& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_misc_meat.hpp b/src/armadillo/include/armadillo_bits/op_misc_meat.hpp new file mode 100644 index 0000000..d1c2f36 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_misc_meat.hpp @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_misc +//! @{ + + + +template +inline +void +op_real::apply( Mat& out, const mtOp& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const Proxy P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + T* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::real( A[i] ); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + *out_mem = std::real( P.at(row,col) ); + out_mem++; + } + } + } + + + +template +inline +void +op_real::apply( Cube& out, const mtOpCube& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const ProxyCube P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + T* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::real( A[i] ); + } + } + else + { + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + *out_mem = std::real( P.at(row,col,slice) ); + out_mem++; + } + } + } + + + +template +inline +void +op_imag::apply( Mat& out, const mtOp& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const Proxy P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + T* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::imag( A[i] ); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + *out_mem = std::imag( P.at(row,col) ); + out_mem++; + } + } + } + + + +template +inline +void +op_imag::apply( Cube& out, const mtOpCube& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const ProxyCube P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + T* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::imag( A[i] ); + } + } + else + { + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + *out_mem = std::imag( P.at(row,col,slice) ); + out_mem++; + } + } + } + + + +template +inline +void +op_abs::apply( Mat& out, const mtOp& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const Proxy P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + T* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::abs( A[i] ); + } + } + #else + { + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::abs( A[i] ); + } + } + #endif + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + *out_mem = std::abs( P.at(row,col) ); + out_mem++; + } + } + } + + + +template +inline +void +op_abs::apply( Cube& out, const mtOpCube& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const ProxyCube P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + T* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::abs( A[i] ); + } + } + #else + { + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = std::abs( A[i] ); + } + } + #endif + } + else + { + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + *out_mem = std::abs( P.at(row,col,slice) ); + out_mem++; + } + } + } + + + +template +inline +void +op_arg::apply( Mat& out, const mtOp& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const Proxy P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + T* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = arma_arg::eval( A[i] ); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + *out_mem = arma_arg::eval( P.at(row,col) ); + out_mem++; + } + } + } + + + +template +inline +void +op_arg::apply( Cube& out, const mtOpCube& X ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const ProxyCube P(X.m); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + T* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + const uword n_elem = P.get_n_elem(); + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + out_mem[i] = arma_arg::eval( A[i] ); + } + } + else + { + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + *out_mem = arma_arg::eval( P.at(row,col,slice) ); + out_mem++; + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_nonzeros_bones.hpp b/src/armadillo/include/armadillo_bits/op_nonzeros_bones.hpp new file mode 100644 index 0000000..8c3fd65 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_nonzeros_bones.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_nonzeros +//! @{ + + + +class op_nonzeros + : public traits_op_col + { + public: + + // for dense matrices + + template + static inline void apply_noalias(Mat& out, const Proxy& P); + + template + static inline void apply(Mat& out, const Op& X); + }; + + + +class op_nonzeros_spmat + : public traits_op_col + { + public: + + template + static inline void apply(Mat& out, const SpToDOp& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_nonzeros_meat.hpp b/src/armadillo/include/armadillo_bits/op_nonzeros_meat.hpp new file mode 100644 index 0000000..8cf32fa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_nonzeros_meat.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_nonzeros +//! @{ + + + +template +inline +void +op_nonzeros::apply_noalias(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword N_max = P.get_n_elem(); + + Mat tmp(N_max, 1, arma_nozeros_indicator()); + + eT* tmp_mem = tmp.memptr(); + + uword N_nz = 0; + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i +inline +void +op_nonzeros::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.m); + + if(P.get_n_elem() == 0) { out.set_size(0,1); return; } + + if(P.is_alias(out)) + { + Mat out2; + + op_nonzeros::apply_noalias(out2, P); + + out.steal_mem(out2); + } + else + { + op_nonzeros::apply_noalias(out, P); + } + } + + + +template +inline +void +op_nonzeros_spmat::apply(Mat& out, const SpToDOp& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(X.m); + + const uword N = P.get_n_nonzero(); + + out.set_size(N,1); + + if(N == 0) { return; } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat::stored_type> U(P.Q); + + arrayops::copy(out.memptr(), U.M.values, N); + + return; + } + + if(is_SpSubview::stored_type>::value) + { + const SpSubview& sv = reinterpret_cast< const SpSubview& >(P.Q); + + if(sv.n_rows == sv.m.n_rows) + { + const SpMat& m = sv.m; + const uword col = sv.aux_col1; + + arrayops::copy(out.memptr(), &(m.values[ m.col_ptrs[col] ]), N); + + return; + } + } + + eT* out_mem = out.memptr(); + + typename SpProxy::const_iterator_type it = P.begin(); + + for(uword i=0; i +struct norm2est_randu_filler + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline norm2est_randu_filler(); + + inline void fill(eT* mem, const uword N); + }; + + +template +struct norm2est_randu_filler< std::complex > + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline norm2est_randu_filler(); + + inline void fill(std::complex* mem, const uword N); + }; + + + +class op_norm2est + : public traits_op_default + { + public: + + template inline static typename T1::pod_type norm2est(const Base& X, const typename T1::pod_type tolerance, const uword max_iter); + template inline static typename T1::pod_type norm2est(const SpBase& X, const typename T1::pod_type tolerance, const uword max_iter); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_norm2est_meat.hpp b/src/armadillo/include/armadillo_bits/op_norm2est_meat.hpp new file mode 100644 index 0000000..b809d13 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_norm2est_meat.hpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_norm2est +//! @{ + + + +template +inline +norm2est_randu_filler::norm2est_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +norm2est_randu_filler::fill(eT* mem, const uword N) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i +inline +norm2est_randu_filler< std::complex >::norm2est_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +norm2est_randu_filler< std::complex >::fill(std::complex* mem, const uword N) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i& mem_i = mem[i]; + + mem_i.real( T(local_u_distr(local_engine)) ); + mem_i.imag( T(local_u_distr(local_engine)) ); + } + } + + + +// +// +// + + + +template +inline +typename T1::pod_type +op_norm2est::norm2est + ( + const Base& X, + const typename T1::pod_type tolerance, + const uword max_iter + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + arma_debug_check( (tolerance < T(0)), "norm2est(): parameter 'tolerance' must be > 0" ); + arma_debug_check( (max_iter == uword(0)), "norm2est(): parameter 'max_iter' must be > 0" ); + + const T tol = (tolerance == T(0)) ? T(1e-6) : T(tolerance); + + const quasi_unwrap U(X.get_ref()); + const Mat& A = U.M; + + if(A.n_elem == 0) { return T(0); } + + if(A.internal_has_nonfinite()) { arma_debug_warn_level(1, "norm2est(): given matrix has non-finite elements"); } + + if((A.n_rows == 1) || (A.n_cols == 1)) { return op_norm::vec_norm_2( Proxy< Mat >(A) ); } + + norm2est_randu_filler randu_filler; + + Col x(A.n_rows, fill::none); + Col y(A.n_cols, fill::none); + + randu_filler.fill(y.memptr(), y.n_elem); + + T est_old = 0; + T est_cur = 0; + + for(uword i=0; i >(x) ); + + if(x_norm == T(0) || (arma_isfinite(x_norm) == false) || (x.internal_has_nonfinite())) + { + randu_filler.fill(x.memptr(), x.n_elem); + + x_norm = op_norm::vec_norm_2( Proxy< Col >(x) ); + } + + if(x_norm != T(0)) { x /= x_norm; } + + y = A.t() * x; + + est_old = est_cur; + est_cur = op_norm::vec_norm_2( Proxy< Col >(y) ); + + arma_extra_debug_print(arma_str::format("norm2est(): est_old: %e") % est_old); + arma_extra_debug_print(arma_str::format("norm2est(): est_cur: %e") % est_cur); + + if(arma_isfinite(est_cur) == false) { return est_old; } + + if( ((std::abs)(est_cur - est_old)) <= (tol * (std::max)(est_cur,est_old)) ) { break; } + } + + return est_cur; + } + + + +// +// +// + + + +template +inline +typename T1::pod_type +op_norm2est::norm2est + ( + const SpBase& X, + const typename T1::pod_type tolerance, + const uword max_iter + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + arma_debug_check( (tolerance < T(0)), "norm2est(): parameter 'tolerance' must be > 0" ); + arma_debug_check( (max_iter == uword(0)), "norm2est(): parameter 'max_iter' must be > 0" ); + + const T tol = (tolerance == T(0)) ? T(1e-6) : T(tolerance); + + const unwrap_spmat U(X.get_ref()); + const SpMat& A = U.M; + + if(A.n_nonzero == 0) { return T(0); } + + if(A.internal_has_nonfinite()) { arma_debug_warn_level(1, "norm2est(): given matrix has non-finite elements"); } + + if((A.n_rows == 1) || (A.n_cols == 1)) { return spop_norm::vec_norm_k(A.values, A.n_nonzero, 2); } + + norm2est_randu_filler randu_filler; + + Mat x(A.n_rows, 1, fill::none); + Mat y(A.n_cols, 1, fill::none); + + randu_filler.fill(y.memptr(), y.n_elem); + + T est_old = 0; + T est_cur = 0; + + for(uword i=0; i >(x) ); + + if(x_norm == T(0) || (arma_isfinite(x_norm) == false) || (x.internal_has_nonfinite())) + { + randu_filler.fill(x.memptr(), x.n_elem); + + x_norm = op_norm::vec_norm_2( Proxy< Mat >(x) ); + } + + if(x_norm != T(0)) { x /= x_norm; } + + y = A.t() * x; + + est_old = est_cur; + est_cur = op_norm::vec_norm_2( Proxy< Mat >(y) ); + + arma_extra_debug_print(arma_str::format("norm2est(): est_old: %e") % est_old); + arma_extra_debug_print(arma_str::format("norm2est(): est_cur: %e") % est_cur); + + if(arma_isfinite(est_cur) == false) { return est_old; } + + if( ((std::abs)(est_cur - est_old)) <= (tol * (std::max)(est_cur,est_old)) ) { break; } + } + + return est_cur; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_norm_bones.hpp b/src/armadillo/include/armadillo_bits/op_norm_bones.hpp new file mode 100644 index 0000000..f402338 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_norm_bones.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_norm +//! @{ + + +class op_norm + : public traits_op_default + { + public: + + template arma_hot inline static typename T1::pod_type vec_norm_1(const Proxy& P, const typename arma_not_cx::result* junk = nullptr); + template arma_hot inline static typename T1::pod_type vec_norm_1(const Proxy& P, const typename arma_cx_only::result* junk = nullptr); + template arma_hot inline static eT vec_norm_1_direct_std(const Mat& X); + template arma_hot inline static eT vec_norm_1_direct_mem(const uword N, const eT* A); + + template arma_hot inline static typename T1::pod_type vec_norm_2(const Proxy& P, const typename arma_not_cx::result* junk = nullptr); + template arma_hot inline static typename T1::pod_type vec_norm_2(const Proxy& P, const typename arma_cx_only::result* junk = nullptr); + template arma_hot inline static eT vec_norm_2_direct_std(const Mat& X); + template arma_hot inline static eT vec_norm_2_direct_mem(const uword N, const eT* A); + template arma_hot inline static eT vec_norm_2_direct_robust(const Mat& X); + + template arma_hot inline static typename T1::pod_type vec_norm_k(const Proxy& P, const int k); + + template arma_hot inline static typename T1::pod_type vec_norm_max(const Proxy& P); + template arma_hot inline static typename T1::pod_type vec_norm_min(const Proxy& P); + + template inline static typename get_pod_type::result mat_norm_1(const Mat& X); + template inline static typename get_pod_type::result mat_norm_2(const Mat& X); + + template inline static typename get_pod_type::result mat_norm_inf(const Mat& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_norm_meat.hpp b/src/armadillo/include/armadillo_bits/op_norm_meat.hpp new file mode 100644 index 0000000..2f50f14 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_norm_meat.hpp @@ -0,0 +1,905 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_norm +//! @{ + + + +template +inline +typename T1::pod_type +op_norm::vec_norm_1(const Proxy& P, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool use_direct_mem = (is_Mat::stored_type>::value) || (is_subview_col::stored_type>::value) || (arma_config::openmp && Proxy::use_mp); + + if(use_direct_mem) + { + const quasi_unwrap::stored_type> tmp(P.Q); + + return op_norm::vec_norm_1_direct_std(tmp.M); + } + + typedef typename T1::pod_type T; + + T acc = T(0); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type A = P.get_ea(); + + const uword N = P.get_n_elem(); + + T acc1 = T(0); + T acc2 = T(0); + + uword i,j; + for(i=0, j=1; j +inline +typename T1::pod_type +op_norm::vec_norm_1(const Proxy& P, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + T acc = T(0); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type A = P.get_ea(); + + const uword N = P.get_n_elem(); + + for(uword i=0; i& X = A[i]; + + const T a = X.real(); + const T b = X.imag(); + + acc += std::sqrt( (a*a) + (b*b) ); + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows == 1) + { + for(uword col=0; col& X = P.at(0,col); + + const T a = X.real(); + const T b = X.imag(); + + acc += std::sqrt( (a*a) + (b*b) ); + } + } + else + { + for(uword col=0; col& X = P.at(row,col); + + const T a = X.real(); + const T b = X.imag(); + + acc += std::sqrt( (a*a) + (b*b) ); + } + } + } + + if( (acc != T(0)) && arma_isfinite(acc) ) + { + return acc; + } + else + { + arma_extra_debug_print("op_norm::vec_norm_1(): detected possible underflow or overflow"); + + const quasi_unwrap::stored_type> R(P.Q); + + const uword N = R.M.n_elem; + const eT* R_mem = R.M.memptr(); + + T max_val = priv::most_neg(); + + for(uword i=0; i& X = R_mem[i]; + + const T a = std::abs(X.real()); + const T b = std::abs(X.imag()); + + if(a > max_val) { max_val = a; } + if(b > max_val) { max_val = b; } + } + + if(max_val == T(0)) { return T(0); } + + T alt_acc = T(0); + + for(uword i=0; i& X = R_mem[i]; + + const T a = X.real() / max_val; + const T b = X.imag() / max_val; + + alt_acc += std::sqrt( (a*a) + (b*b) ); + } + + return ( alt_acc * max_val ); + } + } + + + +template +inline +eT +op_norm::vec_norm_1_direct_std(const Mat& X) + { + arma_extra_debug_sigprint(); + + const uword N = X.n_elem; + const eT* A = X.memptr(); + + if(N < uword(32)) + { + return op_norm::vec_norm_1_direct_mem(N,A); + } + else + { + #if defined(ARMA_USE_ATLAS) + { + return atlas::cblas_asum(N,A); + } + #elif defined(ARMA_USE_BLAS) + { + return blas::asum(N,A); + } + #else + { + return op_norm::vec_norm_1_direct_mem(N,A); + } + #endif + } + } + + + +template +inline +eT +op_norm::vec_norm_1_direct_mem(const uword N, const eT* A) + { + arma_extra_debug_sigprint(); + + #if (defined(ARMA_SIMPLE_LOOPS) || defined(__FAST_MATH__)) + { + eT acc1 = eT(0); + + if(memory::is_aligned(A)) + { + memory::mark_as_aligned(A); + + for(uword i=0; i +inline +typename T1::pod_type +op_norm::vec_norm_2(const Proxy& P, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const bool use_direct_mem = (is_Mat::stored_type>::value) || (is_subview_col::stored_type>::value) || (arma_config::openmp && Proxy::use_mp); + + if(use_direct_mem) + { + const quasi_unwrap::stored_type> tmp(P.Q); + + return op_norm::vec_norm_2_direct_std(tmp.M); + } + + typedef typename T1::pod_type T; + + T acc = T(0); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type A = P.get_ea(); + + const uword N = P.get_n_elem(); + + T acc1 = T(0); + T acc2 = T(0); + + uword i,j; + + for(i=0, j=1; j::stored_type> tmp(P.Q); + + return op_norm::vec_norm_2_direct_robust(tmp.M); + } + } + + + +template +inline +typename T1::pod_type +op_norm::vec_norm_2(const Proxy& P, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + T acc = T(0); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type A = P.get_ea(); + + const uword N = P.get_n_elem(); + + for(uword i=0; i& X = A[i]; + + const T a = X.real(); + const T b = X.imag(); + + acc += (a*a) + (b*b); + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows == 1) + { + for(uword col=0; col& X = P.at(0,col); + + const T a = X.real(); + const T b = X.imag(); + + acc += (a*a) + (b*b); + } + } + else + { + for(uword col=0; col& X = P.at(row,col); + + const T a = X.real(); + const T b = X.imag(); + + acc += (a*a) + (b*b); + } + } + } + + const T sqrt_acc = std::sqrt(acc); + + if( (sqrt_acc != T(0)) && arma_isfinite(sqrt_acc) ) + { + return sqrt_acc; + } + else + { + arma_extra_debug_print("op_norm::vec_norm_2(): detected possible underflow or overflow"); + + const quasi_unwrap::stored_type> R(P.Q); + + const uword N = R.M.n_elem; + const eT* R_mem = R.M.memptr(); + + T max_val = priv::most_neg(); + + for(uword i=0; i max_val) { max_val = val_i; } + } + + if(max_val == T(0)) { return T(0); } + + T alt_acc = T(0); + + for(uword i=0; i +inline +eT +op_norm::vec_norm_2_direct_std(const Mat& X) + { + arma_extra_debug_sigprint(); + + const uword N = X.n_elem; + const eT* A = X.memptr(); + + eT result; + + if(N < uword(32)) + { + result = op_norm::vec_norm_2_direct_mem(N,A); + } + else + { + #if defined(ARMA_USE_ATLAS) + { + result = atlas::cblas_nrm2(N,A); + } + #elif defined(ARMA_USE_BLAS) + { + result = blas::nrm2(N,A); + } + #else + { + result = op_norm::vec_norm_2_direct_mem(N,A); + } + #endif + } + + if( (result != eT(0)) && arma_isfinite(result) ) + { + return result; + } + else + { + arma_extra_debug_print("op_norm::vec_norm_2_direct_std(): detected possible underflow or overflow"); + + return op_norm::vec_norm_2_direct_robust(X); + } + } + + + +template +inline +eT +op_norm::vec_norm_2_direct_mem(const uword N, const eT* A) + { + arma_extra_debug_sigprint(); + + eT acc; + + #if (defined(ARMA_SIMPLE_LOOPS) || defined(__FAST_MATH__)) + { + eT acc1 = eT(0); + + if(memory::is_aligned(A)) + { + memory::mark_as_aligned(A); + + for(uword i=0; i +inline +eT +op_norm::vec_norm_2_direct_robust(const Mat& X) + { + arma_extra_debug_sigprint(); + + const uword N = X.n_elem; + const eT* A = X.memptr(); + + eT max_val = priv::most_neg(); + + uword j; + + for(j=1; j max_val) { max_val = val_i; } + if(val_j > max_val) { max_val = val_j; } + } + + if((j-1) < N) + { + const eT val_i = std::abs(*A); + + if(val_i > max_val) { max_val = val_i; } + } + + if(max_val == eT(0)) { return eT(0); } + + const eT* B = X.memptr(); + + eT acc1 = eT(0); + eT acc2 = eT(0); + + for(j=1; j +inline +typename T1::pod_type +op_norm::vec_norm_k(const Proxy& P, const int k) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + T acc = T(0); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type A = P.get_ea(); + + const uword N = P.get_n_elem(); + + for(uword i=0; i +inline +typename T1::pod_type +op_norm::vec_norm_max(const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const uword N = P.get_n_elem(); + + T max_val = (N != 1) ? priv::most_neg() : std::abs(P[0]); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type A = P.get_ea(); + + uword i,j; + for(i=0, j=1; j +inline +typename T1::pod_type +op_norm::vec_norm_min(const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const uword N = P.get_n_elem(); + + T min_val = (N != 1) ? priv::most_pos() : std::abs(P[0]); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type A = P.get_ea(); + + uword i,j; + for(i=0, j=1; j tmp_i) { min_val = tmp_i; } + if(min_val > tmp_j) { min_val = tmp_j; } + } + + if(i < N) + { + const T tmp_i = std::abs(A[i]); + + if(min_val > tmp_i) { min_val = tmp_i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows != 1) + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const T tmp = std::abs(P.at(row,col)); + + if(min_val > tmp) { min_val = tmp; } + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + const T tmp = std::abs(P.at(0,col)); + + if(min_val > tmp) { min_val = tmp; } + } + } + } + + return min_val; + } + + + +template +inline +typename get_pod_type::result +op_norm::mat_norm_1(const Mat& X) + { + arma_extra_debug_sigprint(); + + // TODO: this can be sped up with a dedicated implementation + return as_scalar( max( sum(abs(X), 0), 1) ); + } + + + +template +inline +typename get_pod_type::result +op_norm::mat_norm_2(const Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(X.internal_has_nonfinite()) { arma_debug_warn_level(1, "norm(): given matrix has non-finite elements"); } + + Col S; + svd(S, X); + + return (S.n_elem > 0) ? S[0] : T(0); + } + + + +template +inline +typename get_pod_type::result +op_norm::mat_norm_inf(const Mat& X) + { + arma_extra_debug_sigprint(); + + // TODO: this can be sped up with a dedicated implementation + return as_scalar( max( sum(abs(X), 1), 0) ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_normalise_bones.hpp b/src/armadillo/include/armadillo_bits/op_normalise_bones.hpp new file mode 100644 index 0000000..4b1932c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_normalise_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_normalise +//! @{ + + + +class op_normalise_vec + : public traits_op_passthru + { + public: + + template inline static void apply(Mat& out, const Op& in); + }; + + + +class op_normalise_mat + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static void apply(Mat& out, const Mat& A, const uword p, const uword dim); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_normalise_meat.hpp b/src/armadillo/include/armadillo_bits/op_normalise_meat.hpp new file mode 100644 index 0000000..e56c390 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_normalise_meat.hpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_normalise +//! @{ + + + +template +inline +void +op_normalise_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const uword p = in.aux_uword_a; + + arma_debug_check( (p == 0), "normalise(): unsupported vector norm type" ); + + const quasi_unwrap U(in.m); + + const T norm_val_a = norm(U.M, p); + const T norm_val_b = (norm_val_a != T(0)) ? norm_val_a : T(1); + + if(quasi_unwrap::has_subview && U.is_alias(out)) + { + Mat tmp = U.M / norm_val_b; + + out.steal_mem(tmp); + } + else + { + out = U.M / norm_val_b; + } + } + + + +template +inline +void +op_normalise_mat::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword p = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (p == 0), "normalise(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "normalise(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + + if(quasi_unwrap::has_subview && U.is_alias(out)) + { + Mat out2; + + op_normalise_mat::apply(out2, U.M, p, dim); + + out.steal_mem(out2); + } + else + { + op_normalise_mat::apply(out, U.M, p, dim); + } + } + + + +template +inline +void +op_normalise_mat::apply(Mat& out, const Mat& A, const uword p, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out.copy_size(A); + + if(A.n_elem == 0) { return; } + + if(dim == 0) + { + const uword n_cols = A.n_cols; + + for(uword i=0; i norm_vals(n_rows); + + T* norm_vals_mem = norm_vals.memptr(); + + for(uword i=0; i + inline static void apply(Mat& out, const Op& expr); + + template + inline static bool apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol); + }; + + + +class op_null + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& expr); + + template + inline static bool apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_orth_null_meat.hpp b/src/armadillo/include/armadillo_bits/op_orth_null_meat.hpp new file mode 100644 index 0000000..4a776ba --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_orth_null_meat.hpp @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_orth_null +//! @{ + + + +template +inline +void +op_orth::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const T tol = access::tmp_real(expr.aux); + + const bool status = op_orth::apply_direct(out, expr.m, tol); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("orth(): svd failed"); + } + } + + + +template +inline +bool +op_orth::apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + arma_debug_check((tol < T(0)), "orth(): tolerance must be >= 0"); + + Mat A(expr.get_ref()); + + Mat U; + Col< T> s; + Mat V; + + const bool status = auxlib::svd_dc(U, s, V, A); + + V.reset(); + + if(status == false) { return false; } + + if(s.is_empty()) { out.reset(); return true; } + + const uword s_n_elem = s.n_elem; + const T* s_mem = s.memptr(); + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * s_mem[0] * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < s_n_elem; ++i) { count += (s_mem[i] > tol) ? uword(1) : uword(0); } + + if(count > 0) + { + out = U.head_cols(count); // out *= eT(-1); + } + else + { + out.set_size(A.n_rows, 0); + } + + return true; + } + + + +// + + + +template +inline +void +op_null::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const T tol = access::tmp_real(expr.aux); + + const bool status = op_null::apply_direct(out, expr.m, tol); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("null(): svd failed"); + } + } + + + +template +inline +bool +op_null::apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + arma_debug_check((tol < T(0)), "null(): tolerance must be >= 0"); + + Mat A(expr.get_ref()); + + Mat U; + Col< T> s; + Mat V; + + const bool status = auxlib::svd_dc(U, s, V, A); + + U.reset(); + + if(status == false) { return false; } + + if(s.is_empty()) { out.reset(); return true; } + + const uword s_n_elem = s.n_elem; + const T* s_mem = s.memptr(); + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * s_mem[0] * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < s_n_elem; ++i) { count += (s_mem[i] > tol) ? uword(1) : uword(0); } + + if(count < A.n_cols) + { + out = V.tail_cols(A.n_cols - count); + + const uword out_n_elem = out.n_elem; + eT* out_mem = out.memptr(); + + for(uword i=0; i::epsilon()) { out_mem[i] = eT(0); } + } + } + else + { + out.set_size(A.n_cols, 0); + } + + return true; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_pinv_bones.hpp b/src/armadillo/include/armadillo_bits/op_pinv_bones.hpp new file mode 100644 index 0000000..bf83ddc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_pinv_bones.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_pinv +//! @{ + + + +class op_pinv_default + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static bool apply_direct(Mat& out, const Base& expr); + }; + + + +class op_pinv + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static bool apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol, const uword method_id); + + template inline static bool apply_diag(Mat& out, const Mat& A, typename get_pod_type::result tol); + + template inline static bool apply_sym (Mat& out, const Mat& A, typename get_pod_type::result tol, const uword method_id); + + template inline static bool apply_gen (Mat& out, Mat& A, typename get_pod_type::result tol, const uword method_id); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_pinv_meat.hpp b/src/armadillo/include/armadillo_bits/op_pinv_meat.hpp new file mode 100644 index 0000000..326a0be --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_pinv_meat.hpp @@ -0,0 +1,313 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_pinv +//! @{ + + + +template +inline +void +op_pinv_default::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_pinv_default::apply_direct(out, in.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("pinv(): svd failed"); + } + } + + + +template +inline +bool +op_pinv_default::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + constexpr T tol = T(0); + constexpr uword method_id = uword(0); + + return op_pinv::apply_direct(out, expr, tol, method_id); + } + + + +// + + + +template +inline +void +op_pinv::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const T tol = access::tmp_real(in.aux); + const uword method_id = in.aux_uword_a; + + const bool status = op_pinv::apply_direct(out, in.m, tol, method_id); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("pinv(): svd failed"); + } + } + + + +template +inline +bool +op_pinv::apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); + + // method_id = 0 -> default setting + // method_id = 1 -> use standard algorithm + // method_id = 2 -> use divide and conquer algorithm + + Mat A(expr.get_ref()); + + if(A.is_empty()) { out.set_size(A.n_cols,A.n_rows); return true; } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_pinv: detected diagonal matrix"); + + return op_pinv::apply_diag(out, A, tol); + } + + bool do_sym = false; + + const bool is_sym_size_ok = (A.n_rows == A.n_cols) && (A.n_rows > (is_cx::yes ? uword(20) : uword(40))); + + if( (is_sym_size_ok) && (arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) ) + { + bool is_approx_sym = false; + bool is_approx_sympd = false; + + sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); + + do_sym = ((is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd)); + } + + if(do_sym) + { + arma_extra_debug_print("op_pinv: symmetric/hermitian optimisation"); + + return op_pinv::apply_sym(out, A, tol, method_id); + } + + return op_pinv::apply_gen(out, A, tol, method_id); + } + + + +template +inline +bool +op_pinv::apply_diag(Mat& out, const Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out.zeros(A.n_cols, A.n_rows); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + podarray diag_abs_vals(N); + + T max_abs_Aii = T(0); + + for(uword i=0; i max_abs_Aii) ? abs_Aii : max_abs_Aii; + } + + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * max_abs_Aii * std::numeric_limits::epsilon(); } + + for(uword i=0; i= tol) + { + const eT Aii = A.at(i,i); + + if(Aii != eT(0)) { out.at(i,i) = eT(eT(1) / Aii); } + } + } + + return true; + } + + + +template +inline +bool +op_pinv::apply_sym(Mat& out, const Mat& A, typename get_pod_type::result tol, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Col< T> eigval; + Mat eigvec; + + const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? auxlib::eig_sym_dc(eigval, eigvec, A) : auxlib::eig_sym(eigval, eigvec, A); + + if(status == false) { return false; } + + if(eigval.n_elem == 0) { out.zeros(A.n_cols, A.n_rows); return true; } + + Col abs_eigval = arma::abs(eigval); + + const uvec indices = sort_index(abs_eigval, "descend"); + + abs_eigval = abs_eigval.elem(indices); + eigval = eigval.elem(indices); + eigvec = eigvec.cols(indices); + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * abs_eigval[0] * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < abs_eigval.n_elem; ++i) { count += (abs_eigval[i] >= tol) ? uword(1) : uword(0); } + + if(count == 0) { out.zeros(A.n_cols, A.n_rows); return true; } + + Col eigval2(count, arma_nozeros_indicator()); + + uword count2 = 0; + + for(uword i=0; i < eigval.n_elem; ++i) + { + const T abs_val = abs_eigval[i]; + const T val = eigval[i]; + + if(abs_val >= tol) { eigval2[count2] = (val != T(0)) ? T(T(1) / val) : T(0); ++count2; } + } + + const Mat eigvec_use(eigvec.memptr(), eigvec.n_rows, count, false); + + out = (eigvec_use * diagmat(eigval2)).eval() * eigvec_use.t(); + + return true; + } + + + + +template +inline +bool +op_pinv::apply_gen(Mat& out, Mat& A, typename get_pod_type::result tol, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + // economical SVD decomposition + Mat U; + Col< T> s; + Mat V; + + if(n_cols > n_rows) { A = trans(A); } + + const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? auxlib::svd_dc_econ(U, s, V, A) : auxlib::svd_econ(U, s, V, A, 'b'); + + if(status == false) { return false; } + + // set tolerance to default if it hasn't been specified + if( (tol == T(0)) && (s.n_elem > 0) ) { tol = (std::max)(n_rows, n_cols) * s[0] * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < s.n_elem; ++i) { count += (s[i] >= tol) ? uword(1) : uword(0); } + + if(count == 0) { out.zeros(n_cols, n_rows); return true; } + + Col s2(count, arma_nozeros_indicator()); + + uword count2 = 0; + + for(uword i=0; i < s.n_elem; ++i) + { + const T val = s[i]; + + if(val >= tol) { s2[count2] = (val > T(0)) ? T(T(1) / val) : T(0); ++count2; } + } + + const Mat U_use(U.memptr(), U.n_rows, count, false); + const Mat V_use(V.memptr(), V.n_rows, count, false); + + Mat tmp; + + if(n_rows >= n_cols) + { + // out = ( (V.n_cols > count) ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( (U.n_cols > count) ? U.cols(0,count-1) : U ); + + tmp = V_use * diagmat(s2); + + out = tmp * trans(U_use); + } + else + { + // out = ( (U.n_cols > count) ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( (V.n_cols > count) ? V.cols(0,count-1) : V ); + + tmp = U_use * diagmat(s2); + + out = tmp * trans(V_use); + } + + return true; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_powmat_bones.hpp b/src/armadillo/include/armadillo_bits/op_powmat_bones.hpp new file mode 100644 index 0000000..021522b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_powmat_bones.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_powmat +//! @{ + + + +class op_powmat + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& expr); + + template + inline static bool apply_direct(Mat& out, const Base& X, const uword y, const bool y_neg); + + template + inline static void apply_direct_positive(Mat& out, const Mat& X, const uword y); + }; + + + +class op_powmat_cx + : public traits_op_default + { + public: + + template + inline static void apply(Mat< std::complex >& out, const mtOp,T1,op_powmat_cx>& expr); + + template + inline static bool apply_direct(Mat< std::complex >& out, const Base& X, const typename T1::pod_type y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_powmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_powmat_meat.hpp new file mode 100644 index 0000000..323db32 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_powmat_meat.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_powmat +//! @{ + + +template +inline +void +op_powmat::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + const uword y = expr.aux_uword_a; + const bool y_neg = (expr.aux_uword_b == uword(1)); + + const bool status = op_powmat::apply_direct(out, expr.m, y, y_neg); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("powmat(): transformation failed"); + } + } + + + +template +inline +bool +op_powmat::apply_direct(Mat& out, const Base& X, const uword y, const bool y_neg) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(y_neg) + { + if(y == uword(1)) + { + return op_inv_gen_default::apply_direct(out, X.get_ref(), "powmat()"); + } + else + { + Mat X_inv; + + const bool inv_status = op_inv_gen_default::apply_direct(X_inv, X.get_ref(), "powmat()"); + + if(inv_status == false) { return false; } + + op_powmat::apply_direct_positive(out, X_inv, y); + } + } + else + { + const quasi_unwrap U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "powmat(): given matrix must be square sized" ); + + op_powmat::apply_direct_positive(out, U.M, y); + } + + return true; + } + + + +template +inline +void +op_powmat::apply_direct_positive(Mat& out, const Mat& X, const uword y) + { + arma_extra_debug_sigprint(); + + const uword N = X.n_rows; + + if(y == uword(0)) { out.eye(N,N); return; } + if(y == uword(1)) { out = X; return; } + + if(X.is_diagmat()) + { + arma_extra_debug_print("op_powmat: detected diagonal matrix"); + + podarray tmp(N); // use temporary array in case we have aliasing + + for(uword i=0; i tmp = X*X; out = X*tmp; } + else if(y == uword(4)) { const Mat tmp = X*X; out = tmp*tmp; } + else if(y == uword(5)) { const Mat tmp = X*X; out = X*tmp*tmp; } + else + { + Mat tmp = X; + + out = X; + + uword z = y-1; + + while(z > 0) + { + if(z & 1) { out = tmp * out; } + + z /= uword(2); + + if(z > 0) { tmp = tmp * tmp; } + } + } + } + } + + + +template +inline +void +op_powmat_cx::apply(Mat< std::complex >& out, const mtOp,T1,op_powmat_cx>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type in_T; + + const in_T y = std::real(expr.aux_out_eT); + + const bool status = op_powmat_cx::apply_direct(out, expr.m, y); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("powmat(): transformation failed"); + } + } + + + +template +inline +bool +op_powmat_cx::apply_direct(Mat< std::complex >& out, const Base& X, const typename T1::pod_type y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT; + typedef typename T1::pod_type in_T; + typedef std::complex out_eT; + + if( y == in_T(int(y)) ) + { + arma_extra_debug_print("op_powmat_cx::apply_direct(): integer exponent detected; redirecting to op_powmat"); + + const uword y_val = (y < int(0)) ? uword(-y) : uword(y); + const bool y_neg = (y < int(0)); + + Mat tmp; + + const bool status = op_powmat::apply_direct(tmp, X.get_ref(), y_val, y_neg); + + if(status == false) { return false; } + + out = conv_to< Mat >::from(tmp); + + return true; + } + + const quasi_unwrap U(X.get_ref()); + const Mat& A = U.M; + + arma_debug_check( (A.is_square() == false), "powmat(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + if(A.is_diagmat()) + { + arma_extra_debug_print("op_powmat_cx: detected diagonal matrix"); + + podarray tmp(N); // use temporary array in case we have aliasing + + for(uword i=0; i(A.at(i,i)), y) ; } + + out.zeros(N,N); + + for(uword i=0; i eigval; + Mat eigvec; + + const bool eig_status = eig_sym(eigval, eigvec, A); + + if(eig_status) + { + eigval = pow(eigval, y); + + const Mat tmp = diagmat(eigval) * eigvec.t(); + + out = conv_to< Mat >::from(eigvec * tmp); + + return true; + } + + arma_extra_debug_print("op_powmat_cx: sympd optimisation failed"); + + // fallthrough if optimisation failed + } + + bool powmat_status = false; + + Col eigval; + Mat eigvec; + + const bool eig_status = eig_gen(eigval, eigvec, A); + + if(eig_status) + { + eigval = pow(eigval, y); + + Mat eigvec_t = trans(eigvec); + Mat tmp = diagmat(conj(eigval)) * eigvec_t; + + const bool solve_status = auxlib::solve_square_fast(out, eigvec_t, tmp); + + if(solve_status) { out = trans(out); powmat_status = true; } + } + + return powmat_status; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_princomp_bones.hpp b/src/armadillo/include/armadillo_bits/op_princomp_bones.hpp new file mode 100644 index 0000000..4d6abaa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_princomp_bones.hpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_princomp +//! @{ + + + +class op_princomp + : public traits_op_default + { + public: + + template + inline static bool + direct_princomp + ( + Mat& coeff_out, + Mat& score_out, + Col& latent_out, + Col& tsquared_out, + const Base& X + ); + + + template + inline static bool + direct_princomp + ( + Mat& coeff_out, + Mat& score_out, + Col& latent_out, + const Base& X + ); + + template + inline static bool + direct_princomp + ( + Mat& coeff_out, + Mat& score_out, + const Base& X + ); + + template + inline static bool + direct_princomp + ( + Mat& coeff_out, + const Base& X + ); + + template + inline static void + apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_princomp_meat.hpp b/src/armadillo/include/armadillo_bits/op_princomp_meat.hpp new file mode 100644 index 0000000..db6f83f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_princomp_meat.hpp @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_princomp +//! @{ + + + +//! \brief +//! principal component analysis -- 4 arguments version +//! computation is done via singular value decomposition +//! coeff_out -> principal component coefficients +//! score_out -> projected samples +//! latent_out -> eigenvalues of principal vectors +//! tsquared_out -> Hotelling's T^2 statistic +template +inline +bool +op_princomp::direct_princomp + ( + Mat& coeff_out, + Mat& score_out, + Col& latent_out, + Col& tsquared_out, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap_check Y( X.get_ref(), score_out ); + const Mat& in = Y.M; + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + + if(n_rows > 1) // more than one sample + { + // subtract the mean - use score_out as temporary matrix + score_out = in; score_out.each_row() -= mean(in); + + // singular value decomposition + Mat U; + Col< T> s; + + const bool svd_ok = (n_rows >= n_cols) ? svd_econ(U, s, coeff_out, score_out) : svd(U, s, coeff_out, score_out); + + if(svd_ok == false) { return false; } + + // normalize the eigenvalues + s /= std::sqrt( double(n_rows - 1) ); + + // project the samples to the principals + score_out *= coeff_out; + + if(n_rows <= n_cols) // number of samples is less than their dimensionality + { + score_out.cols(n_rows-1,n_cols-1).zeros(); + + Col s_tmp(n_cols, arma_zeros_indicator()); + + s_tmp.rows(0,n_rows-2) = s.rows(0,n_rows-2); + s = s_tmp; + + // compute the Hotelling's T-squared + s_tmp.rows(0,n_rows-2) = T(1) / s_tmp.rows(0,n_rows-2); + + const Mat S = score_out * diagmat(Col(s_tmp)); + tsquared_out = sum(S%S,1); + } + else + { + // compute the Hotelling's T-squared + // TODO: replace with more robust approach + const Mat S = score_out * diagmat(Col( T(1) / s)); + tsquared_out = sum(S%S,1); + } + + // compute the eigenvalues of the principal vectors + latent_out = s%s; + } + else // 0 or 1 samples + { + coeff_out.eye(n_cols, n_cols); + + score_out.copy_size(in); + score_out.zeros(); + + latent_out.set_size(n_cols); + latent_out.zeros(); + + tsquared_out.set_size(n_rows); + tsquared_out.zeros(); + } + + return true; + } + + + +//! \brief +//! principal component analysis -- 3 arguments version +//! computation is done via singular value decomposition +//! coeff_out -> principal component coefficients +//! score_out -> projected samples +//! latent_out -> eigenvalues of principal vectors +template +inline +bool +op_princomp::direct_princomp + ( + Mat& coeff_out, + Mat& score_out, + Col& latent_out, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap_check Y( X.get_ref(), score_out ); + const Mat& in = Y.M; + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + + if(n_rows > 1) // more than one sample + { + // subtract the mean - use score_out as temporary matrix + score_out = in; score_out.each_row() -= mean(in); + + // singular value decomposition + Mat U; + Col< T> s; + + const bool svd_ok = (n_rows >= n_cols) ? svd_econ(U, s, coeff_out, score_out) : svd(U, s, coeff_out, score_out); + + if(svd_ok == false) { return false; } + + // normalize the eigenvalues + s /= std::sqrt( double(n_rows - 1) ); + + // project the samples to the principals + score_out *= coeff_out; + + if(n_rows <= n_cols) // number of samples is less than their dimensionality + { + score_out.cols(n_rows-1,n_cols-1).zeros(); + + Col s_tmp(n_cols, arma_zeros_indicator()); + + s_tmp.rows(0,n_rows-2) = s.rows(0,n_rows-2); + s = s_tmp; + } + + // compute the eigenvalues of the principal vectors + latent_out = s%s; + } + else // 0 or 1 samples + { + coeff_out.eye(n_cols, n_cols); + + score_out.copy_size(in); + score_out.zeros(); + + latent_out.set_size(n_cols); + latent_out.zeros(); + } + + return true; + } + + + +//! \brief +//! principal component analysis -- 2 arguments version +//! computation is done via singular value decomposition +//! coeff_out -> principal component coefficients +//! score_out -> projected samples +template +inline +bool +op_princomp::direct_princomp + ( + Mat& coeff_out, + Mat& score_out, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap_check Y( X.get_ref(), score_out ); + const Mat& in = Y.M; + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + + if(n_rows > 1) // more than one sample + { + // subtract the mean - use score_out as temporary matrix + score_out = in; score_out.each_row() -= mean(in); + + // singular value decomposition + Mat U; + Col< T> s; + + const bool svd_ok = (n_rows >= n_cols) ? svd_econ(U, s, coeff_out, score_out) : svd(U, s, coeff_out, score_out); + + if(svd_ok == false) { return false; } + + // project the samples to the principals + score_out *= coeff_out; + + if(n_rows <= n_cols) // number of samples is less than their dimensionality + { + score_out.cols(n_rows-1,n_cols-1).zeros(); + } + } + else // 0 or 1 samples + { + coeff_out.eye(n_cols, n_cols); + score_out.copy_size(in); + score_out.zeros(); + } + + return true; + } + + + +//! \brief +//! principal component analysis -- 1 argument version +//! computation is done via singular value decomposition +//! coeff_out -> principal component coefficients +template +inline +bool +op_princomp::direct_princomp + ( + Mat& coeff_out, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap Y( X.get_ref() ); + const Mat& in = Y.M; + + if(in.n_elem != 0) + { + Mat tmp = in; tmp.each_row() -= mean(in); + + // singular value decomposition + Mat U; + Col< T> s; + + const bool svd_ok = (in.n_rows >= in.n_cols) ? svd_econ(U, s, coeff_out, tmp) : svd(U, s, coeff_out, tmp); + + if(svd_ok == false) { return false; } + } + else + { + coeff_out.eye(in.n_cols, in.n_cols); + } + + return true; + } + + + +template +inline +void +op_princomp::apply + ( + Mat& out, + const Op& in + ) + { + arma_extra_debug_sigprint(); + + const bool status = op_princomp::direct_princomp(out, in.m); + + if(status == false) + { + out.soft_reset(); + + arma_stop_runtime_error("princomp(): decomposition failed"); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_prod_bones.hpp b/src/armadillo/include/armadillo_bits/op_prod_bones.hpp new file mode 100644 index 0000000..790401c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_prod_bones.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_prod +//! @{ + + +class op_prod + : public traits_op_xvec + { + public: + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim); + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static eT prod(const subview& S); + + template + inline static typename T1::elem_type prod(const Base& X); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_prod_meat.hpp b/src/armadillo/include/armadillo_bits/op_prod_meat.hpp new file mode 100644 index 0000000..c1b9577 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_prod_meat.hpp @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_prod +//! @{ + + +template +inline +void +op_prod::apply_noalias(Mat& out, const Mat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) // traverse across rows (ie. find the product in each column) + { + out.set_size(1, X_n_cols); + + eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = arrayops::product(X.colptr(col), X_n_rows); + } + } + else // traverse across columns (ie. find the product in each row) + { + out.ones(X_n_rows, 1); + + eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + const eT* X_col_mem = X.colptr(col); + + for(uword row=0; row < X_n_rows; ++row) + { + out_mem[row] *= X_col_mem[row]; + } + } + } + } + + + +template +inline +void +op_prod::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + + arma_debug_check( (dim > 1), "prod(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_prod::apply_noalias(tmp, U.M, dim); + + out.steal_mem(tmp); + } + else + { + op_prod::apply_noalias(out, U.M, dim); + } + } + + + +template +inline +eT +op_prod::prod(const subview& X) + { + arma_extra_debug_sigprint(); + + eT val = eT(1); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(X_n_rows == 1) + { + const Mat& A = X.m; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_col_p1 = start_col + X_n_cols; + + uword i,j; + for(i=start_col, j=start_col+1; j < end_col_p1; i+=2, j+=2) + { + val *= A.at(start_row, i); + val *= A.at(start_row, j); + } + + if(i < end_col_p1) + { + val *= A.at(start_row, i); + } + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + val *= arrayops::product( X.colptr(col), X_n_rows ); + } + } + + return val; + } + + + +template +inline +typename T1::elem_type +op_prod::prod(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(X.get_ref()); + + eT val = eT(1); + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + const ea_type A = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + uword i,j; + for(i=0, j=1; j < n_elem; i+=2, j+=2) + { + val *= A[i]; + val *= A[j]; + } + + if(i < n_elem) + { + val *= A[i]; + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows == 1) + { + uword i,j; + for(i=0, j=1; j < n_cols; i+=2, j+=2) + { + val *= P.at(0,i); + val *= P.at(0,j); + } + + if(i < n_cols) + { + val *= P.at(0,i); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + uword i,j; + for(i=0, j=1; j < n_rows; i+=2, j+=2) + { + val *= P.at(i,col); + val *= P.at(j,col); + } + + if(i < n_rows) + { + val *= P.at(i,col); + } + } + } + } + + return val; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_range_bones.hpp b/src/armadillo/include/armadillo_bits/op_range_bones.hpp new file mode 100644 index 0000000..5745624 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_range_bones.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_range +//! @{ + + +class op_range + : public traits_op_xvec + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim); + + template + inline static typename T1::elem_type vector_range(const T1& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_range_meat.hpp b/src/armadillo/include/armadillo_bits/op_range_meat.hpp new file mode 100644 index 0000000..a3e66c2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_range_meat.hpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_range +//! @{ + + + +template +inline +void +op_range::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "range(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + if(U.is_alias(out) == false) + { + op_range::apply_noalias(out, X, dim); + } + else + { + Mat tmp; + + op_range::apply_noalias(tmp, X, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_range::apply_noalias(Mat& out, const Mat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + // TODO: replace with dedicated implementation which finds min and max at the same time + out = max(X,dim) - min(X,dim); + } + + + +template +inline +typename T1::elem_type +op_range::vector_range(const T1& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(expr); + const Mat& X = U.M; + + const eT* X_mem = X.memptr(); + const uword N = X.n_elem; + + if(N == 0) + { + arma_debug_check(true, "range(): object has no elements"); + + return Datum::nan; + } + + // TODO: replace with dedicated implementation which finds min and max at the same time + return op_max::direct_max(X_mem, N) - op_min::direct_min(X_mem, N); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_rank_bones.hpp b/src/armadillo/include/armadillo_bits/op_rank_bones.hpp new file mode 100644 index 0000000..f0c4a07 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_rank_bones.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_rank +//! @{ + + + +class op_rank + : public traits_op_default + { + public: + + template inline static bool apply(uword& out, const Base& expr, const typename T1::pod_type tol); + + template inline static bool apply_gen(uword& out, Mat& A, typename get_pod_type::result tol); + + template inline static bool apply_sym(uword& out, Mat& A, typename get_pod_type::result tol); + + template inline static bool apply_diag(uword& out, Mat& A, typename get_pod_type::result tol); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_rank_meat.hpp b/src/armadillo/include/armadillo_bits/op_rank_meat.hpp new file mode 100644 index 0000000..ef00dd3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_rank_meat.hpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_rank +//! @{ + + + +template +inline +bool +op_rank::apply(uword& out, const Base& expr, const typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + Mat A(expr.get_ref()); + + if(A.is_empty()) { out = uword(0); return true; } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_rank::apply(): detected diagonal matrix"); + + return op_rank::apply_diag(out, A, tol); + } + + bool do_sym = false; + + if((arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) && (A.n_rows >= (is_cx::yes ? uword(64) : uword(128)))) + { + bool is_approx_sym = false; + bool is_approx_sympd = false; + + sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); + + do_sym = (is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd); + } + + if(do_sym) + { + arma_extra_debug_print("op_rank::apply(): symmetric/hermitian optimisation"); + + return op_rank::apply_sym(out, A, tol); + } + + return op_rank::apply_gen(out, A, tol); + } + + + +template +inline +bool +op_rank::apply_diag(uword& out, Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword N = (std::min)(A.n_rows, A.n_cols); + + podarray diag_abs_vals(N); + + T max_abs_Aii = T(0); + + for(uword i=0; i max_abs_Aii) ? abs_Aii : max_abs_Aii; + } + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * max_abs_Aii * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i tol) ? uword(1) : uword(0); } + + out = count; + + return true; + } + + + +template +inline +bool +op_rank::apply_sym(uword& out, Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.is_square() == false) { out = uword(0); return false; } + + Col v; + + const bool status = auxlib::eig_sym(v, A); + + if(status == false) { out = uword(0); return false; } + + const uword v_n_elem = v.n_elem; + T* v_mem = v.memptr(); + + if(v_n_elem == 0) { out = uword(0); return true; } + + T max_abs_v = T(0); + + for(uword i=0; i < v_n_elem; ++i) { const T val = std::abs(v_mem[i]); v_mem[i] = val; if(val > max_abs_v) { max_abs_v = val; } } + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * max_abs_v * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < v_n_elem; ++i) { count += (v_mem[i] > tol) ? uword(1) : uword(0); } + + out = count; + + return true; + } + + + +template +inline +bool +op_rank::apply_gen(uword& out, Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Col s; + + const bool status = auxlib::svd_dc(s, A); + + if(status == false) { out = uword(0); return false; } + + const uword s_n_elem = s.n_elem; + const T* s_mem = s.memptr(); + + if(s_n_elem == 0) { out = uword(0); return true; } + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * s_mem[0] * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < s_n_elem; ++i) { count += (s_mem[i] > tol) ? uword(1) : uword(0); } + + out = count; + + return true; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_rcond_bones.hpp b/src/armadillo/include/armadillo_bits/op_rcond_bones.hpp new file mode 100644 index 0000000..88697e8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_rcond_bones.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_rcond +//! @{ + + +class op_rcond + : public traits_op_default + { + public: + + template static inline typename T1::pod_type apply(const Base& X); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_rcond_meat.hpp b/src/armadillo/include/armadillo_bits/op_rcond_meat.hpp new file mode 100644 index 0000000..48123bd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_rcond_meat.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_rcond +//! @{ + + + +template +inline +typename T1::pod_type +op_rcond::apply(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(strip_trimat::do_trimat) + { + const strip_trimat S(X.get_ref()); + + const quasi_unwrap::stored_type> U(S.M); + + arma_debug_check( (U.M.is_square() == false), "rcond(): matrix must be square sized" ); + + const uword layout = (S.do_triu) ? uword(0) : uword(1); + + return auxlib::rcond_trimat(U.M, layout); + } + + Mat A = X.get_ref(); + + arma_debug_check( (A.is_square() == false), "rcond(): matrix must be square sized" ); + + if(A.is_empty()) { return Datum::inf; } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_rcond::apply(): detected diagonal matrix"); + + const eT* colmem = A.memptr(); + const uword N = A.n_rows; + + T abs_min = Datum::inf; + T abs_max = T(0); + + for(uword i=0; i abs_max) ? abs_val : abs_max; + + colmem += N; + } + + if((abs_min == T(0)) || (abs_max == T(0))) { return T(0); } + + return T(abs_min / abs_max); + } + + const bool is_triu = trimat_helper::is_triu(A); + const bool is_tril = (is_triu) ? false : trimat_helper::is_tril(A); + + if(is_triu || is_tril) + { + const uword layout = (is_triu) ? uword(0) : uword(1); + + return auxlib::rcond_trimat(A, layout); + } + + const bool try_sympd = arma_config::optimise_sym && (auxlib::crippled_lapack(A) ? false : sym_helper::guess_sympd(A)); + + if(try_sympd) + { + arma_extra_debug_print("op_rcond::apply(): attempting sympd optimisation"); + + bool calc_ok = false; + + const T out_val = auxlib::rcond_sympd(A, calc_ok); + + if(calc_ok) { return out_val; } + + arma_extra_debug_print("op_rcond::apply(): sympd optimisation failed"); + + // auxlib::rcond_sympd() may have failed because A isn't really sympd + // restore A, as auxlib::rcond_sympd() may have destroyed it + A = X.get_ref(); + // fallthrough to the next return statement + } + + return auxlib::rcond(A); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_relational_bones.hpp b/src/armadillo/include/armadillo_bits/op_relational_bones.hpp new file mode 100644 index 0000000..0e8ecab --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_relational_bones.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_relational +//! @{ + + + +class op_rel_lt_pre + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_lt_post + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_gt_pre + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_gt_post + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_lteq_pre + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_lteq_post + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_gteq_pre + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_gteq_post + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_eq + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +class op_rel_noteq + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + + template + inline static void apply(Cube& out, const mtOpCube& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_relational_meat.hpp b/src/armadillo/include/armadillo_bits/op_relational_meat.hpp new file mode 100644 index 0000000..6c7344b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_relational_meat.hpp @@ -0,0 +1,510 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_relational +//! @{ + + +#undef operator_rel + +#undef arma_applier_mat_pre +#undef arma_applier_mat_post + +#undef arma_applier_cube_pre +#undef arma_applier_cube_post + + +#define arma_applier_mat_pre(operator_rel) \ + {\ + typedef typename T1::elem_type eT;\ + typedef typename Proxy::ea_type ea_type;\ + \ + const eT val = X.aux;\ + \ + const Proxy P(X.m);\ + \ + const uword n_rows = P.get_n_rows();\ + const uword n_cols = P.get_n_cols();\ + \ + const bool bad_alias = ( Proxy::has_subview && P.is_alias(out) );\ + \ + if(bad_alias == false)\ + {\ + out.set_size(n_rows, n_cols);\ + \ + uword* out_mem = out.memptr();\ + \ + if(Proxy::use_at == false)\ + {\ + ea_type PA = P.get_ea();\ + const uword n_elem = out.n_elem;\ + \ + for(uword i=0; i tmp(P.Q);\ + \ + out = (val) operator_rel (tmp);\ + }\ + } + + + +#define arma_applier_mat_post(operator_rel) \ + {\ + typedef typename T1::elem_type eT;\ + typedef typename Proxy::ea_type ea_type;\ + \ + const eT val = X.aux;\ + \ + const Proxy P(X.m);\ + \ + const uword n_rows = P.get_n_rows();\ + const uword n_cols = P.get_n_cols();\ + \ + const bool bad_alias = ( Proxy::has_subview && P.is_alias(out) );\ + \ + if(bad_alias == false)\ + {\ + out.set_size(n_rows, n_cols);\ + \ + uword* out_mem = out.memptr();\ + \ + if(Proxy::use_at == false)\ + {\ + ea_type PA = P.get_ea();\ + const uword n_elem = out.n_elem;\ + \ + for(uword i=0; i tmp(P.Q);\ + \ + out = (tmp) operator_rel (val);\ + }\ + } + + + +#define arma_applier_cube_pre(operator_rel) \ + {\ + typedef typename T1::elem_type eT;\ + typedef typename ProxyCube::ea_type ea_type;\ + \ + const eT val = X.aux;\ + \ + const ProxyCube P(X.m);\ + \ + const uword n_rows = P.get_n_rows();\ + const uword n_cols = P.get_n_cols();\ + const uword n_slices = P.get_n_slices();\ + \ + const bool bad_alias = ( ProxyCube::has_subview && P.is_alias(out) );\ + \ + if(bad_alias == false)\ + {\ + out.set_size(n_rows, n_cols, n_slices);\ + \ + uword* out_mem = out.memptr();\ + \ + if(ProxyCube::use_at == false)\ + {\ + ea_type PA = P.get_ea();\ + const uword n_elem = out.n_elem;\ + \ + for(uword i=0; i::stored_type> tmp(P.Q);\ + \ + out = (val) operator_rel (tmp.M);\ + }\ + } + + + +#define arma_applier_cube_post(operator_rel) \ + {\ + typedef typename T1::elem_type eT;\ + typedef typename ProxyCube::ea_type ea_type;\ + \ + const eT val = X.aux;\ + \ + const ProxyCube P(X.m);\ + \ + const uword n_rows = P.get_n_rows();\ + const uword n_cols = P.get_n_cols();\ + const uword n_slices = P.get_n_slices();\ + \ + const bool bad_alias = ( ProxyCube::has_subview && P.is_alias(out) );\ + \ + if(bad_alias == false)\ + {\ + out.set_size(n_rows, n_cols, n_slices);\ + \ + uword* out_mem = out.memptr();\ + \ + if(ProxyCube::use_at == false)\ + {\ + ea_type PA = P.get_ea();\ + const uword n_elem = out.n_elem;\ + \ + for(uword i=0; i::stored_type> tmp(P.Q);\ + \ + out = (tmp.M) operator_rel (val);\ + }\ + } + + + +template +inline +void +op_rel_lt_pre::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_pre( < ); + } + + + +template +inline +void +op_rel_gt_pre::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_pre( > ); + } + + + +template +inline +void +op_rel_lteq_pre::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_pre( <= ); + } + + + +template +inline +void +op_rel_gteq_pre::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_pre( >= ); + } + + + +template +inline +void +op_rel_lt_post::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_post( < ); + } + + + +template +inline +void +op_rel_gt_post::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_post( > ); + } + + + +template +inline +void +op_rel_lteq_post::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_post( <= ); + } + + + +template +inline +void +op_rel_gteq_post::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_post( >= ); + } + + + +template +inline +void +op_rel_eq::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_post( == ); + } + + + +template +inline +void +op_rel_noteq::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + arma_applier_mat_post( != ); + } + + + +// +// +// + + + +template +inline +void +op_rel_lt_pre::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_pre( < ); + } + + + +template +inline +void +op_rel_gt_pre::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_pre( > ); + } + + + +template +inline +void +op_rel_lteq_pre::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_pre( <= ); + } + + + +template +inline +void +op_rel_gteq_pre::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_pre( >= ); + } + + + +template +inline +void +op_rel_lt_post::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_post( < ); + } + + + +template +inline +void +op_rel_gt_post::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_post( > ); + } + + + +template +inline +void +op_rel_lteq_post::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_post( <= ); + } + + + +template +inline +void +op_rel_gteq_post::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_post( >= ); + } + + + +template +inline +void +op_rel_eq::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_post( == ); + } + + + +template +inline +void +op_rel_noteq::apply(Cube& out, const mtOpCube& X) + { + arma_extra_debug_sigprint(); + + arma_applier_cube_post( != ); + } + + + +#undef arma_applier_mat_pre +#undef arma_applier_mat_post + +#undef arma_applier_cube_pre +#undef arma_applier_cube_post + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_repelem_bones.hpp b/src/armadillo/include/armadillo_bits/op_repelem_bones.hpp new file mode 100644 index 0000000..52d20a4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_repelem_bones.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_repelem +//! @{ + + + +class op_repelem + : public traits_op_default + { + public: + + template inline static void apply_noalias(Mat& out, const obj& X, const uword copies_per_row, const uword copies_per_col); + + template inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_repelem_meat.hpp b/src/armadillo/include/armadillo_bits/op_repelem_meat.hpp new file mode 100644 index 0000000..38bfe36 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_repelem_meat.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_repelem +//! @{ + + + +template +inline +void +op_repelem::apply_noalias(Mat& out, const obj& X, const uword copies_per_row, const uword copies_per_col) + { + arma_extra_debug_sigprint(); + + typedef typename obj::elem_type eT; + + const uword X_n_rows = obj::is_row ? uword(1) : X.n_rows; + const uword X_n_cols = obj::is_col ? uword(1) : X.n_cols; + + out.set_size(X_n_rows * copies_per_row, X_n_cols * copies_per_col); + + if(out.n_elem == 0) { return; } + + for(uword col=0; col < X_n_cols; ++col) + { + const uword out_col_offset = col * copies_per_col; + + eT* out_colptr_first = out.colptr(out_col_offset); + + for(uword row=0; row < X_n_rows; ++row) + { + const uword out_row_offset = row * copies_per_row; + + const eT copy_value = X.at(row, col); + + for(uword row_copy=0; row_copy < copies_per_row; ++row_copy) + { + out_colptr_first[out_row_offset + row_copy] = copy_value; + } + + if(copies_per_col != 1) + { + for(uword col_copy=1; col_copy < copies_per_col; ++col_copy) + { + eT* out_colptr = out.colptr(out_col_offset + col_copy); + + arrayops::copy(&out_colptr[out_row_offset], &out_colptr_first[out_row_offset], copies_per_row); + } + } + } + } + } + + + +template +inline +void +op_repelem::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword copies_per_row = in.aux_uword_a; + const uword copies_per_col = in.aux_uword_b; + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_repelem::apply_noalias(tmp, U.M, copies_per_row, copies_per_col); + + out.steal_mem(tmp); + } + else + { + op_repelem::apply_noalias(out, U.M, copies_per_row, copies_per_col); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_repmat_bones.hpp b/src/armadillo/include/armadillo_bits/op_repmat_bones.hpp new file mode 100644 index 0000000..100179a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_repmat_bones.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_repmat +//! @{ + + + +class op_repmat + : public traits_op_default + { + public: + + template inline static void apply_noalias(Mat& out, const obj& X, const uword copies_per_row, const uword copies_per_col); + + template inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_repmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_repmat_meat.hpp new file mode 100644 index 0000000..1460350 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_repmat_meat.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_repmat +//! @{ + + + +template +inline +void +op_repmat::apply_noalias(Mat& out, const obj& X, const uword copies_per_row, const uword copies_per_col) + { + arma_extra_debug_sigprint(); + + typedef typename obj::elem_type eT; + + const uword X_n_rows = obj::is_row ? uword(1) : X.n_rows; + const uword X_n_cols = obj::is_col ? uword(1) : X.n_cols; + + out.set_size(X_n_rows * copies_per_row, X_n_cols * copies_per_col); + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + + // if( (out_n_rows > 0) && (out_n_cols > 0) ) + // { + // for(uword col = 0; col < out_n_cols; col += X_n_cols) + // for(uword row = 0; row < out_n_rows; row += X_n_rows) + // { + // out.submat(row, col, row+X_n_rows-1, col+X_n_cols-1) = X; + // } + // } + + if( (out_n_rows > 0) && (out_n_cols > 0) ) + { + if(copies_per_row != 1) + { + for(uword col_copy=0; col_copy < copies_per_col; ++col_copy) + { + const uword out_col_offset = X_n_cols * col_copy; + + for(uword col=0; col < X_n_cols; ++col) + { + eT* out_colptr = out.colptr(col + out_col_offset); + const eT* X_colptr = X.colptr(col); + + for(uword row_copy=0; row_copy < copies_per_row; ++row_copy) + { + const uword out_row_offset = X_n_rows * row_copy; + + arrayops::copy( &out_colptr[out_row_offset], X_colptr, X_n_rows ); + } + } + } + } + else + { + for(uword col_copy=0; col_copy < copies_per_col; ++col_copy) + { + const uword out_col_offset = X_n_cols * col_copy; + + for(uword col=0; col < X_n_cols; ++col) + { + eT* out_colptr = out.colptr(col + out_col_offset); + const eT* X_colptr = X.colptr(col); + + arrayops::copy( out_colptr, X_colptr, X_n_rows ); + } + } + } + } + + } + + + +template +inline +void +op_repmat::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword copies_per_row = in.aux_uword_a; + const uword copies_per_col = in.aux_uword_b; + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_repmat::apply_noalias(tmp, U.M, copies_per_row, copies_per_col); + + out.steal_mem(tmp); + } + else + { + op_repmat::apply_noalias(out, U.M, copies_per_row, copies_per_col); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_reshape_bones.hpp b/src/armadillo/include/armadillo_bits/op_reshape_bones.hpp new file mode 100644 index 0000000..b27f22b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_reshape_bones.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_reshape +//! @{ + + + +class op_reshape + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static void apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols); + + template inline static void apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols); + + template inline static void apply_proxy_noalias(Mat& out, const Proxy& P, const uword new_n_rows, const uword new_n_cols); + + // + + template inline static void apply(Cube& out, const OpCube& in); + + template inline static void apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + + template inline static void apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_reshape_meat.hpp b/src/armadillo/include/armadillo_bits/op_reshape_meat.hpp new file mode 100644 index 0000000..1846dc8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_reshape_meat.hpp @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_reshape +//! @{ + + + +template +inline +void +op_reshape::apply(Mat& actual_out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; + + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) + { + const unwrap U(in.m); + const Mat& A = U.M; + + if(&actual_out == &A) + { + op_reshape::apply_mat_inplace(actual_out, new_n_rows, new_n_cols); + } + else + { + op_reshape::apply_mat_noalias(actual_out, A, new_n_rows, new_n_cols); + } + } + else + { + const Proxy P(in.m); + + const bool is_alias = P.is_alias(actual_out); + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap::stored_type> U(P.Q); + + op_reshape::apply_mat_noalias(out, U.M, new_n_rows, new_n_cols); + } + else + { + op_reshape::apply_proxy_noalias(out, P, new_n_rows, new_n_cols); + } + + if(is_alias) { actual_out.steal_mem(tmp); } + } + } + + + +template +inline +void +op_reshape::apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + const uword new_n_elem = new_n_rows * new_n_cols; + + if(A.n_elem == new_n_elem) { A.set_size(new_n_rows, new_n_cols); return; } + + Mat B; + + op_reshape::apply_mat_noalias(B, A, new_n_rows, new_n_cols); + + A.steal_mem(B); + } + + + +template +inline +void +op_reshape::apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + out.set_size(new_n_rows, new_n_cols); + + const uword n_elem_to_copy = (std::min)(A.n_elem, out.n_elem); + + eT* out_mem = out.memptr(); + + arrayops::copy( out_mem, A.memptr(), n_elem_to_copy ); + + if(n_elem_to_copy < out.n_elem) + { + const uword n_elem_leftover = out.n_elem - n_elem_to_copy; + + arrayops::fill_zeros(&(out_mem[n_elem_to_copy]), n_elem_leftover); + } + } + + + +template +inline +void +op_reshape::apply_proxy_noalias(Mat& out, const Proxy& P, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + out.set_size(new_n_rows, new_n_cols); + + const uword n_elem_to_copy = (std::min)(P.get_n_elem(), out.n_elem); + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < n_elem_to_copy; ++i) { out_mem[i] = Pea[i]; } + } + else + { + uword i = 0; + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + + for(uword col=0; col < P_n_cols; ++col) + for(uword row=0; row < P_n_rows; ++row) + { + if(i >= n_elem_to_copy) { goto nested_loop_end; } + + out_mem[i] = P.at(row,col); + + ++i; + } + + nested_loop_end: ; + } + + if(n_elem_to_copy < out.n_elem) + { + const uword n_elem_leftover = out.n_elem - n_elem_to_copy; + + arrayops::fill_zeros(&(out_mem[n_elem_to_copy]), n_elem_leftover); + } + } + + + +template +inline +void +op_reshape::apply(Cube& out, const OpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube U(in.m); + const Cube& A = U.M; + + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; + const uword new_n_slices = in.aux_uword_c; + + if(&out == &A) + { + op_reshape::apply_cube_inplace(out, new_n_rows, new_n_cols, new_n_slices); + } + else + { + op_reshape::apply_cube_noalias(out, A, new_n_rows, new_n_cols, new_n_slices); + } + } + + + +template +inline +void +op_reshape::apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + const uword new_n_elem = new_n_rows * new_n_cols * new_n_slices; + + if(A.n_elem == new_n_elem) { A.set_size(new_n_rows, new_n_cols, new_n_slices); return; } + + Cube B; + + op_reshape::apply_cube_noalias(B, A, new_n_rows, new_n_cols, new_n_slices); + + A.steal_mem(B); + } + + + +template +inline +void +op_reshape::apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + out.set_size(new_n_rows, new_n_cols, new_n_slices); + + const uword n_elem_to_copy = (std::min)(A.n_elem, out.n_elem); + + eT* out_mem = out.memptr(); + + arrayops::copy( out_mem, A.memptr(), n_elem_to_copy ); + + if(n_elem_to_copy < out.n_elem) + { + const uword n_elem_leftover = out.n_elem - n_elem_to_copy; + + arrayops::fill_zeros(&(out_mem[n_elem_to_copy]), n_elem_leftover); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_resize_bones.hpp b/src/armadillo/include/armadillo_bits/op_resize_bones.hpp new file mode 100644 index 0000000..d33273f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_resize_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_resize +//! @{ + + + +class op_resize + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static void apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols); + + template inline static void apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols); + + // + + template inline static void apply(Cube& out, const OpCube& in); + + template inline static void apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + + template inline static void apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_resize_meat.hpp b/src/armadillo/include/armadillo_bits/op_resize_meat.hpp new file mode 100644 index 0000000..b18a163 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_resize_meat.hpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_resize +//! @{ + + + +template +inline +void +op_resize::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + if(&out == &A) + { + op_resize::apply_mat_inplace(out, new_n_rows, new_n_cols); + } + else + { + op_resize::apply_mat_noalias(out, A, new_n_rows, new_n_cols); + } + } + + + +template +inline +void +op_resize::apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + if( (A.n_rows == new_n_rows) && (A.n_cols == new_n_cols) ) { return; } + + if(A.is_empty()) { A.zeros(new_n_rows, new_n_cols); return; } + + Mat B; + + op_resize::apply_mat_noalias(B, A, new_n_rows, new_n_cols); + + A.steal_mem(B); + } + + + +template +inline +void +op_resize::apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + out.set_size(new_n_rows, new_n_cols); + + if( (new_n_rows > A.n_rows) || (new_n_cols > A.n_cols) ) { out.zeros(); } + + if( (out.n_elem > 0) && (A.n_elem > 0) ) + { + const uword end_row = (std::min)(new_n_rows, A.n_rows) - 1; + const uword end_col = (std::min)(new_n_cols, A.n_cols) - 1; + + out.submat(0, 0, end_row, end_col) = A.submat(0, 0, end_row, end_col); + } + } + + + +// + + + +template +inline +void +op_resize::apply(Cube& out, const OpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; + const uword new_n_slices = in.aux_uword_c; + + const unwrap_cube tmp(in.m); + const Cube& A = tmp.M; + + if(&out == &A) + { + op_resize::apply_cube_inplace(out, new_n_rows, new_n_cols, new_n_slices); + } + else + { + op_resize::apply_cube_noalias(out, A, new_n_rows, new_n_cols, new_n_slices); + } + } + + + +template +inline +void +op_resize::apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + if( (A.n_rows == new_n_rows) && (A.n_cols == new_n_cols) && (A.n_slices == new_n_slices) ) { return; } + + if(A.is_empty()) { A.zeros(new_n_rows, new_n_cols, new_n_slices); return; } + + Cube B; + + op_resize::apply_cube_noalias(B, A, new_n_rows, new_n_cols, new_n_slices); + + A.steal_mem(B); + } + + + +template +inline +void +op_resize::apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + out.set_size(new_n_rows, new_n_cols, new_n_slices); + + if( (new_n_rows > A.n_rows) || (new_n_cols > A.n_cols) || (new_n_slices > A.n_slices) ) { out.zeros(); } + + if( (out.n_elem > 0) && (A.n_elem > 0) ) + { + const uword end_row = (std::min)(new_n_rows, A.n_rows) - 1; + const uword end_col = (std::min)(new_n_cols, A.n_cols) - 1; + const uword end_slice = (std::min)(new_n_slices, A.n_slices) - 1; + + out.subcube(0, 0, 0, end_row, end_col, end_slice) = A.subcube(0, 0, 0, end_row, end_col, end_slice); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_reverse_bones.hpp b/src/armadillo/include/armadillo_bits/op_reverse_bones.hpp new file mode 100644 index 0000000..8ec6217 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_reverse_bones.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_reverse +//! @{ + + + +class op_reverse + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_reverse_vec + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_reverse_meat.hpp b/src/armadillo/include/armadillo_bits/op_reverse_meat.hpp new file mode 100644 index 0000000..4edafa8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_reverse_meat.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_reverse +//! @{ + + + +template +inline +void +op_reverse::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + + arma_debug_check( (dim > 1), "reverse(): parameter 'dim' must be 0 or 1" ); + + if(is_Mat::value) + { + // allow detection of in-place operation + + const unwrap U(in.m); + + if(dim == 0) { op_flipud::apply_direct(out, U.M); } + if(dim == 1) { op_fliplr::apply_direct(out, U.M); } + } + else + { + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + if(dim == 0) { op_flipud::apply_proxy_noalias(tmp, P); } + if(dim == 1) { op_fliplr::apply_proxy_noalias(tmp, P); } + + out.steal_mem(tmp); + } + else + { + if(dim == 0) { op_flipud::apply_proxy_noalias(out, P); } + if(dim == 1) { op_fliplr::apply_proxy_noalias(out, P); } + } + } + } + + + +template +inline +void +op_reverse_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Mat::value) + { + // allow detection of in-place operation + + const unwrap U(in.m); + + if((T1::is_xvec) ? bool(U.M.is_rowvec()) : bool(T1::is_row)) + { + op_fliplr::apply_direct(out, U.M); + } + else + { + op_flipud::apply_direct(out, U.M); + } + } + else + { + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + if((T1::is_xvec) ? bool(P.get_n_rows() == 1) : bool(T1::is_row)) + { + op_fliplr::apply_proxy_noalias(tmp, P); + } + else + { + op_flipud::apply_proxy_noalias(tmp, P); + } + + out.steal_mem(tmp); + } + else + { + if((T1::is_xvec) ? bool(P.get_n_rows() == 1) : bool(T1::is_row)) + { + op_fliplr::apply_proxy_noalias(out, P); + } + else + { + op_flipud::apply_proxy_noalias(out, P); + } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_roots_bones.hpp b/src/armadillo/include/armadillo_bits/op_roots_bones.hpp new file mode 100644 index 0000000..6007d19 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_roots_bones.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_roots +//! @{ + + + +class op_roots + : public traits_op_col + { + public: + + template + inline static void apply(Mat< std::complex >& out, const mtOp, T1, op_roots>& expr); + + template + inline static bool apply_direct(Mat< std::complex >& out, const Base& X); + + template + inline static bool apply_noalias(Mat< std::complex::result> >& out, const Mat& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_roots_meat.hpp b/src/armadillo/include/armadillo_bits/op_roots_meat.hpp new file mode 100644 index 0000000..1e09120 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_roots_meat.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_roots +//! @{ + + + +template +inline +void +op_roots::apply(Mat< std::complex >& out, const mtOp, T1, op_roots>& expr) + { + arma_extra_debug_sigprint(); + + const bool status = op_roots::apply_direct(out, expr.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("roots(): eigen decomposition failed"); + } + } + + + +template +inline +bool +op_roots::apply_direct(Mat< std::complex >& out, const Base& X) + { + arma_extra_debug_sigprint(); + + typedef std::complex out_eT; + + const quasi_unwrap U(X.get_ref()); + + bool status = false; + + if(U.is_alias(out)) + { + Mat tmp; + + status = op_roots::apply_noalias(tmp, U.M); + + out.steal_mem(tmp); + } + else + { + status = op_roots::apply_noalias(out, U.M); + } + + return status; + } + + + +template +inline +bool +op_roots::apply_noalias(Mat< std::complex::result> >& out, const Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + typedef std::complex::result> out_eT; + + arma_debug_check( (X.is_vec() == false), "roots(): given object must be a vector" ); + + if(X.internal_has_nonfinite()) { return false; } + + // treat X as a column vector + + const Col Y( const_cast(X.memptr()), X.n_elem, false, false); + + const T Y_max = (Y.is_empty() == false) ? T(max(abs(Y))) : T(0); + + if(Y_max == T(0)) { out.set_size(1,0); return true; } + + const uvec indices = find( Y / Y_max ); + + const uword n_tail_zeros = (indices.n_elem > 0) ? uword( (Y.n_elem-1) - indices[indices.n_elem-1] ) : uword(0); + + const Col Z = Y.subvec( indices[0], indices[indices.n_elem-1] ); + + if(Z.n_elem >= uword(2)) + { + Mat tmp; + + if(Z.n_elem == uword(2)) + { + tmp.set_size(1,1); + + tmp[0] = -Z[1] / Z[0]; + } + else + { + tmp = diagmat(ones< Col >(Z.n_elem - 2), -1); + + tmp.row(0) = strans(-Z.subvec(1, Z.n_elem-1) / Z[0]); + } + + Mat junk; + + const bool status = auxlib::eig_gen(out, junk, false, tmp); + + if(status == false) { return false; } + + if(n_tail_zeros > 0) + { + out.resize(out.n_rows + n_tail_zeros, 1); + } + } + else + { + out.zeros(n_tail_zeros,1); + } + + return true; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_row_as_mat_bones.hpp b/src/armadillo/include/armadillo_bits/op_row_as_mat_bones.hpp new file mode 100644 index 0000000..a843092 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_row_as_mat_bones.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_row_as_mat +//! @{ + + +class op_row_as_mat + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const CubeToMatOp& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_row_as_mat_meat.hpp b/src/armadillo/include/armadillo_bits/op_row_as_mat_meat.hpp new file mode 100644 index 0000000..751d8d8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_row_as_mat_meat.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_row_as_mat +//! @{ + + + +template +inline +void +op_row_as_mat::apply(Mat& out, const CubeToMatOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube U(expr.m); + const Cube& A = U.M; + + const uword in_row = expr.aux_uword; + + arma_debug_check_bounds( (in_row >= A.n_rows), "Cube::row_as_mat(): index out of bounds" ); + + const uword A_n_cols = A.n_cols; + const uword A_n_rows = A.n_rows; + const uword A_n_slices = A.n_slices; + + out.set_size(A_n_slices, A_n_cols); + + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_mem = &(A.at(in_row, 0, s)); + eT* out_mem = &(out.at(s,0)); + + for(uword c=0; c < A_n_cols; ++c) + { + (*out_mem) = (*A_mem); + + A_mem += A_n_rows; + out_mem += A_n_slices; + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_shift_bones.hpp b/src/armadillo/include/armadillo_bits/op_shift_bones.hpp new file mode 100644 index 0000000..74e4902 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_shift_bones.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_shift +//! @{ + + + +class op_shift_vec + : public traits_op_passthru + { + public: + + template inline static void apply(Mat& out, const Op& in); + }; + + + +class op_shift + : public traits_op_default + { + public: + + template inline static void apply_noalias(Mat& out, const Mat& X, const uword len, const uword neg, const uword dim); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_shift_meat.hpp b/src/armadillo/include/armadillo_bits/op_shift_meat.hpp new file mode 100644 index 0000000..b369b5d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_shift_meat.hpp @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_shift +//! @{ + + + +template +inline +void +op_shift_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(in.m); + + const uword len = in.aux_uword_a; + const uword neg = in.aux_uword_b; + + const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + if(U.is_alias(out)) + { + Mat tmp; + + op_shift::apply_noalias(tmp, U.M, len, neg, dim); + + out.steal_mem(tmp); + } + else + { + op_shift::apply_noalias(out, U.M, len, neg, dim); + } + } + + + +template +inline +void +op_shift::apply_noalias(Mat& out, const Mat& X, const uword len, const uword neg, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ((dim == 0) && (len >= X.n_rows)), "shift(): shift amount out of bounds" ); + arma_debug_check_bounds( ((dim == 1) && (len >= X.n_cols)), "shift(): shift amount out of bounds" ); + + out.copy_size(X); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + if(neg == 0) + { + for(uword col=0; col < X_n_cols; ++col) + { + eT* out_ptr = out.colptr(col); + const eT* X_ptr = X.colptr(col); + + for(uword out_row=len, row=0; row < (X_n_rows - len); ++row, ++out_row) + { + out_ptr[out_row] = X_ptr[row]; + } + + for(uword out_row=0, row=(X_n_rows - len); row < X_n_rows; ++row, ++out_row) + { + out_ptr[out_row] = X_ptr[row]; + } + } + } + else + if(neg == 1) + { + for(uword col=0; col < X_n_cols; ++col) + { + eT* out_ptr = out.colptr(col); + const eT* X_ptr = X.colptr(col); + + for(uword out_row=0, row=len; row < X_n_rows; ++row, ++out_row) + { + out_ptr[out_row] = X_ptr[row]; + } + + for(uword out_row=(X_n_rows-len), row=0; row < len; ++row, ++out_row) + { + out_ptr[out_row] = X_ptr[row]; + } + } + } + } + else + if(dim == 1) + { + if(neg == 0) + { + if(X_n_rows == 1) + { + eT* out_ptr = out.memptr(); + const eT* X_ptr = X.memptr(); + + for(uword out_col=len, col=0; col < (X_n_cols - len); ++col, ++out_col) + { + out_ptr[out_col] = X_ptr[col]; + } + + for(uword out_col=0, col=(X_n_cols - len); col < X_n_cols; ++col, ++out_col) + { + out_ptr[out_col] = X_ptr[col]; + } + } + else + { + for(uword out_col=len, col=0; col < (X_n_cols - len); ++col, ++out_col) + { + arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows ); + } + + for(uword out_col=0, col=(X_n_cols - len); col < X_n_cols; ++col, ++out_col) + { + arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows ); + } + } + } + else + if(neg == 1) + { + if(X_n_rows == 1) + { + eT* out_ptr = out.memptr(); + const eT* X_ptr = X.memptr(); + + for(uword out_col=0, col=len; col < X_n_cols; ++col, ++out_col) + { + out_ptr[out_col] = X_ptr[col]; + } + + for(uword out_col=(X_n_cols-len), col=0; col < len; ++col, ++out_col) + { + out_ptr[out_col] = X_ptr[col]; + } + } + else + { + for(uword out_col=0, col=len; col < X_n_cols; ++col, ++out_col) + { + arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows ); + } + + for(uword out_col=(X_n_cols-len), col=0; col < len; ++col, ++out_col) + { + arrayops::copy( out.colptr(out_col), X.colptr(col), X_n_rows ); + } + } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_shuffle_bones.hpp b/src/armadillo/include/armadillo_bits/op_shuffle_bones.hpp new file mode 100644 index 0000000..8150d13 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_shuffle_bones.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_shuffle +//! @{ + + + +class op_shuffle + : public traits_op_default + { + public: + + template inline static void apply_direct(Mat& out, const Mat& X, const uword dim); + + template inline static void apply(Mat& out, const Op& in); + }; + + + +class op_shuffle_vec + : public traits_op_passthru + { + public: + + template inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_shuffle_meat.hpp b/src/armadillo/include/armadillo_bits/op_shuffle_meat.hpp new file mode 100644 index 0000000..ecfc2f6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_shuffle_meat.hpp @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_shuffle +//! @{ + + + +template +inline +void +op_shuffle::apply_direct(Mat& out, const Mat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + if(X.is_empty()) { out.copy_size(X); return; } + + const uword N = (dim == 0) ? X.n_rows : X.n_cols; + + // see op_sort_index_bones.hpp for the definition of arma_sort_index_packet + // and the associated comparison functor + + typedef arma_sort_index_packet packet; + + std::vector packet_vec(N); + + for(uword i=0; i()); + packet_vec[i].index = i; + } + + arma_sort_index_helper_ascend comparator; + + std::sort( packet_vec.begin(), packet_vec.end(), comparator ); + + const bool is_alias = (&out == &X); + + if(X.is_vec() == false) + { + if(is_alias == false) + { + arma_extra_debug_print("op_shuffle::apply(): matrix"); + + out.copy_size(X); + + if(dim == 0) + { + for(uword i=0; i 1) // ie. column vector + { + for(uword i=0; i 1) // ie. row vector + { + for(uword i=0; i 1) // ie. column vector + { + for(uword i=0; i 1) // ie. row vector + { + for(uword i=0; i +inline +void +op_shuffle::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const unwrap U(in.m); + + const uword dim = in.aux_uword_a; + + arma_debug_check( (dim > 1), "shuffle(): parameter 'dim' must be 0 or 1" ); + + op_shuffle::apply_direct(out, U.M, dim); + } + + + +template +inline +void +op_shuffle_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const unwrap U(in.m); + + const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + op_shuffle::apply_direct(out, U.M, dim); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sort_bones.hpp b/src/armadillo/include/armadillo_bits/op_sort_bones.hpp new file mode 100644 index 0000000..37449fe --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sort_bones.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sort +//! @{ + + + +class op_sort + : public traits_op_default + { + public: + + template + inline static void copy_row(eT* X, const Mat& A, const uword row); + + template + inline static void copy_row(Mat& A, const eT* X, const uword row); + + template + inline static void direct_sort(eT* X, const uword N, const uword sort_type = 0); + + template + inline static void direct_sort_ascending(eT* X, const uword N); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword sort_type, const uword dim); + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_sort_vec + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sort_index_bones.hpp b/src/armadillo/include/armadillo_bits/op_sort_index_bones.hpp new file mode 100644 index 0000000..7229ed5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sort_index_bones.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sort_index +//! @{ + + + +class op_sort_index + : public traits_op_col + { + public: + + template + static inline bool apply_noalias(Mat& out, const Proxy& P, const uword sort_type); + + template + static inline void apply(Mat& out, const mtOp& in); + }; + + + +class op_stable_sort_index + : public traits_op_col + { + public: + + template + static inline bool apply_noalias(Mat& out, const Proxy& P, const uword sort_type); + + template + static inline void apply(Mat& out, const mtOp& in); + }; + + + +template +struct arma_sort_index_packet + { + eT val; + uword index; + }; + + + +template +struct arma_sort_index_helper_ascend + { + arma_inline + bool + operator() (const arma_sort_index_packet& A, const arma_sort_index_packet& B) const + { + return (A.val < B.val); + } + }; + + + +template +struct arma_sort_index_helper_descend + { + arma_inline + bool + operator() (const arma_sort_index_packet& A, const arma_sort_index_packet& B) const + { + return (A.val > B.val); + } + }; + + + +template +struct arma_sort_index_helper_ascend< std::complex > + { + typedef typename std::complex eT; + + inline + bool + operator() (const arma_sort_index_packet& A, const arma_sort_index_packet& B) const + { + return (std::abs(A.val) < std::abs(B.val)); + } + + // inline + // bool + // operator() (const arma_sort_index_packet& A, const arma_sort_index_packet& B) const + // { + // const T abs_A_val = std::abs(A.val); + // const T abs_B_val = std::abs(B.val); + // + // return ( (abs_A_val != abs_B_val) ? (abs_A_val < abs_B_val) : (std::arg(A.val) < std::arg(B.val)) ); + // } + }; + + + +template +struct arma_sort_index_helper_descend< std::complex > + { + typedef typename std::complex eT; + + inline + bool + operator() (const arma_sort_index_packet& A, const arma_sort_index_packet& B) const + { + return (std::abs(A.val) > std::abs(B.val)); + } + + // inline + // bool + // operator() (const arma_sort_index_packet& A, const arma_sort_index_packet& B) const + // { + // const T abs_A_val = std::abs(A.val); + // const T abs_B_val = std::abs(B.val); + // + // return ( (abs_A_val != abs_B_val) ? (abs_A_val > abs_B_val) : (std::arg(A.val) > std::arg(B.val)) ); + // } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sort_index_meat.hpp b/src/armadillo/include/armadillo_bits/op_sort_index_meat.hpp new file mode 100644 index 0000000..dfee901 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sort_index_meat.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sort_index +//! @{ + + + +template +inline +bool +arma_sort_index_helper(Mat& out, const Proxy& P, const uword sort_type) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + out.set_size(n_elem, 1); + + std::vector< arma_sort_index_packet > packet_vec(n_elem); + + if(Proxy::use_at == false) + { + for(uword i=0; i comparator; + + if(sort_stable == false) + { + std::sort( packet_vec.begin(), packet_vec.end(), comparator ); + } + else + { + std::stable_sort( packet_vec.begin(), packet_vec.end(), comparator ); + } + } + else + { + // descend + + arma_sort_index_helper_descend comparator; + + if(sort_stable == false) + { + std::sort( packet_vec.begin(), packet_vec.end(), comparator ); + } + else + { + std::stable_sort( packet_vec.begin(), packet_vec.end(), comparator ); + } + } + + uword* out_mem = out.memptr(); + + for(uword i=0; i +inline +bool +op_sort_index::apply_noalias(Mat& out, const Proxy& P, const uword sort_type) + { + arma_extra_debug_sigprint(); + + return arma_sort_index_helper(out, P, sort_type); + } + + + +template +inline +void +op_sort_index::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + const Proxy P(in.m); + + if(P.get_n_elem() == 0) { out.set_size(0,1); return; } + + const uword sort_type = in.aux_uword_a; + + bool all_non_nan = false; + + if(P.is_alias(out)) + { + Mat out2; + + all_non_nan = op_sort_index::apply_noalias(out2, P, sort_type); + + out.steal_mem(out2); + } + else + { + all_non_nan = op_sort_index::apply_noalias(out, P, sort_type); + } + + arma_debug_check( (all_non_nan == false), "sort_index(): detected NaN" ); + } + + + +template +inline +bool +op_stable_sort_index::apply_noalias(Mat& out, const Proxy& P, const uword sort_type) + { + arma_extra_debug_sigprint(); + + return arma_sort_index_helper(out, P, sort_type); + } + + + +template +inline +void +op_stable_sort_index::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + const Proxy P(in.m); + + if(P.get_n_elem() == 0) { out.set_size(0,1); return; } + + const uword sort_type = in.aux_uword_a; + + bool all_non_nan = false; + + if(P.is_alias(out)) + { + Mat out2; + + all_non_nan = op_stable_sort_index::apply_noalias(out2, P, sort_type); + + out.steal_mem(out2); + } + else + { + all_non_nan = op_stable_sort_index::apply_noalias(out, P, sort_type); + } + + arma_debug_check( (all_non_nan == false), "stable_sort_index(): detected NaN" ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sort_meat.hpp b/src/armadillo/include/armadillo_bits/op_sort_meat.hpp new file mode 100644 index 0000000..7a0ec6b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sort_meat.hpp @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sort +//! @{ + + + +template +inline +void +op_sort::direct_sort(eT* X, const uword n_elem, const uword sort_type) + { + arma_extra_debug_sigprint(); + + if(sort_type == 0) + { + arma_lt_comparator comparator; + + std::sort(&X[0], &X[n_elem], comparator); + } + else + { + arma_gt_comparator comparator; + + std::sort(&X[0], &X[n_elem], comparator); + } + } + + + +template +inline +void +op_sort::direct_sort_ascending(eT* X, const uword n_elem) + { + arma_extra_debug_sigprint(); + + arma_lt_comparator comparator; + + std::sort(&X[0], &X[n_elem], comparator); + } + + + +template +inline +void +op_sort::copy_row(eT* X, const Mat& A, const uword row) + { + const uword N = A.n_cols; + + uword i,j; + + for(i=0, j=1; j +inline +void +op_sort::copy_row(Mat& A, const eT* X, const uword row) + { + const uword N = A.n_cols; + + uword i,j; + + for(i=0, j=1; j +inline +void +op_sort::apply_noalias(Mat& out, const Mat& X, const uword sort_type, const uword dim) + { + arma_extra_debug_sigprint(); + + if((X.n_rows * X.n_cols) <= 1) { out = X; return; } + + if(dim == 0) // sort the contents of each column + { + arma_extra_debug_print("op_sort::apply(): dim = 0"); + + out = X; + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + for(uword col=0; col < n_cols; ++col) + { + op_sort::direct_sort( out.colptr(col), n_rows, sort_type ); + } + } + else + if(dim == 1) // sort the contents of each row + { + if(X.n_rows == 1) // a row vector + { + arma_extra_debug_print("op_sort::apply(): dim = 1, vector specific"); + + out = X; + op_sort::direct_sort(out.memptr(), out.n_elem, sort_type); + } + else // not a row vector + { + arma_extra_debug_print("op_sort::apply(): dim = 1, generic"); + + out.copy_size(X); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + podarray tmp_array(n_cols); + + for(uword row=0; row < n_rows; ++row) + { + op_sort::copy_row(tmp_array.memptr(), X, row); + + op_sort::direct_sort( tmp_array.memptr(), n_cols, sort_type ); + + op_sort::copy_row(out, tmp_array.memptr(), row); + } + } + } + } + + + +template +inline +void +op_sort::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + const uword sort_type = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (sort_type > 1), "sort(): parameter 'sort_type' must be 0 or 1" ); + arma_debug_check( (dim > 1), "sort(): parameter 'dim' must be 0 or 1" ); + arma_debug_check( (X.internal_has_nan()), "sort(): detected NaN" ); + + if(U.is_alias(out)) + { + Mat tmp; + + op_sort::apply_noalias(tmp, X, sort_type, dim); + + out.steal_mem(tmp); + } + else + { + op_sort::apply_noalias(out, X, sort_type, dim); + } + } + + + +template +inline +void +op_sort_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap U(in.m); // not using quasi_unwrap, to ensure there is no aliasing with subviews + const Mat& X = U.M; + + const uword sort_type = in.aux_uword_a; + + arma_debug_check( (sort_type > 1), "sort(): parameter 'sort_type' must be 0 or 1" ); + arma_debug_check( (X.internal_has_nan()), "sort(): detected NaN" ); + + out = X; // not checking for aliasing, to allow inplace sorting of vectors + + if(out.n_elem <= 1) { return; } + + eT* out_mem = out.memptr(); + + eT* start_ptr = out_mem; + eT* endp1_ptr = &out_mem[out.n_elem]; + + if(sort_type == 0) + { + arma_lt_comparator comparator; + + std::sort(start_ptr, endp1_ptr, comparator); + } + else + { + arma_gt_comparator comparator; + + std::sort(start_ptr, endp1_ptr, comparator); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sp_minus_bones.hpp b/src/armadillo/include/armadillo_bits/op_sp_minus_bones.hpp new file mode 100644 index 0000000..c0134ed --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sp_minus_bones.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sp_minus +//! @{ + + + +// Subtract a sparse object from a scalar; the output will be a dense object. +class op_sp_minus_pre + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const SpToDOp& in); + + // force apply into sparse matrix + template + inline static void apply(SpMat& out, const SpToDOp& in); + + // used for the optimization of sparse % (scalar - sparse) + template + inline static void apply_inside_schur(SpMat& out, const T2& x, const SpToDOp& y); + + // used for the optimization of sparse / (scalar - sparse) + template + inline static void apply_inside_div(SpMat& out, const T2& x, const SpToDOp& y); + }; + + + +// Subtract a scalar from a sparse object; the output will be a dense object. +class op_sp_minus_post + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const SpToDOp& in); + + // force apply into sparse matrix + template + inline static void apply(SpMat& out, const SpToDOp& in); + + // used for the optimization of sparse % (sparse - scalar) + template + inline static void apply_inside_schur(SpMat& out, const T2& x, const SpToDOp& y); + + // used for the optimization of sparse / (sparse - scalar) + template + inline static void apply_inside_div(SpMat& out, const T2& x, const SpToDOp& y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sp_minus_meat.hpp b/src/armadillo/include/armadillo_bits/op_sp_minus_meat.hpp new file mode 100644 index 0000000..f8151a8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sp_minus_meat.hpp @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sp_minus +//! @{ + + +// scalar - SpBase +template +inline +void +op_sp_minus_pre::apply(Mat& out, const SpToDOp& in) + { + arma_extra_debug_sigprint(); + + // Note that T1 will be a sparse type, so we use SpProxy. + const SpProxy proxy(in.m); + + out.set_size(proxy.get_n_rows(), proxy.get_n_cols()); + out.fill(in.aux); + + typename SpProxy::const_iterator_type it = proxy.begin(); + typename SpProxy::const_iterator_type it_end = proxy.end(); + + for(; it != it_end; ++it) + { + out.at(it.row(), it.col()) -= (*it); + } + } + + + +// force apply into SpMat +template +inline +void +op_sp_minus_pre::apply(SpMat& out, const SpToDOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // Note that T1 will be a sparse type, so we use SpProxy. + const SpProxy proxy(in.m); + + const uword n_rows = proxy.get_n_rows(); + const uword n_cols = proxy.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = in.aux; + + for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) + { + out.at(r, c) = k - proxy.at(r, c); + } + } + + + +// used for the optimization of sparse % (scalar - sparse) +template +inline +void +op_sp_minus_pre::apply_inside_schur(SpMat& out, const T2& x, const SpToDOp& y) + { + arma_extra_debug_sigprint(); + + const SpProxy proxy2(x); + const SpProxy proxy3(y.m); + + arma_debug_assert_same_size(proxy2.get_n_rows(), proxy2.get_n_cols(), proxy3.get_n_rows(), proxy3.get_n_cols(), "element-wise multiplication"); + + out.zeros(proxy2.get_n_rows(), proxy2.get_n_cols()); + + typename SpProxy::const_iterator_type it = proxy2.begin(); + typename SpProxy::const_iterator_type it_end = proxy2.end(); + + const eT k = y.aux; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + out.at(it_row, it_col) = (*it) * (k - proxy3.at(it_row, it_col)); + } + } + + + +// used for the optimization of sparse / (scalar - sparse) +template +inline +void +op_sp_minus_pre::apply_inside_div(SpMat& out, const T2& x, const SpToDOp& y) + { + arma_extra_debug_sigprint(); + + const SpProxy proxy2(x); + const SpProxy proxy3(y.m); + + arma_debug_assert_same_size(proxy2.get_n_rows(), proxy2.get_n_cols(), proxy3.get_n_rows(), proxy3.get_n_cols(), "element-wise multiplication"); + + out.zeros(proxy2.get_n_rows(), proxy2.get_n_cols()); + + typename SpProxy::const_iterator_type it = proxy2.begin(); + typename SpProxy::const_iterator_type it_end = proxy2.end(); + + const eT k = y.aux; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + out.at(it_row, it_col) = (*it) / (k - proxy3.at(it_row, it_col)); + } + } + + + +// SpBase - scalar +template +inline +void +op_sp_minus_post::apply(Mat& out, const SpToDOp& in) + { + arma_extra_debug_sigprint(); + + // Note that T1 will be a sparse type, so we use SpProxy. + const SpProxy proxy(in.m); + + out.set_size(proxy.get_n_rows(), proxy.get_n_cols()); + out.fill(-in.aux); + + typename SpProxy::const_iterator_type it = proxy.begin(); + typename SpProxy::const_iterator_type it_end = proxy.end(); + + for(; it != it_end; ++it) + { + out.at(it.row(), it.col()) += (*it); + } + } + + + +// force apply into sparse matrix +template +inline +void +op_sp_minus_post::apply(SpMat& out, const SpToDOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // Note that T1 will be a sparse type, so we use SpProxy. + const SpProxy proxy(in.m); + + const uword n_rows = proxy.get_n_rows(); + const uword n_cols = proxy.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = in.aux; + + for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) + { + out.at(r, c) = proxy.at(r, c) - k; + } + } + + + +// used for the optimization of sparse % (sparse - scalar) +template +inline +void +op_sp_minus_post::apply_inside_schur(SpMat& out, const T2& x, const SpToDOp& y) + { + arma_extra_debug_sigprint(); + + const SpProxy proxy2(x); + const SpProxy proxy3(y.m); + + arma_debug_assert_same_size(proxy2.get_n_rows(), proxy2.get_n_cols(), proxy3.get_n_rows(), proxy3.get_n_cols(), "element-wise multiplication"); + + out.zeros(proxy2.get_n_rows(), proxy2.get_n_cols()); + + typename SpProxy::const_iterator_type it = proxy2.begin(); + typename SpProxy::const_iterator_type it_end = proxy2.end(); + + const eT k = y.aux; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + out.at(it_row, it_col) = (*it) * (proxy3.at(it_row, it_col) - k); + } + } + + + +// used for the optimization of sparse / (sparse - scalar) +template +inline +void +op_sp_minus_post::apply_inside_div(SpMat& out, const T2& x, const SpToDOp& y) + { + arma_extra_debug_sigprint(); + + const SpProxy proxy2(x); + const SpProxy proxy3(y.m); + + arma_debug_assert_same_size(proxy2.get_n_rows(), proxy2.get_n_cols(), proxy3.get_n_rows(), proxy3.get_n_cols(), "element-wise multiplication"); + + out.zeros(proxy2.get_n_rows(), proxy2.get_n_cols()); + + typename SpProxy::const_iterator_type it = proxy2.begin(); + typename SpProxy::const_iterator_type it_end = proxy2.end(); + + const eT k = y.aux; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + out.at(it_row, it_col) = (*it) / (proxy3.at(it_row, it_col) - k); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sp_plus_bones.hpp b/src/armadillo/include/armadillo_bits/op_sp_plus_bones.hpp new file mode 100644 index 0000000..d3977a1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sp_plus_bones.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sp_plus +//! @{ + + + +// Add a scalar to a sparse matrix; this will return a dense matrix. +class op_sp_plus + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const SpToDOp& in); + + // force apply into an SpMat<> + template + inline static void apply(SpMat& out, const SpToDOp& in); + + // used for the optimization of sparse % (sparse + scalar) + template + inline static void apply_inside_schur(SpMat& out, const T2& x, const SpToDOp& y); + + // used for the optimization of sparse / (sparse + scalar) + template + inline static void apply_inside_div(SpMat& out, const T2& x, const SpToDOp& y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sp_plus_meat.hpp b/src/armadillo/include/armadillo_bits/op_sp_plus_meat.hpp new file mode 100644 index 0000000..723d968 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sp_plus_meat.hpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sp_plus +//! @{ + + +template +inline +void +op_sp_plus::apply(Mat& out, const SpToDOp& in) + { + arma_extra_debug_sigprint(); + + // Note that T1 will be a sparse type, so we use SpProxy. + const SpProxy proxy(in.m); + + out.set_size(proxy.get_n_rows(), proxy.get_n_cols()); + out.fill(in.aux); + + typename SpProxy::const_iterator_type it = proxy.begin(); + typename SpProxy::const_iterator_type it_end = proxy.end(); + + for(; it != it_end; ++it) + { + out.at(it.row(), it.col()) += (*it); + } + } + + + +// force apply into sparse matrix +template +inline +void +op_sp_plus::apply(SpMat& out, const SpToDOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // Note that T1 will be a sparse type, so we use SpProxy. + const SpProxy proxy(in.m); + + const uword n_rows = proxy.get_n_rows(); + const uword n_cols = proxy.get_n_cols(); + + out.set_size(n_rows, n_cols); + + const eT k = in.aux; + + // We have to loop over all the elements. + for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) + { + out.at(r, c) = proxy.at(r, c) + k; + } + } + + + +// used for the optimization of sparse % (sparse + scalar) +template +inline +void +op_sp_plus::apply_inside_schur(SpMat& out, const T2& x, const SpToDOp& y) + { + arma_extra_debug_sigprint(); + + const SpProxy proxy2(x); + const SpProxy proxy3(y.m); + + arma_debug_assert_same_size(proxy2.get_n_rows(), proxy2.get_n_cols(), proxy3.get_n_rows(), proxy3.get_n_cols(), "element-wise multiplication"); + + out.zeros(proxy2.get_n_rows(), proxy2.get_n_cols()); + + typename SpProxy::const_iterator_type it = proxy2.begin(); + typename SpProxy::const_iterator_type it_end = proxy2.end(); + + const eT k = y.aux; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + out.at(it_row, it_col) = (*it) * (proxy3.at(it_row, it_col) + k); + } + } + + + +// used for the optimization of sparse / (sparse + scalar) +template +inline +void +op_sp_plus::apply_inside_div(SpMat& out, const T2& x, const SpToDOp& y) + { + arma_extra_debug_sigprint(); + + const SpProxy proxy2(x); + const SpProxy proxy3(y.m); + + arma_debug_assert_same_size(proxy2.get_n_rows(), proxy2.get_n_cols(), proxy3.get_n_rows(), proxy3.get_n_cols(), "element-wise division"); + + out.zeros(proxy2.get_n_rows(), proxy2.get_n_cols()); + + typename SpProxy::const_iterator_type it = proxy2.begin(); + typename SpProxy::const_iterator_type it_end = proxy2.end(); + + const eT k = y.aux; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + out.at(it_row, it_col) = (*it) / (proxy3.at(it_row, it_col) + k); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sqrtmat_bones.hpp b/src/armadillo/include/armadillo_bits/op_sqrtmat_bones.hpp new file mode 100644 index 0000000..a63ae4b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sqrtmat_bones.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sqrtmat +//! @{ + + + +class op_sqrtmat + : public traits_op_default + { + public: + + template + inline static void apply(Mat< std::complex >& out, const mtOp,T1,op_sqrtmat>& in); + + template + inline static bool apply_direct(Mat< std::complex >& out, const Op& expr); + + template + inline static bool apply_direct(Mat< std::complex >& out, const Base& expr); + }; + + + +class op_sqrtmat_cx + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Op& expr); + + template + inline static bool apply_direct_noalias(Mat& out, const diagmat_proxy& P); + + template + inline static bool apply_direct(Mat& out, const Base& expr); + + template + inline static bool helper(Mat< std::complex >& S); + }; + + + +class op_sqrtmat_sympd + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp new file mode 100644 index 0000000..3c2fae5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sqrtmat_meat.hpp @@ -0,0 +1,549 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sqrtmat +//! @{ + + +//! implementation partly based on: +//! N. J. Higham. +//! A New sqrtm for Matlab. +//! Numerical Analysis Report No. 336, January 1999. +//! Department of Mathematics, University of Manchester. +//! ISSN 1360-1725 +//! http://www.maths.manchester.ac.uk/~higham/narep/narep336.ps.gz + + +template +inline +void +op_sqrtmat::apply(Mat< std::complex >& out, const mtOp,T1,op_sqrtmat>& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat::apply_direct(out, in.m); + + if(status == false) + { + arma_debug_warn_level(3, "sqrtmat(): given matrix is singular; may not have a square root"); + } + } + + + +template +inline +bool +op_sqrtmat::apply_direct(Mat< std::complex >& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type T; + + const diagmat_proxy P(expr.m); + + arma_debug_check( (P.n_rows != P.n_cols), "sqrtmat(): given matrix must be square sized" ); + + const uword N = P.n_rows; + + out.zeros(N,N); + + bool singular = false; + + for(uword i=0; i= T(0)) + { + singular = (singular || (val == T(0))); + + out.at(i,i) = std::sqrt(val); + } + else + { + out.at(i,i) = std::sqrt( std::complex(val) ); + } + } + + return (singular) ? false : true; + } + + + +template +inline +bool +op_sqrtmat::apply_direct(Mat< std::complex >& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_T; + typedef typename std::complex out_T; + + const quasi_unwrap expr_unwrap(expr.get_ref()); + const Mat& A = expr_unwrap.M; + + arma_debug_check( (A.is_square() == false), "sqrtmat(): given matrix must be square sized" ); + + if(A.n_elem == 0) + { + out.reset(); + return true; + } + else + if(A.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::sqrt( std::complex( A[0] ) ); + return true; + } + + if(A.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat: detected diagonal matrix"); + + const uword N = A.n_rows; + + out.zeros(N,N); // aliasing can't happen as op_sqrtmat is defined as cx_mat = op(mat) + + for(uword i=0; i= in_T(0)) + { + out.at(i,i) = std::sqrt(val); + } + else + { + out.at(i,i) = std::sqrt( out_T(val) ); + } + } + + return true; + } + + const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); + + if(try_sympd) + { + arma_extra_debug_print("op_sqrtmat: attempting sympd optimisation"); + + // if matrix A is sympd, all its eigenvalues are positive + + Col eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, A, 'd', "sqrtmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const in_T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i >::from( eigvec * diagmat(eigval) * eigvec.t() ); + + return true; + } + } + + arma_extra_debug_print("op_sqrtmat: sympd optimisation failed"); + + // fallthrough if eigen decomposition failed or an eigenvalue is <= 0 + } + + + Mat U; + Mat S(A.n_rows, A.n_cols, arma_nozeros_indicator()); + + const in_T* Amem = A.memptr(); + out_T* Smem = S.memptr(); + + const uword n_elem = A.n_elem; + + for(uword i=0; i( Amem[i] ); + } + + const bool schur_ok = auxlib::schur(U,S); + + if(schur_ok == false) + { + arma_extra_debug_print("sqrtmat(): schur decomposition failed"); + out.soft_reset(); + return false; + } + + const bool status = op_sqrtmat_cx::helper(S); + + const Mat X = U*S; + + S.reset(); + + out = X*U.t(); + + return status; + } + + + +template +inline +void +op_sqrtmat_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat_cx::apply_direct(out, in.m); + + if(status == false) + { + arma_debug_warn_level(3, "sqrtmat(): given matrix is singular; may not have a square root"); + } + } + + + +template +inline +bool +op_sqrtmat_cx::apply_direct(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const diagmat_proxy P(expr.m); + + bool status = false; + + if(P.is_alias(out)) + { + Mat tmp; + + status = op_sqrtmat_cx::apply_direct_noalias(tmp, P); + + out.steal_mem(tmp); + } + else + { + status = op_sqrtmat_cx::apply_direct_noalias(out, P); + } + + return status; + } + + + +template +inline +bool +op_sqrtmat_cx::apply_direct_noalias(Mat& out, const diagmat_proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_check( (P.n_rows != P.n_cols), "sqrtmat(): given matrix must be square sized" ); + + const uword N = P.n_rows; + + out.zeros(N,N); + + const eT zero = eT(0); + + bool singular = false; + + for(uword i=0; i +inline +bool +op_sqrtmat_cx::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + Mat U; + Mat S = expr.get_ref(); + + arma_debug_check( (S.n_rows != S.n_cols), "sqrtmat(): given matrix must be square sized" ); + + if(S.n_elem == 0) + { + out.reset(); + return true; + } + else + if(S.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::sqrt(S[0]); + return true; + } + + if(S.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat_cx: detected diagonal matrix"); + + const uword N = S.n_rows; + + out.zeros(N,N); // aliasing can't happen as S is generated + + for(uword i=0; i eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, S, 'd', "sqrtmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i X = U*S; + + S.reset(); + + out = X*U.t(); + + return status; + } + + + +template +inline +bool +op_sqrtmat_cx::helper(Mat< std::complex >& S) + { + typedef typename std::complex eT; + + if(S.is_empty()) { return true; } + + const uword N = S.n_rows; + + const eT zero = eT(0); + + eT& S_00 = S[0]; + + bool singular = (S_00 == zero); + + S_00 = std::sqrt(S_00); + + for(uword j=1; j < N; ++j) + { + eT* S_j = S.colptr(j); + + eT& S_jj = S_j[j]; + + singular = (singular || (S_jj == zero)); + + S_jj = std::sqrt(S_jj); + + for(uword ii=0; ii <= (j-1); ++ii) + { + const uword i = (j-1) - ii; + + const eT* S_i = S.colptr(i); + + //S_j[i] /= (S_i[i] + S_j[j]); + S_j[i] /= (S_i[i] + S_jj); + + for(uword k=0; k < i; ++k) + { + S_j[k] -= S_i[k] * S_j[i]; + } + } + } + + return (singular) ? false : true; + } + + + +template +inline +void +op_sqrtmat_sympd::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_sqrtmat_sympd::apply_direct(out, in.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("sqrtmat_sympd(): transformation failed"); + } + } + + + +template +inline +bool +op_sqrtmat_sympd::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const unwrap U(expr.get_ref()); + const Mat& X = U.M; + + arma_debug_check( (X.is_square() == false), "sqrtmat_sympd(): given matrix must be square sized" ); + + if((arma_config::debug) && (is_cx::yes) && (sym_helper::check_diag_imag(X) == false)) + { + arma_debug_warn_level(1, "sqrtmat_sympd(): imaginary components on the diagonal are non-zero"); + } + + if(is_op_diagmat::value || X.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat_sympd: detected diagonal matrix"); + + out = X; + + eT* colmem = out.memptr(); + + const uword N = X.n_rows; + + for(uword i=0; i eigval; + Mat eigvec; + + const bool status = eig_sym_helper(eigval, eigvec, X, 'd', "sqrtmat_sympd()"); + + if(status == false) { return false; } + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_stddev_meat.hpp b/src/armadillo/include/armadillo_bits/op_stddev_meat.hpp new file mode 100644 index 0000000..83724bc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_stddev_meat.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_stddev +//! @{ + + + +template +inline +void +op_stddev::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type out_eT; + + const uword norm_type = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (norm_type > 1), "stddev(): parameter 'norm_type' must be 0 or 1" ); + arma_debug_check( (dim > 1), "stddev(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_stddev::apply_noalias(tmp, U.M, norm_type, dim); + + out.steal_mem(tmp); + } + else + { + op_stddev::apply_noalias(out, U.M, norm_type, dim); + } + } + + + +template +inline +void +op_stddev::apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_stddev::apply_noalias(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + out_eT* out_mem = out.memptr(); + + for(uword col=0; col 0) ? 1 : 0); + + if(X_n_cols > 0) + { + podarray dat(X_n_cols); + + in_eT* dat_mem = dat.memptr(); + out_eT* out_mem = out.memptr(); + + for(uword row=0; row + struct traits + { + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + struct pos + { + static constexpr uword n2 = (do_flip == false) ? (row + col*2) : (col + row*2); + static constexpr uword n3 = (do_flip == false) ? (row + col*3) : (col + row*3); + static constexpr uword n4 = (do_flip == false) ? (row + col*4) : (col + row*4); + }; + + template + arma_cold inline static void apply_mat_noalias_tinysq(Mat& out, const TA& A); + + template + arma_hot inline static void block_worker(eT* Y, const eT* X, const uword X_n_rows, const uword Y_n_rows, const uword n_rows, const uword n_cols); + + template + arma_hot inline static void apply_mat_noalias_large(Mat& out, const Mat& A); + + template + arma_hot inline static void apply_mat_noalias(Mat& out, const TA& A); + + template + arma_hot inline static void apply_mat_inplace(Mat& out); + + template + inline static void apply_mat(Mat& out, const TA& A); + + template + inline static void apply_proxy(Mat& out, const Proxy& P); + + template + inline static void apply_direct(Mat& out, const T1& X); + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_strans_cube + { + public: + + template + inline static void apply_noalias(Cube& out, const Cube& X); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_strans_meat.hpp b/src/armadillo/include/armadillo_bits/op_strans_meat.hpp new file mode 100644 index 0000000..ed02d3b --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_strans_meat.hpp @@ -0,0 +1,465 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_strans +//! @{ + + + +//! for tiny square matrices (size <= 4x4) +template +inline +void +op_strans::apply_mat_noalias_tinysq(Mat& out, const TA& A) + { + const eT* Am = A.memptr(); + eT* outm = out.memptr(); + + switch(A.n_rows) + { + case 1: + { + outm[0] = Am[0]; + } + break; + + case 2: + { + outm[pos::n2] = Am[pos::n2]; + outm[pos::n2] = Am[pos::n2]; + + outm[pos::n2] = Am[pos::n2]; + outm[pos::n2] = Am[pos::n2]; + } + break; + + case 3: + { + outm[pos::n3] = Am[pos::n3]; + outm[pos::n3] = Am[pos::n3]; + outm[pos::n3] = Am[pos::n3]; + + outm[pos::n3] = Am[pos::n3]; + outm[pos::n3] = Am[pos::n3]; + outm[pos::n3] = Am[pos::n3]; + + outm[pos::n3] = Am[pos::n3]; + outm[pos::n3] = Am[pos::n3]; + outm[pos::n3] = Am[pos::n3]; + } + break; + + case 4: + { + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + outm[pos::n4] = Am[pos::n4]; + } + break; + + default: + ; + } + + } + + + +template +inline +void +op_strans::block_worker(eT* Y, const eT* X, const uword X_n_rows, const uword Y_n_rows, const uword n_rows, const uword n_cols) + { + for(uword row = 0; row < n_rows; ++row) + { + const uword Y_offset = row * Y_n_rows; + + for(uword col = 0; col < n_cols; ++col) + { + const uword X_offset = col * X_n_rows; + + Y[col + Y_offset] = X[row + X_offset]; + } + } + } + + + +template +inline +void +op_strans::apply_mat_noalias_large(Mat& out, const Mat& A) + { + arma_extra_debug_sigprint(); + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + const uword block_size = 64; + + const uword n_rows_base = block_size * (n_rows / block_size); + const uword n_cols_base = block_size * (n_cols / block_size); + + const uword n_rows_extra = n_rows - n_rows_base; + const uword n_cols_extra = n_cols - n_cols_base; + + const eT* X = A.memptr(); + eT* Y = out.memptr(); + + for(uword row = 0; row < n_rows_base; row += block_size) + { + const uword Y_offset = row * n_cols; + + for(uword col = 0; col < n_cols_base; col += block_size) + { + const uword X_offset = col * n_rows; + + op_strans::block_worker(&Y[col + Y_offset], &X[row + X_offset], n_rows, n_cols, block_size, block_size); + } + + const uword X_offset = n_cols_base * n_rows; + + op_strans::block_worker(&Y[n_cols_base + Y_offset], &X[row + X_offset], n_rows, n_cols, block_size, n_cols_extra); + } + + if(n_rows_extra == 0) { return; } + + const uword Y_offset = n_rows_base * n_cols; + + for(uword col = 0; col < n_cols_base; col += block_size) + { + const uword X_offset = col * n_rows; + + op_strans::block_worker(&Y[col + Y_offset], &X[n_rows_base + X_offset], n_rows, n_cols, n_rows_extra, block_size); + } + + const uword X_offset = n_cols_base * n_rows; + + op_strans::block_worker(&Y[n_cols_base + Y_offset], &X[n_rows_base + X_offset], n_rows, n_cols, n_rows_extra, n_cols_extra); + } + + + +//! Immediate transpose of a dense matrix +template +inline +void +op_strans::apply_mat_noalias(Mat& out, const TA& A) + { + arma_extra_debug_sigprint(); + + const uword A_n_cols = A.n_cols; + const uword A_n_rows = A.n_rows; + + out.set_size(A_n_cols, A_n_rows); + + if( (TA::is_row) || (TA::is_col) || (A_n_cols == 1) || (A_n_rows == 1) ) + { + arrayops::copy( out.memptr(), A.memptr(), A.n_elem ); + } + else + { + if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) ) + { + op_strans::apply_mat_noalias_tinysq(out, A); + } + else + if( (A_n_rows >= 512) && (A_n_cols >= 512) ) + { + op_strans::apply_mat_noalias_large(out, A); + } + else + { + eT* outptr = out.memptr(); + + for(uword k=0; k < A_n_rows; ++k) + { + const eT* Aptr = &(A.at(k,0)); + + uword j; + for(j=1; j < A_n_cols; j+=2) + { + const eT tmp_i = (*Aptr); Aptr += A_n_rows; + const eT tmp_j = (*Aptr); Aptr += A_n_rows; + + (*outptr) = tmp_i; outptr++; + (*outptr) = tmp_j; outptr++; + } + + if((j-1) < A_n_cols) + { + (*outptr) = (*Aptr); outptr++;; + } + } + } + } + } + + + +template +inline +void +op_strans::apply_mat_inplace(Mat& out) + { + arma_extra_debug_sigprint(); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + if(n_rows == n_cols) + { + arma_extra_debug_print("op_strans::apply(): doing in-place transpose of a square matrix"); + + const uword N = n_rows; + + for(uword k=0; k < N; ++k) + { + eT* colptr = &(out.at(k,k)); + eT* rowptr = colptr; + + colptr++; + rowptr += N; + + uword j; + + for(j=(k+2); j < N; j+=2) + { + std::swap( (*rowptr), (*colptr) ); rowptr += N; colptr++; + std::swap( (*rowptr), (*colptr) ); rowptr += N; colptr++; + } + + if((j-1) < N) + { + std::swap( (*rowptr), (*colptr) ); + } + } + } + else + { + Mat tmp; + + op_strans::apply_mat_noalias(tmp, out); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_strans::apply_mat(Mat& out, const TA& A) + { + arma_extra_debug_sigprint(); + + if(&out != &A) + { + op_strans::apply_mat_noalias(out, A); + } + else + { + op_strans::apply_mat_inplace(out); + } + } + + + +template +inline +void +op_strans::apply_proxy(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) + { + out.set_size(n_cols, n_rows); + + eT* out_mem = out.memptr(); + + const uword n_elem = P.get_n_elem(); + + typename Proxy::ea_type Pea = P.get_ea(); + + uword i,j; + for(i=0, j=1; j < n_elem; i+=2, j+=2) + { + const eT tmp_i = Pea[i]; + const eT tmp_j = Pea[j]; + + out_mem[i] = tmp_i; + out_mem[j] = tmp_j; + } + + if(i < n_elem) + { + out_mem[i] = Pea[i]; + } + } + else // general matrix transpose + { + out.set_size(n_cols, n_rows); + + eT* outptr = out.memptr(); + + for(uword k=0; k < n_rows; ++k) + { + uword j; + for(j=1; j < n_cols; j+=2) + { + const uword i = j-1; + + const eT tmp_i = P.at(k,i); + const eT tmp_j = P.at(k,j); + + (*outptr) = tmp_i; outptr++; + (*outptr) = tmp_j; outptr++; + } + + const uword i = j-1; + + if(i < n_cols) + { + (*outptr) = P.at(k,i); outptr++; + } + } + } + } + + + +template +inline +void +op_strans::apply_direct(Mat& out, const T1& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // allow detection of in-place transpose + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) + { + const unwrap U(X); + + op_strans::apply_mat(out, U.M); + } + else + { + const Proxy P(X); + + const bool is_alias = P.is_alias(out); + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap::stored_type> U(P.Q); + + if(is_alias) + { + Mat tmp; + + op_strans::apply_mat_noalias(tmp, U.M); + + out.steal_mem(tmp); + } + else + { + op_strans::apply_mat_noalias(out, U.M); + } + } + else + { + if(is_alias) + { + Mat tmp; + + op_strans::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_strans::apply_proxy(out, P); + } + } + } + } + + + +template +inline +void +op_strans::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + op_strans::apply_direct(out, in.m); + } + + + +// +// +// + + + +template +inline +void +op_strans_cube::apply_noalias(Cube& out, const Cube& X) + { + out.set_size(X.n_cols, X.n_rows, X.n_slices); + + for(uword s=0; s < X.n_slices; ++s) + { + Mat out_slice( out.slice_memptr(s), X.n_cols, X.n_rows, false, true ); + + const Mat X_slice( const_cast(X.slice_memptr(s)), X.n_rows, X.n_cols, false, true ); + + op_strans::apply_mat_noalias(out_slice, X_slice); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sum_bones.hpp b/src/armadillo/include/armadillo_bits/op_sum_bones.hpp new file mode 100644 index 0000000..a72c786 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sum_bones.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sum +//! @{ + + +class op_sum + : public traits_op_xvec + { + public: + + // dense matrices + + template + arma_hot inline static void apply(Mat& out, const Op& in); + + template + arma_hot inline static void apply_noalias(Mat& out, const Proxy& P, const uword dim); + + template + arma_hot inline static void apply_noalias_unwrap(Mat& out, const Proxy& P, const uword dim); + + template + arma_hot inline static void apply_noalias_proxy(Mat& out, const Proxy& P, const uword dim); + + + // cubes + + template + arma_hot inline static void apply(Cube& out, const OpCube& in); + + template + arma_hot inline static void apply_noalias(Cube& out, const ProxyCube& P, const uword dim); + + template + arma_hot inline static void apply_noalias_unwrap(Cube& out, const ProxyCube& P, const uword dim); + + template + arma_hot inline static void apply_noalias_proxy(Cube& out, const ProxyCube& P, const uword dim); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_sum_meat.hpp b/src/armadillo/include/armadillo_bits/op_sum_meat.hpp new file mode 100644 index 0000000..f759daa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_sum_meat.hpp @@ -0,0 +1,430 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_sum +//! @{ + + + +template +inline +void +op_sum::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "sum(): parameter 'dim' must be 0 or 1" ); + + const Proxy P(in.m); + + if(P.is_alias(out) == false) + { + op_sum::apply_noalias(out, P, dim); + } + else + { + Mat tmp; + + op_sum::apply_noalias(tmp, P, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_sum::apply_noalias(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + if(is_Mat::stored_type>::value || (arma_config::openmp && Proxy::use_mp)) + { + op_sum::apply_noalias_unwrap(out, P, dim); + } + else + { + op_sum::apply_noalias_proxy(out, P, dim); + } + } + + + +template +inline +void +op_sum::apply_noalias_unwrap(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + typedef typename Proxy::stored_type P_stored_type; + + const unwrap tmp(P.Q); + + const typename unwrap::stored_type& X = tmp.M; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword out_n_rows = (dim == 0) ? uword(1) : X_n_rows; + const uword out_n_cols = (dim == 0) ? X_n_cols : uword(1); + + out.set_size(out_n_rows, out_n_cols); + + if(X.n_elem == 0) { out.zeros(); return; } + + const eT* X_colptr = X.memptr(); + eT* out_mem = out.memptr(); + + if(dim == 0) + { + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = arrayops::accumulate( X_colptr, X_n_rows ); + + X_colptr += X_n_rows; + } + } + else + { + arrayops::copy(out_mem, X_colptr, X_n_rows); + + X_colptr += X_n_rows; + + for(uword col=1; col < X_n_cols; ++col) + { + arrayops::inplace_plus( out_mem, X_colptr, X_n_rows ); + + X_colptr += X_n_rows; + } + } + } + + + +template +inline +void +op_sum::apply_noalias_proxy(Mat& out, const Proxy& P, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + + const uword out_n_rows = (dim == 0) ? uword(1) : P_n_rows; + const uword out_n_cols = (dim == 0) ? P_n_cols : uword(1); + + out.set_size(out_n_rows, out_n_cols); + + if(P.get_n_elem() == 0) { out.zeros(); return; } + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) + { + if(dim == 0) + { + uword count = 0; + + for(uword col=0; col < P_n_cols; ++col) + { + eT val1 = eT(0); + eT val2 = eT(0); + + uword j; + for(j=1; j < P_n_rows; j+=2) + { + val1 += P[count]; ++count; + val2 += P[count]; ++count; + } + + if((j-1) < P_n_rows) + { + val1 += P[count]; ++count; + } + + out_mem[col] = (val1 + val2); + } + } + else + { + uword count = 0; + + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] = P[count]; ++count; + } + + for(uword col=1; col < P_n_cols; ++col) + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] += P[count]; ++count; + } + } + } + else + { + if(dim == 0) + { + for(uword col=0; col < P_n_cols; ++col) + { + eT val1 = eT(0); + eT val2 = eT(0); + + uword i,j; + for(i=0, j=1; j < P_n_rows; i+=2, j+=2) + { + val1 += P.at(i,col); + val2 += P.at(j,col); + } + + if(i < P_n_rows) + { + val1 += P.at(i,col); + } + + out_mem[col] = (val1 + val2); + } + } + else + { + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] = P.at(row,0); + } + + for(uword col=1; col < P_n_cols; ++col) + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] += P.at(row,col); + } + } + } + } + + + +// +// cubes + + + +template +inline +void +op_sum::apply(Cube& out, const OpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 2), "sum(): parameter 'dim' must be 0 or 1 or 2" ); + + const ProxyCube P(in.m); + + if(P.is_alias(out) == false) + { + op_sum::apply_noalias(out, P, dim); + } + else + { + Cube tmp; + + op_sum::apply_noalias(tmp, P, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_sum::apply_noalias(Cube& out, const ProxyCube& P, const uword dim) + { + arma_extra_debug_sigprint(); + + if(is_Cube::stored_type>::value || (arma_config::openmp && ProxyCube::use_mp)) + { + op_sum::apply_noalias_unwrap(out, P, dim); + } + else + { + op_sum::apply_noalias_proxy(out, P, dim); + } + } + + + +template +inline +void +op_sum::apply_noalias_unwrap(Cube& out, const ProxyCube& P, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + typedef typename ProxyCube::stored_type P_stored_type; + + const unwrap_cube tmp(P.Q); + + const Cube& X = tmp.M; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + const uword X_n_slices = X.n_slices; + + if(dim == 0) + { + out.set_size(1, X_n_cols, X_n_slices); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + out_mem[col] = arrayops::accumulate( X.slice_colptr(slice,col), X_n_rows ); + } + } + } + else + if(dim == 1) + { + out.zeros(X_n_rows, 1, X_n_slices); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < X_n_cols; ++col) + { + arrayops::inplace_plus( out_mem, X.slice_colptr(slice,col), X_n_rows ); + } + } + } + else + if(dim == 2) + { + out.zeros(X_n_rows, X_n_cols, 1); + + eT* out_mem = out.memptr(); + + for(uword slice=0; slice < X_n_slices; ++slice) + { + arrayops::inplace_plus(out_mem, X.slice_memptr(slice), X.n_elem_slice ); + } + } + } + + + +template +inline +void +op_sum::apply_noalias_proxy(Cube& out, const ProxyCube& P, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + const uword P_n_slices = P.get_n_slices(); + + if(dim == 0) + { + out.set_size(1, P_n_cols, P_n_slices); + + for(uword slice=0; slice < P_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < P_n_cols; ++col) + { + eT val1 = eT(0); + eT val2 = eT(0); + + uword i,j; + for(i=0, j=1; j < P_n_rows; i+=2, j+=2) + { + val1 += P.at(i,col,slice); + val2 += P.at(j,col,slice); + } + + if(i < P_n_rows) + { + val1 += P.at(i,col,slice); + } + + out_mem[col] = (val1 + val2); + } + } + } + else + if(dim == 1) + { + out.zeros(P_n_rows, 1, P_n_slices); + + for(uword slice=0; slice < P_n_slices; ++slice) + { + eT* out_mem = out.slice_memptr(slice); + + for(uword col=0; col < P_n_cols; ++col) + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] += P.at(row,col,slice); + } + } + } + else + if(dim == 2) + { + out.zeros(P_n_rows, P_n_cols, 1); + + for(uword slice=0; slice < P_n_slices; ++slice) + { + for(uword col=0; col < P_n_cols; ++col) + { + eT* out_mem = out.slice_colptr(0,col); + + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] += P.at(row,col,slice); + } + } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_symmat_bones.hpp b/src/armadillo/include/armadillo_bits/op_symmat_bones.hpp new file mode 100644 index 0000000..1e07f0e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_symmat_bones.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_symmat +//! @{ + + + +class op_symmatu + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_symmatl + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_symmatu_cx + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_symmatl_cx + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_symmat_meat.hpp b/src/armadillo/include/armadillo_bits/op_symmat_meat.hpp new file mode 100644 index 0000000..52731bd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_symmat_meat.hpp @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_symmat +//! @{ + + + +template +inline +void +op_symmatu::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + arma_debug_check( (A.is_square() == false), "symmatu(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + if(&out != &A) + { + out.copy_size(A); + + // upper triangular: copy the diagonal and the elements above the diagonal + + for(uword i=0; i +inline +void +op_symmatl::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + arma_debug_check( (A.is_square() == false), "symmatl(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + if(&out != &A) + { + out.copy_size(A); + + // lower triangular: copy the diagonal and the elements below the diagonal + + for(uword i=0; i +inline +void +op_symmatu_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + arma_debug_check( (A.is_square() == false), "symmatu(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + const bool do_conj = (in.aux_uword_b == 1); + + if(&out != &A) + { + out.copy_size(A); + + // upper triangular: copy the diagonal and the elements above the diagonal + + for(uword i=0; i +inline +void +op_symmatl_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + arma_debug_check( (A.is_square() == false), "symmatl(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + const bool do_conj = (in.aux_uword_b == 1); + + if(&out != &A) + { + out.copy_size(A); + + // lower triangular: copy the diagonal and the elements below the diagonal + + for(uword i=0; i + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_toeplitz_c + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_toeplitz_meat.hpp b/src/armadillo/include/armadillo_bits/op_toeplitz_meat.hpp new file mode 100644 index 0000000..a06c481 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_toeplitz_meat.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_toeplitz +//! @{ + + + +template +inline +void +op_toeplitz::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_check tmp(in.m, out); + const Mat& X = tmp.M; + + arma_debug_check( ((X.is_vec() == false) && (X.is_empty() == false)), "toeplitz(): given object must be a vector" ); + + const uword N = X.n_elem; + const eT* X_mem = X.memptr(); + + out.set_size(N,N); + + for(uword col=0; col < N; ++col) + { + eT* col_mem = out.colptr(col); + + uword i; + + i = col; + for(uword row=0; row < col; ++row, --i) { col_mem[row] = X_mem[i]; } + + i = 0; + for(uword row=col; row < N; ++row, ++i) { col_mem[row] = X_mem[i]; } + } + } + + + +template +inline +void +op_toeplitz_c::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_check tmp(in.m, out); + const Mat& X = tmp.M; + + arma_debug_check( ((X.is_vec() == false) && (X.is_empty() == false)), "circ_toeplitz(): given object must be a vector" ); + + const uword N = X.n_elem; + const eT* X_mem = X.memptr(); + + out.set_size(N,N); + + if(X.is_rowvec()) + { + for(uword row=0; row < N; ++row) + { + uword i; + + i = row; + for(uword col=0; col < row; ++col, --i) { out.at(row,col) = X_mem[N-i]; } + + i = 0; + for(uword col=row; col < N; ++col, ++i) { out.at(row,col) = X_mem[i]; } + } + } + else + { + for(uword col=0; col < N; ++col) + { + eT* col_mem = out.colptr(col); + + uword i; + + i = col; + for(uword row=0; row < col; ++row, --i) { col_mem[row] = X_mem[N-i]; } + + i = 0; + for(uword row=col; row < N; ++row, ++i) { col_mem[row] = X_mem[i]; } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_trimat_bones.hpp b/src/armadillo/include/armadillo_bits/op_trimat_bones.hpp new file mode 100644 index 0000000..f500cbd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_trimat_bones.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_trimat +//! @{ + + + +// NOTE: don't split op_trimat into seperate op_trimatu and op_trimatl classes, +// NOTE: as several instances elsewhere rely on trimatu() and trimatl() producing the same type +class op_trimat + : public traits_op_default + { + public: + + template + inline static void fill_zeros(Mat& A, const bool upper); + + // + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void apply_unwrap(Mat& out, const Mat& A, const bool upper); + + template + inline static void apply_proxy(Mat& out, const Proxy& P, const bool upper); + }; + + + +class op_trimatu_ext + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void fill_zeros(Mat& A, const uword row_offset, const uword col_offset); + }; + + + +class op_trimatl_ext + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static void fill_zeros(Mat& A, const uword row_offset, const uword col_offset); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_trimat_meat.hpp b/src/armadillo/include/armadillo_bits/op_trimat_meat.hpp new file mode 100644 index 0000000..7922515 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_trimat_meat.hpp @@ -0,0 +1,381 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_trimat +//! @{ + + + +template +inline +void +op_trimat::fill_zeros(Mat& out, const bool upper) + { + arma_extra_debug_sigprint(); + + const uword N = out.n_rows; + + if(upper) + { + // upper triangular: set all elements below the diagonal to zero + + for(uword i=0; i +inline +void +op_trimat::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const bool upper = (in.aux_uword_a == 0); + + // allow detection of in-place operation + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) + { + const unwrap U(in.m); + + op_trimat::apply_unwrap(out, U.M, upper); + } + else + { + const Proxy P(in.m); + + const bool is_alias = P.is_alias(out); + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap::stored_type> U(P.Q); + + if(is_alias) + { + Mat tmp; + + op_trimat::apply_unwrap(tmp, U.M, upper); + + out.steal_mem(tmp); + } + else + { + op_trimat::apply_unwrap(out, U.M, upper); + } + } + else + { + if(is_alias) + { + Mat tmp; + + op_trimat::apply_proxy(tmp, P, upper); + + out.steal_mem(tmp); + } + else + { + op_trimat::apply_proxy(out, P, upper); + } + } + } + } + + + +template +inline +void +op_trimat::apply_unwrap(Mat& out, const Mat& A, const bool upper) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square sized" ); + + if(&out != &A) + { + out.copy_size(A); + + const uword N = A.n_rows; + + if(upper) + { + // upper triangular: copy the diagonal and the elements above the diagonal + for(uword i=0; i +inline +void +op_trimat::apply_proxy(Mat& out, const Proxy& P, const bool upper) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (P.get_n_rows() != P.get_n_cols()), "trimatu()/trimatl(): given matrix must be square sized" ); + + const uword N = P.get_n_rows(); + + out.set_size(N,N); + + if(upper) + { + for(uword j=0; j < N; ++j) + for(uword i=0; i < (j+1); ++i) + { + out.at(i,j) = P.at(i,j); + } + } + else + { + for(uword j=0; j +inline +void +op_trimatu_ext::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + arma_debug_check( (A.is_square() == false), "trimatu(): given matrix must be square sized" ); + + const uword row_offset = in.aux_uword_a; + const uword col_offset = in.aux_uword_b; + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu(): requested diagonal is out of bounds" ); + + if(&out != &A) + { + out.copy_size(A); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i < n_cols; ++i) + { + const uword col = i + col_offset; + + if(i < N) + { + const uword end_row = i + row_offset; + + for(uword row=0; row <= end_row; ++row) + { + out.at(row,col) = A.at(row,col); + } + } + else + { + if(col < n_cols) + { + arrayops::copy(out.colptr(col), A.colptr(col), n_rows); + } + } + } + } + + op_trimatu_ext::fill_zeros(out, row_offset, col_offset); + } + + + +template +inline +void +op_trimatu_ext::fill_zeros(Mat& out, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword col=0; col < col_offset; ++col) + { + arrayops::fill_zeros(out.colptr(col), n_rows); + } + + for(uword i=0; i < N; ++i) + { + const uword start_row = i + row_offset + 1; + const uword col = i + col_offset; + + for(uword row=start_row; row < n_rows; ++row) + { + out.at(row,col) = eT(0); + } + } + } + + + +// + + + +template +inline +void +op_trimatl_ext::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + arma_debug_check( (A.is_square() == false), "trimatl(): given matrix must be square sized" ); + + const uword row_offset = in.aux_uword_a; + const uword col_offset = in.aux_uword_b; + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl(): requested diagonal is out of bounds" ); + + if(&out != &A) + { + out.copy_size(A); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword col=0; col < col_offset; ++col) + { + arrayops::copy( out.colptr(col), A.colptr(col), n_rows ); + } + + for(uword i=0; i +inline +void +op_trimatl_ext::fill_zeros(Mat& out, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i < n_cols; ++i) + { + const uword col = i + col_offset; + + if(i < N) + { + const uword end_row = i + row_offset; + + for(uword row=0; row < end_row; ++row) + { + out.at(row,col) = eT(0); + } + } + else + { + if(col < n_cols) + { + arrayops::fill_zeros(out.colptr(col), n_rows); + } + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_unique_bones.hpp b/src/armadillo/include/armadillo_bits/op_unique_bones.hpp new file mode 100644 index 0000000..7e7bb69 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_unique_bones.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_unique +//! @{ + + + +class op_unique + : public traits_op_col + { + public: + + template + inline static bool apply_helper(Mat& out, const Proxy& P, const bool P_is_row); + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_unique_vec + : public traits_op_passthru + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +template +struct arma_unique_comparator + { + arma_inline + bool + operator() (const eT a, const eT b) const + { + return ( a < b ); + } + }; + + + +template +struct arma_unique_comparator< std::complex > + { + arma_inline + bool + operator() (const std::complex& a, const std::complex& b) const + { + const T a_real = a.real(); + const T b_real = b.real(); + + return ( (a_real < b_real) ? true : ((a_real == b_real) ? (a.imag() < b.imag()) : false) ); + } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_unique_meat.hpp b/src/armadillo/include/armadillo_bits/op_unique_meat.hpp new file mode 100644 index 0000000..1605ea7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_unique_meat.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_unique +//! @{ + + + +template +inline +bool +op_unique::apply_helper(Mat& out, const Proxy& P, const bool P_is_row) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + if(P_is_row) + { + out.set_size(1,0); + } + else + { + out.set_size(0,1); + } + + return true; + } + + if(n_elem == 1) + { + const eT tmp = (Proxy::use_at) ? P.at(0,0) : P[0]; + + out.set_size(1, 1); + + out[0] = tmp; + + return true; + } + + Mat X(n_elem, 1, arma_nozeros_indicator()); + + eT* X_mem = X.memptr(); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i comparator; + + std::sort( X.begin(), X.end(), comparator ); + + uword N_unique = 1; + + for(uword i=1; i < n_elem; ++i) + { + const eT a = X_mem[i-1]; + const eT b = X_mem[i ]; + + const eT diff = a - b; + + if(diff != eT(0)) { ++N_unique; } + } + + if(P_is_row) + { + out.set_size(1, N_unique); + } + else + { + out.set_size(N_unique, 1); + } + + eT* out_mem = out.memptr(); + + if(n_elem > 0) { (*out_mem) = X_mem[0]; out_mem++; } + + for(uword i=1; i < n_elem; ++i) + { + const eT a = X_mem[i-1]; + const eT b = X_mem[i ]; + + const eT diff = a - b; + + if(diff != eT(0)) { (*out_mem) = b; out_mem++; } + } + + return true; + } + + + +template +inline +void +op_unique::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const Proxy P(in.m); + + const bool all_non_nan = op_unique::apply_helper(out, P, false); + + arma_debug_check( (all_non_nan == false), "unique(): detected NaN" ); + } + + + +template +inline +void +op_unique_vec::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const Proxy P(in.m); + + const bool P_is_row = (T1::is_xvec) ? bool(P.get_n_rows() == 1) : bool(T1::is_row); + + const bool all_non_nan = op_unique::apply_helper(out, P, P_is_row); + + arma_debug_check( (all_non_nan == false), "unique(): detected NaN" ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_var_bones.hpp b/src/armadillo/include/armadillo_bits/op_var_bones.hpp new file mode 100644 index 0000000..ee13bd4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_var_bones.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_var +//! @{ + + + +class op_var + : public traits_op_xvec + { + public: + + template + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim); + + // + + template + inline static typename get_pod_type::result var_vec(const subview_col& X, const uword norm_type = 0); + + template + inline static typename get_pod_type::result var_vec(const subview_row& X, const uword norm_type = 0); + + template + inline static typename T1::pod_type var_vec(const Base& X, const uword norm_type = 0); + + + // + + template + inline static eT direct_var(const eT* const X, const uword N, const uword norm_type = 0); + + template + inline static eT direct_var_robust(const eT* const X, const uword N, const uword norm_type = 0); + + + // + + template + inline static T direct_var(const std::complex* const X, const uword N, const uword norm_type = 0); + + template + inline static T direct_var_robust(const std::complex* const X, const uword N, const uword norm_type = 0); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_var_meat.hpp b/src/armadillo/include/armadillo_bits/op_var_meat.hpp new file mode 100644 index 0000000..49a252d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_var_meat.hpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_var +//! @{ + + + +template +inline +void +op_var::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type out_eT; + + const uword norm_type = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" ); + arma_debug_check( (dim > 1), "var(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_var::apply_noalias(tmp, U.M, norm_type, dim); + + out.steal_mem(tmp); + } + else + { + op_var::apply_noalias(out, U.M, norm_type, dim); + } + } + + + +template +inline +void +op_var::apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_var::apply_noalias(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + out_eT* out_mem = out.memptr(); + + for(uword col=0; col 0) ? 1 : 0); + + if(X_n_cols > 0) + { + podarray dat(X_n_cols); + + in_eT* dat_mem = dat.memptr(); + out_eT* out_mem = out.memptr(); + + for(uword row=0; row +inline +typename T1::pod_type +op_var::var_vec(const Base& X, const uword norm_type) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" ); + + const quasi_unwrap U(X.get_ref()); + + return op_var::direct_var(U.M.memptr(), U.M.n_elem, norm_type); + } + + + +template +inline +typename get_pod_type::result +op_var::var_vec(const subview_col& X, const uword norm_type) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" ); + + return op_var::direct_var(X.colptr(0), X.n_rows, norm_type); + } + + + + +template +inline +typename get_pod_type::result +op_var::var_vec(const subview_row& X, const uword norm_type) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" ); + + const Mat& A = X.m; + + const uword start_row = X.aux_row1; + const uword start_col = X.aux_col1; + + const uword end_col_p1 = start_col + X.n_cols; + + podarray tmp(X.n_elem); + eT* tmp_mem = tmp.memptr(); + + for(uword i=0, col=start_col; col < end_col_p1; ++col, ++i) + { + tmp_mem[i] = A.at(start_row, col); + } + + return op_var::direct_var(tmp.memptr(), tmp.n_elem, norm_type); + } + + + +//! find the variance of an array +template +inline +eT +op_var::direct_var(const eT* const X, const uword n_elem, const uword norm_type) + { + arma_extra_debug_sigprint(); + + if(n_elem >= 2) + { + const eT acc1 = op_mean::direct_mean(X, n_elem); + + eT acc2 = eT(0); + eT acc3 = eT(0); + + uword i,j; + + for(i=0, j=1; j +inline +eT +op_var::direct_var_robust(const eT* const X, const uword n_elem, const uword norm_type) + { + arma_extra_debug_sigprint(); + + if(n_elem > 1) + { + eT r_mean = X[0]; + eT r_var = eT(0); + + for(uword i=1; i +inline +T +op_var::direct_var(const std::complex* const X, const uword n_elem, const uword norm_type) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + if(n_elem >= 2) + { + const eT acc1 = op_mean::direct_mean(X, n_elem); + + T acc2 = T(0); + eT acc3 = eT(0); + + for(uword i=0; i +inline +T +op_var::direct_var_robust(const std::complex* const X, const uword n_elem, const uword norm_type) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + if(n_elem > 1) + { + eT r_mean = X[0]; + T r_var = T(0); + + for(uword i=1; i + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword k, const uword dim); + + template + inline static void apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword k); + }; + + +class op_vecnorm_ext + : public traits_op_xvec + { + public: + + template + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword method_id, const uword dim); + + template + inline static void apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword method_id); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_vecnorm_meat.hpp b/src/armadillo/include/armadillo_bits/op_vecnorm_meat.hpp new file mode 100644 index 0000000..7f9f664 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_vecnorm_meat.hpp @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_vecnorm +//! @{ + + + +template +inline +void +op_vecnorm::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT; + typedef typename T1::pod_type out_eT; + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + const uword k = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (k == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + if(U.is_alias(out)) + { + Mat tmp; + + op_vecnorm::apply_noalias(tmp, X, k, dim); + + out.steal_mem(tmp); + } + else + { + op_vecnorm::apply_noalias(out, X, k, dim); + } + } + + + + +template +inline +void +op_vecnorm::apply_noalias(Mat::result>& out, const Mat& X, const uword k, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_vecnorm::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + out_eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + op_vecnorm::apply_rawmem( out_mem[col], X.colptr(col), X_n_rows, k ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_vecnorm::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols > 0) + { + podarray dat(X_n_cols); + + in_eT* dat_mem = dat.memptr(); + out_eT* out_mem = out.memptr(); + + for(uword row=0; row < X_n_rows; ++row) + { + dat.copy_row(X, row); + + op_vecnorm::apply_rawmem( out_mem[row], dat_mem, X_n_cols, k ); + } + } + } + } + + + +template +inline +void +op_vecnorm::apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword k) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const Col tmp(const_cast(mem), N, false, false); + + const Proxy< Col > P(tmp); + + if(P.get_n_elem() == 0) { out_val = out_eT(0); return; } + + if(k == uword(1)) { out_val = op_norm::vec_norm_1(P); return; } + if(k == uword(2)) { out_val = op_norm::vec_norm_2(P); return; } + + out_val = op_norm::vec_norm_k(P, int(k)); + } + + + +// + + + +template +inline +void +op_vecnorm_ext::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT; + typedef typename T1::pod_type out_eT; + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + const uword method_id = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (method_id == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + if(U.is_alias(out)) + { + Mat tmp; + + op_vecnorm_ext::apply_noalias(tmp, X, method_id, dim); + + out.steal_mem(tmp); + } + else + { + op_vecnorm_ext::apply_noalias(out, X, method_id, dim); + } + } + + + + +template +inline +void +op_vecnorm_ext::apply_noalias(Mat::result>& out, const Mat& X, const uword method_id, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_vecnorm_ext::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + out_eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + op_vecnorm_ext::apply_rawmem( out_mem[col], X.colptr(col), X_n_rows, method_id ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_vecnorm_ext::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols > 0) + { + podarray dat(X_n_cols); + + in_eT* dat_mem = dat.memptr(); + out_eT* out_mem = out.memptr(); + + for(uword row=0; row < X_n_rows; ++row) + { + dat.copy_row(X, row); + + op_vecnorm_ext::apply_rawmem( out_mem[row], dat_mem, X_n_cols, method_id ); + } + } + } + } + + + +template +inline +void +op_vecnorm_ext::apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const Col tmp(const_cast(mem), N, false, false); + + const Proxy< Col > P(tmp); + + if(P.get_n_elem() == 0) { out_val = out_eT(0); return; } + + if(method_id == uword(1)) { out_val = op_norm::vec_norm_max(P); return; } + if(method_id == uword(2)) { out_val = op_norm::vec_norm_min(P); return; } + + out_val = out_eT(0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_vectorise_bones.hpp b/src/armadillo/include/armadillo_bits/op_vectorise_bones.hpp new file mode 100644 index 0000000..91f7df7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_vectorise_bones.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_vectorise +//! @{ + + + +class op_vectorise_col + : public traits_op_col + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static void apply_direct(Mat& out, const T1& expr); + + template inline static void apply_subview(Mat& out, const subview& sv); + + template inline static void apply_proxy(Mat& out, const Proxy& P); + }; + + + +class op_vectorise_row + : public traits_op_row + { + public: + + template inline static void apply(Mat& out, const Op& in); + + template inline static void apply_direct(Mat& out, const T1& expr); + + template inline static void apply_proxy(Mat& out, const Proxy& P); + }; + + + +class op_vectorise_all + : public traits_op_xvec + { + public: + + template inline static void apply(Mat& out, const Op& in); + }; + + + +class op_vectorise_cube_col + : public traits_op_col + { + public: + + template inline static void apply(Mat& out, const CubeToMatOp& in); + + template inline static void apply_subview(Mat& out, const subview_cube& sv); + + template inline static void apply_unwrap(Mat& out, const T1& expr); + + template inline static void apply_proxy(Mat& out, const T1& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_vectorise_meat.hpp b/src/armadillo/include/armadillo_bits/op_vectorise_meat.hpp new file mode 100644 index 0000000..c0f278c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_vectorise_meat.hpp @@ -0,0 +1,463 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_vectorise +//! @{ + + + +template +inline +void +op_vectorise_col::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + op_vectorise_col::apply_direct(out, in.m); + } + + + +template +inline +void +op_vectorise_col::apply_direct(Mat& out, const T1& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // allow detection of in-place operation + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) + { + const unwrap U(expr); + + if(&out == &(U.M)) + { + // output matrix is the same as the input matrix + + out.set_size(out.n_elem, 1); // set_size() doesn't destroy data as long as the number of elements in the matrix remains the same + } + else + { + out.set_size(U.M.n_elem, 1); + + arrayops::copy(out.memptr(), U.M.memptr(), U.M.n_elem); + } + } + else + if(is_subview::value) + { + const subview& sv = reinterpret_cast< const subview& >(expr); + + if(&out == &(sv.m)) + { + Mat tmp; + + op_vectorise_col::apply_subview(tmp, sv); + + out.steal_mem(tmp); + } + else + { + op_vectorise_col::apply_subview(out, sv); + } + } + else + { + const Proxy P(expr); + + const bool is_alias = P.is_alias(out); + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap::stored_type> U(P.Q); + + if(is_alias) + { + Mat tmp(U.M.memptr(), U.M.n_elem, 1); + + out.steal_mem(tmp); + } + else + { + out.set_size(U.M.n_elem, 1); + + arrayops::copy(out.memptr(), U.M.memptr(), U.M.n_elem); + } + } + else + { + if(is_alias) + { + Mat tmp; + + op_vectorise_col::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_vectorise_col::apply_proxy(out, P); + } + } + } + } + + + +template +inline +void +op_vectorise_col::apply_subview(Mat& out, const subview& sv) + { + arma_extra_debug_sigprint(); + + const uword sv_n_rows = sv.n_rows; + const uword sv_n_cols = sv.n_cols; + + out.set_size(sv.n_elem, 1); + + eT* out_mem = out.memptr(); + + for(uword col=0; col < sv_n_cols; ++col) + { + arrayops::copy(out_mem, sv.colptr(col), sv_n_rows); + + out_mem += sv_n_rows; + } + } + + + +template +inline +void +op_vectorise_col::apply_proxy(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword N = P.get_n_elem(); + + out.set_size(N, 1); + + eT* outmem = out.memptr(); + + if(Proxy::use_at == false) + { + // TODO: add handling of aligned access ? + + typename Proxy::ea_type A = P.get_ea(); + + uword i,j; + + for(i=0, j=1; j < N; i+=2, j+=2) + { + const eT tmp_i = A[i]; + const eT tmp_j = A[j]; + + outmem[i] = tmp_i; + outmem[j] = tmp_j; + } + + if(i < N) + { + outmem[i] = A[i]; + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows == 1) + { + for(uword i=0; i < n_cols; ++i) + { + outmem[i] = P.at(0,i); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + *outmem = P.at(row,col); + outmem++; + } + } + } + } + + + +template +inline +void +op_vectorise_row::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + op_vectorise_row::apply_direct(out, in.m); + } + + + +template +inline +void +op_vectorise_row::apply_direct(Mat& out, const T1& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(expr); + + if(P.is_alias(out)) + { + Mat tmp; + + op_vectorise_row::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_vectorise_row::apply_proxy(out, P); + } + } + + + +template +inline +void +op_vectorise_row::apply_proxy(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_elem = P.get_n_elem(); + + out.set_size(1, n_elem); + + eT* outmem = out.memptr(); + + if(n_cols == 1) + { + if(is_Mat::stored_type>::value) + { + const unwrap::stored_type> tmp(P.Q); + + arrayops::copy(out.memptr(), tmp.M.memptr(), n_elem); + } + else + { + for(uword i=0; i < n_elem; ++i) { outmem[i] = P.at(i,0); } + } + } + else + { + for(uword row=0; row < n_rows; ++row) + { + uword i,j; + + for(i=0, j=1; j < n_cols; i+=2, j+=2) + { + const eT tmp_i = P.at(row,i); + const eT tmp_j = P.at(row,j); + + *outmem = tmp_i; outmem++; + *outmem = tmp_j; outmem++; + } + + if(i < n_cols) + { + *outmem = P.at(row,i); outmem++; + } + } + } + } + + + +template +inline +void +op_vectorise_all::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const uword dim = in.aux_uword_a; + + if(dim == 0) + { + op_vectorise_col::apply_direct(out, in.m); + } + else + { + op_vectorise_row::apply_direct(out, in.m); + } + } + + + +// + + + +template +inline +void +op_vectorise_cube_col::apply(Mat& out, const CubeToMatOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_same_type< T1, subview_cube >::yes) + { + op_vectorise_cube_col::apply_subview(out, reinterpret_cast< const subview_cube& >(in.m)); + } + else + { + if(is_Cube::value || (arma_config::openmp && ProxyCube::use_mp)) + { + op_vectorise_cube_col::apply_unwrap(out, in.m); + } + else + { + op_vectorise_cube_col::apply_proxy(out, in.m); + } + } + } + + + +template +inline +void +op_vectorise_cube_col::apply_subview(Mat& out, const subview_cube& sv) + { + arma_extra_debug_sigprint(); + + const uword sv_nr = sv.n_rows; + const uword sv_nc = sv.n_cols; + const uword sv_ns = sv.n_slices; + + out.set_size(sv.n_elem, 1); + + eT* out_mem = out.memptr(); + + for(uword s=0; s < sv_ns; ++s) + for(uword c=0; c < sv_nc; ++c) + { + arrayops::copy(out_mem, sv.slice_colptr(s,c), sv_nr); + + out_mem += sv_nr; + } + } + + + +template +inline +void +op_vectorise_cube_col::apply_unwrap(Mat& out, const T1& expr) + { + arma_extra_debug_sigprint(); + + const unwrap_cube U(expr); + + out.set_size(U.M.n_elem, 1); + + arrayops::copy(out.memptr(), U.M.memptr(), U.M.n_elem); + } + + + +template +inline +void +op_vectorise_cube_col::apply_proxy(Mat& out, const T1& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const ProxyCube P(expr); + + if(is_Cube::stored_type>::value) + { + op_vectorise_cube_col::apply_unwrap(out, P.Q); + + return; + } + + const uword N = P.get_n_elem(); + + out.set_size(N, 1); + + eT* outmem = out.memptr(); + + if(ProxyCube::use_at == false) + { + typename ProxyCube::ea_type A = P.get_ea(); + + uword i,j; + + for(i=0, j=1; j < N; i+=2, j+=2) + { + const eT tmp_i = A[i]; + const eT tmp_j = A[j]; + + outmem[i] = tmp_i; + outmem[j] = tmp_j; + } + + if(i < N) + { + outmem[i] = A[i]; + } + } + else + { + const uword nr = P.get_n_rows(); + const uword nc = P.get_n_cols(); + const uword ns = P.get_n_slices(); + + for(uword s=0; s < ns; ++s) + for(uword c=0; c < nc; ++c) + for(uword r=0; r < nr; ++r) + { + *outmem = P.at(r,c,s); + outmem++; + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_wishrnd_bones.hpp b/src/armadillo/include/armadillo_bits/op_wishrnd_bones.hpp new file mode 100644 index 0000000..b85a72d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_wishrnd_bones.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_wishrnd +//! @{ + + +class op_wishrnd + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& expr); + + template + inline static bool apply_direct(Mat& out, const Base& X, const typename T1::elem_type df, const uword mode); + + template + inline static bool apply_noalias_mode1(Mat& out, const Mat& S, const eT df); + + template + inline static bool apply_noalias_mode2(Mat& out, const Mat& D, const eT df); + }; + + + +class op_iwishrnd + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& expr); + + template + inline static bool apply_direct(Mat& out, const Base& X, const typename T1::elem_type df, const uword mode); + + template + inline static bool apply_noalias_mode1(Mat& out, const Mat& T, const eT df); + + template + inline static bool apply_noalias_mode2(Mat& out, const Mat& Dinv, const eT df); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/op_wishrnd_meat.hpp b/src/armadillo/include/armadillo_bits/op_wishrnd_meat.hpp new file mode 100644 index 0000000..44fa77d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/op_wishrnd_meat.hpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_wishrnd +//! @{ + + +// implementation based on: +// Yu-Cheng Ku and Peter Bloomfield. +// Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX. +// Oxmetrics User Conference, 2010. + + +template +inline +void +op_wishrnd::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const eT df = expr.aux; + const uword mode = expr.aux_uword_a; + + const bool status = op_wishrnd::apply_direct(out, expr.m, df, mode); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("wishrnd(): given matrix is not symmetric positive definite"); + } + } + + + +template +inline +bool +op_wishrnd::apply_direct(Mat& out, const Base& X, const typename T1::elem_type df, const uword mode) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(X.get_ref()); + + bool status = false; + + if(U.is_alias(out)) + { + Mat tmp; + + if(mode == 1) { status = op_wishrnd::apply_noalias_mode1(tmp, U.M, df); } + if(mode == 2) { status = op_wishrnd::apply_noalias_mode2(tmp, U.M, df); } + + out.steal_mem(tmp); + } + else + { + if(mode == 1) { status = op_wishrnd::apply_noalias_mode1(out, U.M, df); } + if(mode == 2) { status = op_wishrnd::apply_noalias_mode2(out, U.M, df); } + } + + return status; + } + + + +template +inline +bool +op_wishrnd::apply_noalias_mode1(Mat& out, const Mat& S, const eT df) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (S.is_square() == false), "wishrnd(): given matrix must be square sized" ); + + if(S.is_empty()) { out.reset(); return true; } + + if(auxlib::rudimentary_sym_check(S) == false) { return false; } + + Mat D; + + const bool status = op_chol::apply_direct(D, S, 0); + + if(status == false) { return false; } + + return op_wishrnd::apply_noalias_mode2(out, D, df); + } + + + +template +inline +bool +op_wishrnd::apply_noalias_mode2(Mat& out, const Mat& D, const eT df) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (df <= eT(0)), "df must be greater than zero" ); + arma_debug_check( (D.is_square() == false), "wishrnd(): given matrix must be square sized" ); + + if(D.is_empty()) { out.reset(); return true; } + + const uword N = D.n_rows; + + if(df < eT(N)) + { + arma_extra_debug_print("simple generator"); + + const uword df_floor = uword(std::floor(df)); + + const Mat tmp = (randn< Mat >(df_floor, N)) * D; + + out = tmp.t() * tmp; + } + else + { + arma_extra_debug_print("standard generator"); + + op_chi2rnd_varying_df chi2rnd_generator; + + Mat A(N, N, arma_zeros_indicator()); + + for(uword i=0; i::fill( A.colptr(i), i ); + } + + const Mat tmp = A * D; + + A.reset(); + + out = tmp.t() * tmp; + } + + return true; + } + + + +// + + + +template +inline +void +op_iwishrnd::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const eT df = expr.aux; + const uword mode = expr.aux_uword_a; + + const bool status = op_iwishrnd::apply_direct(out, expr.m, df, mode); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("iwishrnd(): given matrix is not symmetric positive definite and/or df is too low"); + } + } + + + +template +inline +bool +op_iwishrnd::apply_direct(Mat& out, const Base& X, const typename T1::elem_type df, const uword mode) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap U(X.get_ref()); + + bool status = false; + + if(U.is_alias(out)) + { + Mat tmp; + + if(mode == 1) { status = op_iwishrnd::apply_noalias_mode1(tmp, U.M, df); } + if(mode == 2) { status = op_iwishrnd::apply_noalias_mode2(tmp, U.M, df); } + + out.steal_mem(tmp); + } + else + { + if(mode == 1) { status = op_iwishrnd::apply_noalias_mode1(out, U.M, df); } + if(mode == 2) { status = op_iwishrnd::apply_noalias_mode2(out, U.M, df); } + } + + return status; + } + + + +template +inline +bool +op_iwishrnd::apply_noalias_mode1(Mat& out, const Mat& T, const eT df) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (T.is_square() == false), "iwishrnd(): given matrix must be square sized" ); + + if(T.is_empty()) { out.reset(); return true; } + + if(auxlib::rudimentary_sym_check(T) == false) { return false; } + + Mat Tinv; + Mat Dinv; + + const bool inv_status = auxlib::inv_sympd(Tinv, T); + + if(inv_status == false) { return false; } + + const bool chol_status = op_chol::apply_direct(Dinv, Tinv, 0); + + if(chol_status == false) { return false; } + + return op_iwishrnd::apply_noalias_mode2(out, Dinv, df); + } + + + +template +inline +bool +op_iwishrnd::apply_noalias_mode2(Mat& out, const Mat& Dinv, const eT df) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (df <= eT(0)), "df must be greater than zero" ); + arma_debug_check( (Dinv.is_square() == false), "iwishrnd(): given matrix must be square sized" ); + + if(Dinv.is_empty()) { out.reset(); return true; } + + Mat tmp; + + const bool wishrnd_status = op_wishrnd::apply_noalias_mode2(tmp, Dinv, df); + + if(wishrnd_status == false) { return false; } + + const bool inv_status1 = auxlib::inv_sympd(out, tmp); + + const bool inv_status2 = (inv_status1) ? bool(true) : bool(auxlib::inv(out, tmp)); + + if(inv_status2 == false) { return false; } + + return true; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_cube_div.hpp b/src/armadillo/include/armadillo_bits/operator_cube_div.hpp new file mode 100644 index 0000000..58ff3a0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_cube_div.hpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_cube_div +//! @{ + + + +//! BaseCube / scalar +template +arma_inline +const eOpCube +operator/ + ( + const BaseCube& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! scalar / BaseCube +template +arma_inline +const eOpCube +operator/ + ( + const typename T1::elem_type k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! complex scalar / non-complex BaseCube (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_div_pre> +operator/ + ( + const std::complex& k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_div_pre>('j', X.get_ref(), k); + } + + + +//! non-complex BaseCube / complex scalar (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_div_post> +operator/ + ( + const BaseCube& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_div_post>('j', X.get_ref(), k); + } + + + +//! element-wise division of BaseCube objects with same element type +template +arma_inline +const eGlueCube +operator/ + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlueCube(X.get_ref(), Y.get_ref()); + } + + + +//! element-wise division of BaseCube objects with different element types +template +inline +const mtGlueCube::result, T1, T2, glue_mixed_div> +operator/ + ( + const BaseCube< typename force_different_type::T1_result, T1>& X, + const BaseCube< typename force_different_type::T2_result, T2>& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +arma_inline +Cube +operator/ + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_div(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator/ + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_div(X.get_ref(), Y); + } + + + +template +arma_inline +Cube +operator/ + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_div(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator/ + ( + const Base& X, + const subview_cube_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_div(X.get_ref(), Y); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_cube_minus.hpp b/src/armadillo/include/armadillo_bits/operator_cube_minus.hpp new file mode 100644 index 0000000..53cb414 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_cube_minus.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_cube_minus +//! @{ + + + +//! unary - +template +arma_inline +const eOpCube +operator- + ( + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref()); + } + + + +//! BaseCube - scalar +template +arma_inline +const eOpCube +operator- + ( + const BaseCube& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! scalar - BaseCube +template +arma_inline +const eOpCube +operator- + ( + const typename T1::elem_type k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! complex scalar - non-complex BaseCube (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_minus_pre> +operator- + ( + const std::complex& k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_minus_pre>('j', X.get_ref(), k); + } + + + +//! non-complex BaseCube - complex scalar (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_minus_post> +operator- + ( + const BaseCube& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_minus_post>('j', X.get_ref(), k); + } + + + +//! subtraction of BaseCube objects with same element type +template +arma_inline +const eGlueCube +operator- + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlueCube(X.get_ref(), Y.get_ref()); + } + + + +//! subtraction of BaseCube objects with different element types +template +inline +const mtGlueCube::result, T1, T2, glue_mixed_minus> +operator- + ( + const BaseCube< typename force_different_type::T1_result, T1>& X, + const BaseCube< typename force_different_type::T2_result, T2>& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +arma_inline +Cube +operator- + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_minus(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator- + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_minus(X.get_ref(), Y); + } + + + +template +arma_inline +Cube +operator- + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_minus(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator- + ( + const Base& X, + const subview_cube_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_minus(X.get_ref(), Y); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_cube_plus.hpp b/src/armadillo/include/armadillo_bits/operator_cube_plus.hpp new file mode 100644 index 0000000..fb360fe --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_cube_plus.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_cube_plus +//! @{ + + + +//! unary plus operation (does nothing, but is required for completeness) +template +arma_inline +const BaseCube& +operator+ + ( + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return X; + } + + + +//! BaseCube + scalar +template +arma_inline +const eOpCube +operator+ + ( + const BaseCube& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! scalar + BaseCube +template +arma_inline +const eOpCube +operator+ + ( + const typename T1::elem_type k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! non-complex BaseCube + complex scalar (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_plus> +operator+ + ( + const BaseCube& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_plus>('j', X.get_ref(), k); + } + + + +//! complex scalar + non-complex BaseCube (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_plus> +operator+ + ( + const std::complex& k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_plus>('j', X.get_ref(), k); // NOTE: order is swapped + } + + + +//! addition of BaseCube objects with same element type +template +arma_inline +const eGlueCube +operator+ + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlueCube(X.get_ref(), Y.get_ref()); + } + + + +//! addition of BaseCube objects with different element types +template +inline +const mtGlueCube::result, T1, T2, glue_mixed_plus> +operator+ + ( + const BaseCube< typename force_different_type::T1_result, T1>& X, + const BaseCube< typename force_different_type::T2_result, T2>& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +arma_inline +Cube +operator+ + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_plus(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator+ + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_plus(Y, X.get_ref()); // NOTE: swapped order + } + + + +template +arma_inline +Cube +operator+ + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_plus(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator+ + ( + const Base& X, + const subview_cube_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_plus(Y, X.get_ref()); // NOTE: swapped order + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_cube_relational.hpp b/src/armadillo/include/armadillo_bits/operator_cube_relational.hpp new file mode 100644 index 0000000..8270d0f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_cube_relational.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_cube_relational +//! @{ + + + +// < : lt +// > : gt +// <= : lteq +// >= : gteq +// == : eq +// != : noteq +// && : and +// || : or + + + +template +inline +const mtGlueCube +operator< +(const BaseCube::result,T1>& X, const BaseCube::result,T2>& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +inline +const mtGlueCube +operator> +(const BaseCube::result,T1>& X, const BaseCube::result,T2>& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +inline +const mtGlueCube +operator<= +(const BaseCube::result,T1>& X, const BaseCube::result,T2>& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +inline +const mtGlueCube +operator>= +(const BaseCube::result,T1>& X, const BaseCube::result,T2>& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +inline +const mtGlueCube +operator== +(const BaseCube& X, const BaseCube& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +inline +const mtGlueCube +operator!= +(const BaseCube& X, const BaseCube& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +inline +const mtGlueCube +operator&& +(const BaseCube::result,T1>& X, const BaseCube::result,T2>& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +inline +const mtGlueCube +operator|| +(const BaseCube::result,T1>& X, const BaseCube::result,T2>& Y) + { + arma_extra_debug_sigprint(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +// +// +// + + + +template +inline +const mtOpCube +operator< +(const typename arma_not_cx::result val, const BaseCube::result,T1>& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator< +(const BaseCube::result,T1>& X, const typename arma_not_cx::result val) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator> +(const typename arma_not_cx::result val, const BaseCube::result,T1>& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator> +(const BaseCube::result,T1>& X, const typename arma_not_cx::result val) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator<= +(const typename arma_not_cx::result val, const BaseCube::result,T1>& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator<= +(const BaseCube::result,T1>& X, const typename arma_not_cx::result val) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator>= +(const typename arma_not_cx::result val, const BaseCube::result,T1>& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator>= +(const BaseCube::result,T1>& X, const typename arma_not_cx::result val) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator== +(const typename T1::elem_type val, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator== +(const BaseCube& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator!= +(const typename T1::elem_type val, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +template +inline +const mtOpCube +operator!= +(const BaseCube& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOpCube(X.get_ref(), val); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_cube_schur.hpp b/src/armadillo/include/armadillo_bits/operator_cube_schur.hpp new file mode 100644 index 0000000..21b7ee1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_cube_schur.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_cube_schur +//! @{ + + +// operator %, which we define it to do a schur product (element-wise multiplication) + + +//! element-wise multiplication of BaseCube objects with same element type +template +arma_inline +const eGlueCube +operator% + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlueCube(X.get_ref(), Y.get_ref()); + } + + + +//! element-wise multiplication of BaseCube objects with different element types +template +inline +const mtGlueCube::result, T1, T2, glue_mixed_schur> +operator% + ( + const BaseCube< typename force_different_type::T1_result, T1>& X, + const BaseCube< typename force_different_type::T2_result, T2>& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlueCube( X.get_ref(), Y.get_ref() ); + } + + + +template +arma_inline +Cube +operator% + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_schur(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator% + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_schur(Y, X.get_ref()); // NOTE: swapped order + } + + + +template +arma_inline +Cube +operator% + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_schur(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator% + ( + const Base& X, + const subview_cube_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each2_aux::operator_schur(Y, X.get_ref()); // NOTE: swapped order + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_cube_times.hpp b/src/armadillo/include/armadillo_bits/operator_cube_times.hpp new file mode 100644 index 0000000..0b9cf76 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_cube_times.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_cube_times +//! @{ + + + +//! BaseCube * scalar +template +arma_inline +const eOpCube +operator* + ( + const BaseCube& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! scalar * BaseCube +template +arma_inline +const eOpCube +operator* + ( + const typename T1::elem_type k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return eOpCube(X.get_ref(), k); + } + + + +//! non-complex BaseCube * complex scalar (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_times> +operator* + ( + const BaseCube& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_times>('j', X.get_ref(), k); + } + + + +//! complex scalar * non-complex BaseCube (experimental) +template +arma_inline +const mtOpCube, T1, op_cx_scalar_times> +operator* + ( + const std::complex& k, + const BaseCube& X + ) + { + arma_extra_debug_sigprint(); + + return mtOpCube, T1, op_cx_scalar_times>('j', X.get_ref(), k); + } + + + +template +arma_inline +Cube +operator* + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_times(X, Y.get_ref()); + } + + + +template +arma_inline +Cube +operator* + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_cube_each1_aux::operator_times(X.get_ref(), Y); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_div.hpp b/src/armadillo/include/armadillo_bits/operator_div.hpp new file mode 100644 index 0000000..4f17fdf --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_div.hpp @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_div +//! @{ + + + +//! Base / scalar +template +arma_inline +typename +enable_if2< is_arma_type::value, const eOp< T1, eop_scalar_div_post> >::result +operator/ + ( + const T1& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return eOp(X, k); + } + + + +//! scalar / Base +template +arma_inline +typename +enable_if2< is_arma_type::value, const eOp< T1, eop_scalar_div_pre> >::result +operator/ + ( + const typename T1::elem_type k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return eOp(X, k); + } + + + +//! complex scalar / non-complex Base +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_div_pre> + >::result +operator/ + ( + const std::complex& k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_div_pre>('j', X, k); + } + + + +//! non-complex Base / complex scalar +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_div_post> + >::result +operator/ + ( + const T1& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_div_post>('j', X, k); + } + + + +//! element-wise division of Base objects with same element type +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && is_same_type::value), + const eGlue + >::result +operator/ + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlue(X, Y); + } + + + +//! element-wise division of Base objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_same_type::no)), + const mtGlue::result, T1, T2, glue_mixed_div> + >::result +operator/ + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlue( X, Y ); + } + + + +//! element-wise division of sparse matrix by scalar +template +inline +typename +enable_if2< is_arma_sparse_type::value, SpMat >::result +operator/ + ( + const T1& X, + const typename T1::elem_type y + ) + { + arma_extra_debug_sigprint(); + + SpMat result(X); + + result /= y; + + return result; + } + + + +//! element-wise division of one sparse and one dense object +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), + SpMat + >::result +operator/ + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy pa(x); + const Proxy pb(y); + + const uword n_rows = pa.get_n_rows(); + const uword n_cols = pa.get_n_cols(); + + arma_debug_assert_same_size(n_rows, n_cols, pb.get_n_rows(), pb.get_n_cols(), "element-wise division"); + + uword new_n_nonzero = 0; + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT val = pa.at(row,col) / pb.at(row, col); + + if(val != eT(0)) + { + ++new_n_nonzero; + } + } + + SpMat result(arma_reserve_indicator(), n_rows, n_cols, new_n_nonzero); + + uword cur_pos = 0; + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT val = pa.at(row,col) / pb.at(row, col); + + if(val != eT(0)) + { + access::rw(result.values[cur_pos]) = val; + access::rw(result.row_indices[cur_pos]) = row; + ++access::rw(result.col_ptrs[col + 1]); + ++cur_pos; + } + } + + // Fix column pointers + for(uword col = 1; col <= result.n_cols; ++col) + { + access::rw(result.col_ptrs[col]) += result.col_ptrs[col - 1]; + } + + return result; + } + + + +//! optimization: element-wise division of sparse / (sparse +/- scalar) +template +inline +typename +enable_if2 + < + ( + is_arma_sparse_type::value && is_arma_sparse_type::value && + is_same_type::yes && + (is_same_type::value || + is_same_type::value || + is_same_type::value) + ), + SpMat + >::result +operator/ + ( + const T1& x, + const SpToDOp& y + ) + { + arma_extra_debug_sigprint(); + + SpMat out; + + op_type::apply_inside_div(out, x, y); + + return out; + } + + + +//! element-wise division of one dense and one sparse object +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), + Mat + >::result +operator/ + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy pa(x); + const SpProxy pb(y); + + const uword n_rows = pa.get_n_rows(); + const uword n_cols = pa.get_n_cols(); + + arma_debug_assert_same_size(n_rows, n_cols, pb.get_n_rows(), pb.get_n_cols(), "element-wise division"); + + Mat result(n_rows, n_cols, arma_nozeros_indicator()); + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + result.at(row, col) = pa.at(row, col) / pb.at(row, col); + } + + return result; + } + + + +template +arma_inline +Mat +operator/ + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_div(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator/ + ( + const Base& X, + const subview_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_div(X.get_ref(), Y); + } + + + +template +arma_inline +Mat +operator/ + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_div(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator/ + ( + const Base& X, + const subview_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_div(X.get_ref(), Y); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_minus.hpp b/src/armadillo/include/armadillo_bits/operator_minus.hpp new file mode 100644 index 0000000..42047a7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_minus.hpp @@ -0,0 +1,570 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_minus +//! @{ + + + +//! unary - +template +arma_inline +typename +enable_if2< is_arma_type::value, const eOp >::result +operator- +(const T1& X) + { + arma_extra_debug_sigprint(); + + return eOp(X); + } + + + +//! Base - scalar +template +arma_inline +typename +enable_if2< is_arma_type::value, const eOp >::result +operator- + ( + const T1& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return eOp(X, k); + } + + + +//! scalar - Base +template +arma_inline +typename +enable_if2< is_arma_type::value, const eOp >::result +operator- + ( + const typename T1::elem_type k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return eOp(X, k); + } + + + +//! complex scalar - non-complex Base +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_minus_pre> + >::result +operator- + ( + const std::complex& k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_minus_pre>('j', X, k); + } + + + +//! non-complex Base - complex scalar +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_minus_post> + >::result +operator- + ( + const T1& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_minus_post>('j', X, k); + } + + + +//! subtraction of Base objects with same element type +template +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value, + const eGlue + >::result +operator- + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlue(X, Y); + } + + + +//! subtraction of Base objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_same_type::no)), + const mtGlue::result, T1, T2, glue_mixed_minus> + >::result +operator- + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlue( X, Y ); + } + + + +//! unary "-" for sparse objects +template +inline +typename +enable_if2 + < + is_arma_sparse_type::value && is_signed::value, + SpOp + >::result +operator- +(const T1& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + return SpOp(X, eT(-1)); + } + + + +//! subtraction of two sparse objects +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::value), + const SpGlue + >::result +operator- + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + return SpGlue(X,Y); + } + + + +//! subtraction of one sparse and one dense object +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), + Mat + >::result +operator- + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + const SpProxy pa(x); + + Mat result(-y); + + arma_debug_assert_same_size( pa.get_n_rows(), pa.get_n_cols(), result.n_rows, result.n_cols, "subtraction" ); + + typename SpProxy::const_iterator_type it = pa.begin(); + typename SpProxy::const_iterator_type it_end = pa.end(); + + while(it != it_end) + { + result.at(it.row(), it.col()) += (*it); + ++it; + } + + return result; + } + + + +//! subtraction of one dense and one sparse object +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), + Mat + >::result +operator- + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat result(x); + + const SpProxy pb(y); + + arma_debug_assert_same_size( result.n_rows, result.n_cols, pb.get_n_rows(), pb.get_n_cols(), "subtraction" ); + + typename SpProxy::const_iterator_type it = pb.begin(); + typename SpProxy::const_iterator_type it_end = pb.end(); + + while(it != it_end) + { + result.at(it.row(), it.col()) -= (*it); + ++it; + } + + return result; + } + + + +//! subtraction of two sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::no), + const mtSpGlue< typename promote_type::result, T1, T2, spglue_minus_mixed > + >::result +operator- + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtSpGlue( X, Y ); + } + + + +//! subtraction of sparse and non-sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::no), + Mat< typename promote_type::result > + >::result +operator- + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat< typename promote_type::result > out; + + spglue_minus_mixed::sparse_minus_dense(out, x, y); + + return out; + } + + + +//! subtraction of sparse and non-sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::no), + Mat< typename promote_type::result > + >::result +operator- + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat< typename promote_type::result > out; + + spglue_minus_mixed::dense_minus_sparse(out, x, y); + + return out; + } + + + +//! sparse - scalar +template +arma_inline +typename +enable_if2< is_arma_sparse_type::value, const SpToDOp >::result +operator- + ( + const T1& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(X, k); + } + + + +//! scalar - sparse +template +arma_inline +typename +enable_if2< is_arma_sparse_type::value, const SpToDOp >::result +operator- + ( + const typename T1::elem_type k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(X, k); + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + (is_same_type::value || + is_same_type::value)), + const SpToDOp + >::result +operator- + ( + const SpToDOp& x, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + const typename T1::elem_type aux = (is_same_type::value) ? -x.aux : x.aux; + + return SpToDOp(x.m, aux + k); + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + (is_same_type::value || + is_same_type::value)), + const SpToDOp + >::result +operator- + ( + const typename T1::elem_type k, + const SpToDOp& x + ) + { + arma_extra_debug_sigprint(); + + const typename T1::elem_type aux = (is_same_type::value) ? -x.aux : x.aux; + + return SpToDOp(x.m, k + aux); + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + is_same_type::value), + const SpToDOp + >::result +operator- + ( + const SpToDOp& x, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(x.m, x.aux - k); + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + is_same_type::value), + const SpToDOp + >::result +operator- + ( + const typename T1::elem_type k, + const SpToDOp& x + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(x.m, k - x.aux); + } + + + +template +arma_inline +Mat +operator- + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_minus(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator- + ( + const Base& X, + const subview_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_minus(X.get_ref(), Y); + } + + + +template +arma_inline +Mat +operator- + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_minus(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator- + ( + const Base& X, + const subview_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_minus(X.get_ref(), Y); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_ostream.hpp b/src/armadillo/include/armadillo_bits/operator_ostream.hpp new file mode 100644 index 0000000..ce9a9cb --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_ostream.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_ostream +//! @{ + + + +template +inline +std::ostream& +operator<< (std::ostream& o, const Base& X) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(X.get_ref()); + + arma_ostream::print(o, tmp.M, true); + + return o; + } + + + +template +inline +std::ostream& +operator<< (std::ostream& o, const SpBase& X) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp(X.get_ref()); + + arma_ostream::print(o, tmp.M, true); + + return o; + } + + + +template +inline +std::ostream& +operator<< (std::ostream& o, const SpValProxy& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + o << eT(X); + + return o; + } + + + +template +inline +std::ostream& +operator<< (std::ostream& o, const MapMat_val& X) + { + arma_extra_debug_sigprint(); + + o << eT(X); + + return o; + } + + + +template +inline +std::ostream& +operator<< (std::ostream& o, const SpMat_MapMat_val& X) + { + arma_extra_debug_sigprint(); + + o << eT(X); + + return o; + } + + + +template +inline +std::ostream& +operator<< (std::ostream& o, const SpSubview_MapMat_val& X) + { + arma_extra_debug_sigprint(); + + o << eT(X); + + return o; + } + + + +template +inline +std::ostream& +operator<< (std::ostream& o, const BaseCube& X) + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp(X.get_ref()); + + arma_ostream::print(o, tmp.M, true); + + return o; + } + + + +//! Print the contents of a field to the specified stream. +template +inline +std::ostream& +operator<< (std::ostream& o, const field& X) + { + arma_extra_debug_sigprint(); + + arma_ostream::print(o, X); + + return o; + } + + + +//! Print the contents of a subfield to the specified stream +template +inline +std::ostream& +operator<< (std::ostream& o, const subview_field& X) + { + arma_extra_debug_sigprint(); + + arma_ostream::print(o, X); + + return o; + } + + + +inline +std::ostream& +operator<< (std::ostream& o, const SizeMat& S) + { + arma_extra_debug_sigprint(); + + arma_ostream::print(o, S); + + return o; + } + + + +inline +std::ostream& +operator<< (std::ostream& o, const SizeCube& S) + { + arma_extra_debug_sigprint(); + + arma_ostream::print(o, S); + + return o; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_plus.hpp b/src/armadillo/include/armadillo_bits/operator_plus.hpp new file mode 100644 index 0000000..3cae597 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_plus.hpp @@ -0,0 +1,540 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_plus +//! @{ + + + +//! unary plus operation (does nothing, but is required for completeness) +template +arma_inline +typename enable_if2< is_arma_type::value, const T1& >::result +operator+ +(const T1& X) + { + arma_extra_debug_sigprint(); + + return X; + } + + + +//! Base + scalar +template +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +operator+ +(const T1& X, const typename T1::elem_type k) + { + arma_extra_debug_sigprint(); + + return eOp(X, k); + } + + + +//! scalar + Base +template +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +operator+ +(const typename T1::elem_type k, const T1& X) + { + arma_extra_debug_sigprint(); + + return eOp(X, k); // NOTE: order is swapped + } + + + +//! non-complex Base + complex scalar +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_plus> + >::result +operator+ + ( + const T1& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_plus>('j', X, k); + } + + + +//! complex scalar + non-complex Base +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_plus> + >::result +operator+ + ( + const std::complex& k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_plus>('j', X, k); // NOTE: order is swapped + } + + + +//! addition of user-accessible Armadillo objects with same element type +template +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value, + const eGlue + >::result +operator+ + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlue(X, Y); + } + + + +//! addition of user-accessible Armadillo objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_same_type::no)), + const mtGlue::result, T1, T2, glue_mixed_plus> + >::result +operator+ + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlue( X, Y ); + } + + + +//! addition of two sparse objects +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::value), + SpGlue + >::result +operator+ + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + return SpGlue(x, y); + } + + + +//! addition of one dense and one sparse object +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), + Mat + >::result +operator+ + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat result(x); + + const SpProxy pb(y); + + arma_debug_assert_same_size( result.n_rows, result.n_cols, pb.get_n_rows(), pb.get_n_cols(), "addition" ); + + typename SpProxy::const_iterator_type it = pb.begin(); + typename SpProxy::const_iterator_type it_end = pb.end(); + + while(it != it_end) + { + result.at(it.row(), it.col()) += (*it); + ++it; + } + + return result; + } + + + +//! addition of one sparse and one dense object +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), + Mat + >::result +operator+ + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + const SpProxy pa(x); + + Mat result(y); + + arma_debug_assert_same_size( pa.get_n_rows(), pa.get_n_cols(), result.n_rows, result.n_cols, "addition" ); + + typename SpProxy::const_iterator_type it = pa.begin(); + typename SpProxy::const_iterator_type it_end = pa.end(); + + while(it != it_end) + { + result.at(it.row(), it.col()) += (*it); + ++it; + } + + return result; + } + + + +//! addition of two sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::no), + const mtSpGlue< typename promote_type::result, T1, T2, spglue_plus_mixed > + >::result +operator+ + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtSpGlue( X, Y ); + } + + + +//! addition of sparse and non-sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::no), + Mat< typename promote_type::result > + >::result +operator+ + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat< typename promote_type::result > out; + + spglue_plus_mixed::dense_plus_sparse(out, x, y); + + return out; + } + + + +//! addition of sparse and non-sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::no), + Mat< typename promote_type::result > + >::result +operator+ + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + Mat< typename promote_type::result > out; + + // Just call the other order (these operations are commutative) + // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order + spglue_plus_mixed::dense_plus_sparse(out, y, x); + + return out; + } + + + +//! addition of sparse object with scalar +template +inline +typename enable_if2< is_arma_sparse_type::value, const SpToDOp >::result +operator+ + ( + const T1& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(X, k); + } + + + +template +inline +typename enable_if2< is_arma_sparse_type::value, const SpToDOp >::result +operator+ + ( + const typename T1::elem_type k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(X, k); // NOTE: swapped order + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + (is_same_type::value || + is_same_type::value)), + const SpToDOp + >::result +operator+ + ( + const SpToDOp& x, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + const typename T1::elem_type aux = (is_same_type::value) ? x.aux : -x.aux; + + return SpToDOp(x.m, aux + k); + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + is_same_type::value), + const SpToDOp + >::result +operator+ + ( + const SpToDOp& x, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(x.m, x.aux + k); + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + (is_same_type::value || + is_same_type::value)), + const SpToDOp + >::result +operator+ + ( + const typename T1::elem_type k, + const SpToDOp& x + ) + { + arma_extra_debug_sigprint(); + + const typename T1::elem_type aux = (is_same_type::value) ? x.aux : -x.aux; + + return SpToDOp(x.m, aux + k); + } + + + +// TODO: this is an uncommon use case; remove? +//! multiple applications of add/subtract scalars can be condensed +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && + is_same_type::value), + const SpToDOp + >::result +operator+ + ( + const typename T1::elem_type k, + const SpToDOp& x + ) + { + arma_extra_debug_sigprint(); + + return SpToDOp(x.m, x.aux + k); + } + + + + +template +arma_inline +Mat +operator+ + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_plus(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator+ + ( + const Base& X, + const subview_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_plus(Y, X.get_ref()); // NOTE: swapped order + } + + + +template +arma_inline +Mat +operator+ + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_plus(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator+ + ( + const Base& X, + const subview_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_plus(Y, X.get_ref()); // NOTE: swapped order + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_relational.hpp b/src/armadillo/include/armadillo_bits/operator_relational.hpp new file mode 100644 index 0000000..7313fdc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_relational.hpp @@ -0,0 +1,483 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_relational +//! @{ + + +// < : lt +// > : gt +// <= : lteq +// >= : gteq +// == : eq +// != : noteq +// && : and +// || : or + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_cx::no) && (is_cx::no)), + const mtGlue + >::result +operator< +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_cx::no) && (is_cx::no)), + const mtGlue + >::result +operator> +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_cx::no) && (is_cx::no)), + const mtGlue + >::result +operator<= +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_cx::no) && (is_cx::no)), + const mtGlue + >::result +operator>= +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value), + const mtGlue + >::result +operator== +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value), + const mtGlue + >::result +operator!= +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_cx::no) && (is_cx::no)), + const mtGlue + >::result +operator&& +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_cx::no) && (is_cx::no)), + const mtGlue + >::result +operator|| +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return mtGlue( X, Y ); + } + + + +// +// +// + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator< +(const typename T1::elem_type val, const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator< +(const T1& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator> +(const typename T1::elem_type val, const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator> +(const T1& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator<= +(const typename T1::elem_type val, const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator<= +(const T1& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator>= +(const typename T1::elem_type val, const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_type::value && (is_cx::no)), + const mtOp + >::result +operator>= +(const T1& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +operator== +(const typename T1::elem_type val, const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +operator== +(const T1& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +operator!= +(const typename T1::elem_type val, const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +template +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +operator!= +(const T1& X, const typename T1::elem_type val) + { + arma_extra_debug_sigprint(); + + return mtOp(X, val); + } + + + +// + + + +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && (is_cx::no) && (is_cx::no)), + const mtSpGlue + >::result +operator< +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + // TODO: ensure T1::elem_type and T2::elem_type are the same + + return mtSpGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && (is_cx::no) && (is_cx::no)), + const mtSpGlue + >::result +operator> +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + // TODO: ensure T1::elem_type and T2::elem_type are the same + + return mtSpGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && (is_cx::no) && (is_cx::no)), + const mtSpGlue + >::result +operator&& +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + // TODO: ensure T1::elem_type and T2::elem_type are the same + + return mtSpGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && (is_cx::no) && (is_cx::no)), + const mtSpGlue + >::result +operator|| +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + // TODO: ensure T1::elem_type and T2::elem_type are the same + + return mtSpGlue( X, Y ); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_schur.hpp b/src/armadillo/include/armadillo_bits/operator_schur.hpp new file mode 100644 index 0000000..2acfdbc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_schur.hpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_schur +//! @{ + + +// operator %, which we define it to do a schur product (element-wise multiplication) + + +//! element-wise multiplication of user-accessible Armadillo objects with same element type +template +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value, + const eGlue + >::result +operator% + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + return eGlue(X, Y); + } + + + +//! element-wise multiplication of user-accessible Armadillo objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_same_type::no)), + const mtGlue::result, T1, T2, glue_mixed_schur> + >::result +operator% + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlue( X, Y ); + } + + + +//! element-wise multiplication of two sparse matrices +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::value), + SpGlue + >::result +operator% + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + return SpGlue(x, y); + } + + + +//! element-wise multiplication of one dense and one sparse object +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), + SpMat + >::result +operator% + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + SpMat out; + + spglue_schur_misc::dense_schur_sparse(out, x, y); + + return out; + } + + + +//! element-wise multiplication of one sparse and one dense object +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), + SpMat + >::result +operator% + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + SpMat out; + + // Just call the other order (these operations are commutative) + // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order + spglue_schur_misc::dense_schur_sparse(out, y, x); + + return out; + } + + + +//! optimization: sparse % (sparse +/- scalar) can be done without forming the dense result of the (sparse +/- scalar) term +template +inline +typename +enable_if2 + < + ( + is_arma_sparse_type::value && is_arma_sparse_type::value && + is_same_type::yes && + (is_same_type::value || + is_same_type::value || + is_same_type::value) + ), + SpMat + >::result +operator% + ( + const T1& x, + const SpToDOp& y + ) + { + arma_extra_debug_sigprint(); + + SpMat out; + + op_type::apply_inside_schur(out, x, y); + + return out; + } + + + +//! optimization: (sparse +/- scalar) % sparse can be done without forming the dense result of the (sparse +/- scalar) term +template +inline +typename +enable_if2 + < + ( + is_arma_sparse_type::value && is_arma_sparse_type::value && + is_same_type::yes && + (is_same_type::value || + is_same_type::value || + is_same_type::value) + ), + SpMat + >::result +operator% + ( + const SpToDOp& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + SpMat out; + + // Just call the other order (these operations are commutative) + // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order + op_type::apply_inside_schur(out, y, x); + + return out; + } + + + +//! element-wise multiplication of two sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::no), + const mtSpGlue< typename promote_type::result, T1, T2, spglue_schur_mixed > + >::result +operator% + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtSpGlue( X, Y ); + } + + + +//! element-wise multiplication of one dense and one sparse object with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::no), + SpMat< typename promote_type::result > + >::result +operator% + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + SpMat< typename promote_type::result > out; + + spglue_schur_mixed::dense_schur_sparse(out, x, y); + + return out; + } + + + +//! element-wise multiplication of one sparse and one dense object with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::no), + SpMat< typename promote_type::result > + >::result +operator% + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + SpMat< typename promote_type::result > out; + + // Just call the other order (these operations are commutative) + // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order + spglue_schur_mixed::dense_schur_sparse(out, y, x); + + return out; + } + + + +template +inline +Mat +operator% + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_schur(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator% + ( + const Base& X, + const subview_each1& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each1_aux::operator_schur(Y, X.get_ref()); // NOTE: swapped order + } + + + +template +inline +Mat +operator% + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_schur(X, Y.get_ref()); + } + + + +template +arma_inline +Mat +operator% + ( + const Base& X, + const subview_each2& Y + ) + { + arma_extra_debug_sigprint(); + + return subview_each2_aux::operator_schur(Y, X.get_ref()); // NOTE: swapped order + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/operator_times.hpp b/src/armadillo/include/armadillo_bits/operator_times.hpp new file mode 100644 index 0000000..861166c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/operator_times.hpp @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup operator_times +//! @{ + + + +//! Base * scalar +template +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +operator* +(const T1& X, const typename T1::elem_type k) + { + arma_extra_debug_sigprint(); + + return eOp(X,k); + } + + + +//! scalar * Base +template +arma_inline +typename enable_if2< is_arma_type::value, const eOp >::result +operator* +(const typename T1::elem_type k, const T1& X) + { + arma_extra_debug_sigprint(); + + return eOp(X,k); // NOTE: order is swapped + } + + + +//! non-complex Base * complex scalar +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_times> + >::result +operator* + ( + const T1& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_times>('j', X, k); + } + + + +//! complex scalar * non-complex Base +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_cx::no), + const mtOp, T1, op_cx_scalar_times> + >::result +operator* + ( + const std::complex& k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return mtOp, T1, op_cx_scalar_times>('j', X, k); + } + + + +//! scalar * trans(T1) +template +arma_inline +const Op +operator* +(const typename T1::elem_type k, const Op& X) + { + arma_extra_debug_sigprint(); + + return Op(X.m, k); + } + + + +//! trans(T1) * scalar +template +arma_inline +const Op +operator* +(const Op& X, const typename T1::elem_type k) + { + arma_extra_debug_sigprint(); + + return Op(X.m, k); + } + + + +//! Base * diagmat +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_same_type::value), + const Glue, glue_times_diag> + >::result +operator* +(const T1& X, const Op& Y) + { + arma_extra_debug_sigprint(); + + return Glue, glue_times_diag>(X, Y); + } + + + +//! diagmat * Base +template +arma_inline +typename +enable_if2 + < + (is_arma_type::value && is_same_type::value), + const Glue, T2, glue_times_diag> + >::result +operator* +(const Op& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return Glue, T2, glue_times_diag>(X, Y); + } + + + +//! diagmat * diagmat +template +inline +Mat< typename promote_type::result > +operator* +(const Op& X, const Op& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const diagmat_proxy A(X.m); + const diagmat_proxy B(Y.m); + + arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + Mat out(A.n_rows, B.n_cols, arma_zeros_indicator()); + + const uword A_length = (std::min)(A.n_rows, A.n_cols); + const uword B_length = (std::min)(B.n_rows, B.n_cols); + + const uword N = (std::min)(A_length, B_length); + + for(uword i=0; i::apply( A[i] ) * upgrade_val::apply( B[i] ); + } + + return out; + } + + + +//! multiplication of Base objects with same element type +template +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value, + const Glue + >::result +operator* +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + return Glue(X, Y); + } + + + +//! multiplication of Base objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_type::value && (is_same_type::no)), + const mtGlue< typename promote_type::result, T1, T2, glue_mixed_times > + >::result +operator* + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtGlue( X, Y ); + } + + + +//! sparse multiplied by scalar +template +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + SpOp + >::result +operator* + ( + const T1& X, + const typename T1::elem_type k + ) + { + arma_extra_debug_sigprint(); + + return SpOp(X, k); + } + + + +template +inline +typename +enable_if2 + < + is_arma_sparse_type::value, + SpOp + >::result +operator* + ( + const typename T1::elem_type k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return SpOp(X, k); + } + + + +//! non-complex sparse * complex scalar +template +arma_inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_cx::no), + const mtSpOp, T1, spop_cx_scalar_times> + >::result +operator* + ( + const T1& X, + const std::complex& k + ) + { + arma_extra_debug_sigprint(); + + return mtSpOp, T1, spop_cx_scalar_times>('j', X, k); + } + + + +//! complex scalar * non-complex sparse +template +arma_inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_cx::no), + const mtSpOp, T1, spop_cx_scalar_times> + >::result +operator* + ( + const std::complex& k, + const T1& X + ) + { + arma_extra_debug_sigprint(); + + return mtSpOp, T1, spop_cx_scalar_times>('j', X, k); + } + + + +//! multiplication of two sparse objects +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && is_same_type::value), + const SpGlue + >::result +operator* + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + return SpGlue(x, y); + } + + + +//! multiplication of one sparse and one dense object +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), + const SpToDGlue + >::result +operator* + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + return SpToDGlue(x, y); + } + + + +//! multiplication of one dense and one sparse object +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), + const SpToDGlue + >::result +operator* + ( + const T1& x, + const T2& y + ) + { + arma_extra_debug_sigprint(); + + return SpToDGlue(x, y); + } + + + +//! multiplication of two sparse objects with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && (is_same_type::no)), + const mtSpGlue< typename promote_type::result, T1, T2, spglue_times_mixed > + >::result +operator* + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + return mtSpGlue( X, Y ); + } + + + +//! multiplication of one sparse and one dense object with different element types +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_type::value && is_same_type::no), + Mat< typename promote_type::result > + >::result +operator* + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + Mat< typename promote_type::result > out; + + glue_times_sparse_dense::apply_mixed(out, X, Y); + + return out; + } + + + +//! multiplication of one dense and one sparse object with different element types +template +inline +typename +enable_if2 + < + (is_arma_type::value && is_arma_sparse_type::value && is_same_type::no), + Mat< typename promote_type::result > + >::result +operator* + ( + const T1& X, + const T2& Y + ) + { + arma_extra_debug_sigprint(); + + Mat< typename promote_type::result > out; + + glue_times_dense_sparse::apply_mixed(out, X, Y); + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/podarray_bones.hpp b/src/armadillo/include/armadillo_bits/podarray_bones.hpp new file mode 100644 index 0000000..9aa2cf1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/podarray_bones.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup podarray +//! @{ + + + +struct podarray_prealloc_n_elem + { + static constexpr uword val = 16; + }; + + + +//! A lightweight array for POD types. For internal use only! +template +class podarray + { + public: + + arma_aligned const uword n_elem; //!< number of elements held + arma_aligned eT* mem; //!< pointer to memory used by the object + + + protected: + //! internal memory, to avoid calling the 'new' operator for small amounts of memory. + arma_align_mem eT mem_local[ podarray_prealloc_n_elem::val ]; + + + public: + + inline ~podarray(); + inline podarray(); + + inline podarray (const podarray& x); + inline const podarray& operator=(const podarray& x); + + arma_inline explicit podarray(const uword new_N); + + template + inline explicit podarray(const uword new_N, const arma_initmode_indicator&); + + arma_inline eT& operator[] (const uword i); + arma_inline eT operator[] (const uword i) const; + + arma_inline eT& operator() (const uword i); + arma_inline eT operator() (const uword i) const; + + inline void set_min_size(const uword min_n_elem); + + inline void set_size(const uword new_n_elem); + inline void reset(); + + + inline void fill(const eT val); + + inline void zeros(); + inline void zeros(const uword new_n_elem); + + arma_inline eT* memptr(); + arma_inline const eT* memptr() const; + + inline void copy_row(const Mat& A, const uword row); + + + protected: + + inline void init_cold(const uword new_n_elem); + inline void init_warm(const uword new_n_elem); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/podarray_meat.hpp b/src/armadillo/include/armadillo_bits/podarray_meat.hpp new file mode 100644 index 0000000..2eb62f2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/podarray_meat.hpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup podarray +//! @{ + + +template +inline +podarray::~podarray() + { + arma_extra_debug_sigprint_this(this); + + if(n_elem > podarray_prealloc_n_elem::val ) + { + memory::release( mem ); + } + } + + + +template +inline +podarray::podarray() + : n_elem(0) + , mem (0) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +podarray::podarray(const podarray& x) + : n_elem(x.n_elem) + { + arma_extra_debug_sigprint(); + + const uword x_n_elem = x.n_elem; + + init_cold(x_n_elem); + + arrayops::copy( memptr(), x.memptr(), x_n_elem ); + } + + + +template +inline +const podarray& +podarray::operator=(const podarray& x) + { + arma_extra_debug_sigprint(); + + if(this != &x) + { + const uword x_n_elem = x.n_elem; + + init_warm(x_n_elem); + + arrayops::copy( memptr(), x.memptr(), x_n_elem ); + } + + return *this; + } + + + +template +arma_inline +podarray::podarray(const uword new_n_elem) + : n_elem(new_n_elem) + { + arma_extra_debug_sigprint_this(this); + + init_cold(new_n_elem); + } + + + +template +template +inline +podarray::podarray(const uword new_n_elem, const arma_initmode_indicator&) + : n_elem(new_n_elem) + { + arma_extra_debug_sigprint_this(this); + + init_cold(new_n_elem); + + if(do_zeros) + { + arma_extra_debug_print("podarray::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +template +arma_inline +eT +podarray::operator[] (const uword i) const + { + return mem[i]; + } + + + +template +arma_inline +eT& +podarray::operator[] (const uword i) + { + return access::rw(mem[i]); + } + + + +template +arma_inline +eT +podarray::operator() (const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "podarray::operator(): index out of bounds" ); + + return mem[i]; + } + + + +template +arma_inline +eT& +podarray::operator() (const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "podarray::operator(): index out of bounds" ); + + return access::rw(mem[i]); + } + + + +template +inline +void +podarray::set_min_size(const uword min_n_elem) + { + arma_extra_debug_sigprint(); + + if(min_n_elem > n_elem) { init_warm(min_n_elem); } + } + + + +template +inline +void +podarray::set_size(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + init_warm(new_n_elem); + } + + + +template +inline +void +podarray::reset() + { + arma_extra_debug_sigprint(); + + init_warm(0); + } + + + +template +inline +void +podarray::fill(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_set(memptr(), val, n_elem); + } + + + +template +inline +void +podarray::zeros() + { + arma_extra_debug_sigprint(); + + arrayops::fill_zeros(memptr(), n_elem); + } + + + +template +inline +void +podarray::zeros(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + init_warm(new_n_elem); + + arrayops::fill_zeros(memptr(), n_elem); + } + + + +template +arma_inline +eT* +podarray::memptr() + { + return mem; + } + + + +template +arma_inline +const eT* +podarray::memptr() const + { + return mem; + } + + + +template +inline +void +podarray::copy_row(const Mat& A, const uword row) + { + arma_extra_debug_sigprint(); + + // note: this function assumes that the podarray has been set to the correct size beforehand + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + const eT* A_mem = &(A.at(row,0)); + eT* out_mem = memptr(); + + for(uword i=0; i < n_cols; ++i) + { + out_mem[i] = (*A_mem); + + A_mem += n_rows; + } + } + + + +template +inline +void +podarray::init_cold(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + mem = (new_n_elem <= podarray_prealloc_n_elem::val) ? mem_local : memory::acquire(new_n_elem); + } + + + +template +inline +void +podarray::init_warm(const uword new_n_elem) + { + arma_extra_debug_sigprint(); + + if(n_elem == new_n_elem) { return; } + + if(n_elem > podarray_prealloc_n_elem::val) { memory::release( mem ); } + + mem = (new_n_elem <= podarray_prealloc_n_elem::val) ? mem_local : memory::acquire(new_n_elem); + + access::rw(n_elem) = new_n_elem; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/promote_type.hpp b/src/armadillo/include/armadillo_bits/promote_type.hpp new file mode 100644 index 0000000..d53eb32 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/promote_type.hpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup promote_type +//! @{ + + +template +struct is_promotable + { + static constexpr bool value = false; + typedef T1 result; + }; + + +struct is_promotable_ok + { + static constexpr bool value = true; + }; + + +template struct is_promotable : public is_promotable_ok { typedef T result; }; +template struct is_promotable, T> : public is_promotable_ok { typedef std::complex result; }; + +template<> struct is_promotable, std::complex> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable, float> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable, double> : public is_promotable_ok { typedef std::complex result; }; + + +template struct is_promotable, u64> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, s64> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, ulng_t> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, slng_t> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, s32> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, u32> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, s16> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, u16> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, s8> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable, u8> : public is_promotable_ok { typedef std::complex result; }; + + +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef u32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef u32 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; // s32 ? +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; // s32 ? +template<> struct is_promotable : public is_promotable_ok { typedef u16 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s8 result; }; // s16 ? + + + + +// +// mirrored versions + +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; + +template<> struct is_promotable, std::complex> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; + +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; + + +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; +template<> struct is_promotable : public is_promotable_ok { typedef double result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; +template<> struct is_promotable : public is_promotable_ok { typedef float result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef u32 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? +template<> struct is_promotable : public is_promotable_ok { typedef u32 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; // s32 ? +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s16 result; }; // s32 ? +template<> struct is_promotable : public is_promotable_ok { typedef u16 result; }; + +template<> struct is_promotable : public is_promotable_ok { typedef s8 result; }; // s16 ? + + + + + +template +struct promote_type + { + inline static void check() + { + arma_type_check(( is_promotable::value == false )); + } + + typedef typename is_promotable::result result; + }; + + + +template +struct eT_promoter + { + typedef typename promote_type::result eT; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/restrictors.hpp b/src/armadillo/include/armadillo_bits/restrictors.hpp new file mode 100644 index 0000000..019a5f4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/restrictors.hpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup restrictors +//! @{ + + + +// structures for template based restrictions of input/output arguments +// (part of the SFINAE approach) +// http://en.wikipedia.org/wiki/SFINAE + + +template struct arma_scalar_only { }; + +template<> struct arma_scalar_only< u8 > { typedef u8 result; }; +template<> struct arma_scalar_only< s8 > { typedef s8 result; }; +template<> struct arma_scalar_only< u16 > { typedef u16 result; }; +template<> struct arma_scalar_only< s16 > { typedef s16 result; }; +template<> struct arma_scalar_only< u32 > { typedef u32 result; }; +template<> struct arma_scalar_only< s32 > { typedef s32 result; }; +template<> struct arma_scalar_only< u64 > { typedef u64 result; }; +template<> struct arma_scalar_only< s64 > { typedef s64 result; }; +template<> struct arma_scalar_only< ulng_t > { typedef ulng_t result; }; +template<> struct arma_scalar_only< slng_t > { typedef slng_t result; }; +template<> struct arma_scalar_only< float > { typedef float result; }; +template<> struct arma_scalar_only< double > { typedef double result; }; +template<> struct arma_scalar_only< cx_float > { typedef cx_float result; }; +template<> struct arma_scalar_only< cx_double > { typedef cx_double result; }; + + + +template struct arma_integral_only { }; + +template<> struct arma_integral_only< u8 > { typedef u8 result; }; +template<> struct arma_integral_only< s8 > { typedef s8 result; }; +template<> struct arma_integral_only< u16 > { typedef u16 result; }; +template<> struct arma_integral_only< s16 > { typedef s16 result; }; +template<> struct arma_integral_only< u32 > { typedef u32 result; }; +template<> struct arma_integral_only< s32 > { typedef s32 result; }; +template<> struct arma_integral_only< u64 > { typedef u64 result; }; +template<> struct arma_integral_only< s64 > { typedef s64 result; }; +template<> struct arma_integral_only< ulng_t > { typedef ulng_t result; }; +template<> struct arma_integral_only< slng_t > { typedef slng_t result; }; + + + +template struct arma_unsigned_integral_only { }; + +template<> struct arma_unsigned_integral_only< u8 > { typedef u8 result; }; +template<> struct arma_unsigned_integral_only< u16 > { typedef u16 result; }; +template<> struct arma_unsigned_integral_only< u32 > { typedef u32 result; }; +template<> struct arma_unsigned_integral_only< u64 > { typedef u64 result; }; +template<> struct arma_unsigned_integral_only< ulng_t > { typedef ulng_t result; }; + + + +template struct arma_signed_integral_only { }; + +template<> struct arma_signed_integral_only< s8 > { typedef s8 result; }; +template<> struct arma_signed_integral_only< s16 > { typedef s16 result; }; +template<> struct arma_signed_integral_only< s32 > { typedef s32 result; }; +template<> struct arma_signed_integral_only< s64 > { typedef s64 result; }; +template<> struct arma_signed_integral_only< slng_t > { typedef slng_t result; }; + + + +template struct arma_signed_only { }; + +template<> struct arma_signed_only< s8 > { typedef s8 result; }; +template<> struct arma_signed_only< s16 > { typedef s16 result; }; +template<> struct arma_signed_only< s32 > { typedef s32 result; }; +template<> struct arma_signed_only< s64 > { typedef s64 result; }; +template<> struct arma_signed_only< slng_t > { typedef slng_t result; }; +template<> struct arma_signed_only< float > { typedef float result; }; +template<> struct arma_signed_only< double > { typedef double result; }; +template<> struct arma_signed_only< cx_float > { typedef cx_float result; }; +template<> struct arma_signed_only< cx_double > { typedef cx_double result; }; + + + +template struct arma_real_only { }; + +template<> struct arma_real_only< float > { typedef float result; }; +template<> struct arma_real_only< double > { typedef double result; }; + + +template struct arma_real_or_cx_only { }; + +template<> struct arma_real_or_cx_only< float > { typedef float result; }; +template<> struct arma_real_or_cx_only< double > { typedef double result; }; +template<> struct arma_real_or_cx_only< cx_float > { typedef cx_float result; }; +template<> struct arma_real_or_cx_only< cx_double > { typedef cx_double result; }; + + + +template struct arma_cx_only { }; + +template<> struct arma_cx_only< cx_float > { typedef cx_float result; }; +template<> struct arma_cx_only< cx_double > { typedef cx_double result; }; + + + +template struct arma_not_cx { typedef T result; }; +template struct arma_not_cx< std::complex > { }; + + + +template struct arma_blas_type_only { }; + +template<> struct arma_blas_type_only< float > { typedef float result; }; +template<> struct arma_blas_type_only< double > { typedef double result; }; +template<> struct arma_blas_type_only< cx_float > { typedef cx_float result; }; +template<> struct arma_blas_type_only< cx_double > { typedef cx_double result; }; + + + +template struct arma_not_blas_type { typedef T result; }; + +template<> struct arma_not_blas_type< float > { }; +template<> struct arma_not_blas_type< double > { }; +template<> struct arma_not_blas_type< cx_float > { }; +template<> struct arma_not_blas_type< cx_double > { }; + + + +template struct arma_op_rel_only { }; + +template<> struct arma_op_rel_only< op_rel_lt_pre > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_lt_post > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_gt_pre > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_gt_post > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_lteq_pre > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_lteq_post > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_gteq_pre > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_gteq_post > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_eq > { typedef int result; }; +template<> struct arma_op_rel_only< op_rel_noteq > { typedef int result; }; + + + +template struct arma_not_op_rel { typedef int result; }; + +template<> struct arma_not_op_rel< op_rel_lt_pre > { }; +template<> struct arma_not_op_rel< op_rel_lt_post > { }; +template<> struct arma_not_op_rel< op_rel_gt_pre > { }; +template<> struct arma_not_op_rel< op_rel_gt_post > { }; +template<> struct arma_not_op_rel< op_rel_lteq_pre > { }; +template<> struct arma_not_op_rel< op_rel_lteq_post > { }; +template<> struct arma_not_op_rel< op_rel_gteq_pre > { }; +template<> struct arma_not_op_rel< op_rel_gteq_post > { }; +template<> struct arma_not_op_rel< op_rel_eq > { }; +template<> struct arma_not_op_rel< op_rel_noteq > { }; + + + +template struct arma_glue_rel_only { }; + +template<> struct arma_glue_rel_only< glue_rel_lt > { typedef int result; }; +template<> struct arma_glue_rel_only< glue_rel_gt > { typedef int result; }; +template<> struct arma_glue_rel_only< glue_rel_lteq > { typedef int result; }; +template<> struct arma_glue_rel_only< glue_rel_gteq > { typedef int result; }; +template<> struct arma_glue_rel_only< glue_rel_eq > { typedef int result; }; +template<> struct arma_glue_rel_only< glue_rel_noteq > { typedef int result; }; +template<> struct arma_glue_rel_only< glue_rel_and > { typedef int result; }; +template<> struct arma_glue_rel_only< glue_rel_or > { typedef int result; }; + + + +template struct arma_Mat_Col_Row_only { }; + +template struct arma_Mat_Col_Row_only< Mat > { typedef Mat result; }; +template struct arma_Mat_Col_Row_only< Col > { typedef Col result; }; +template struct arma_Mat_Col_Row_only< Row > { typedef Row result; }; + + + +template struct arma_Cube_only { }; +template struct arma_Cube_only< Cube > { typedef Cube result; }; + + +template struct arma_SpMat_SpCol_SpRow_only { }; + +template struct arma_SpMat_SpCol_SpRow_only< SpMat > { typedef SpMat result; }; +template struct arma_SpMat_SpCol_SpRow_only< SpCol > { typedef SpCol result; }; +template struct arma_SpMat_SpCol_SpRow_only< SpRow > { typedef SpRow result; }; + + + +template struct enable_if { }; +template<> struct enable_if { typedef int result; }; + + +template struct enable_if2 { }; +template< typename result_type > struct enable_if2 { typedef result_type result; }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/running_stat_bones.hpp b/src/armadillo/include/armadillo_bits/running_stat_bones.hpp new file mode 100644 index 0000000..bd25e1d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/running_stat_bones.hpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup running_stat +//! @{ + + + +template +class arma_counter + { + public: + + inline ~arma_counter(); + inline arma_counter(); + + inline const arma_counter& operator++(); + inline void operator++(int); + + inline void reset(); + inline eT value() const; + inline eT value_plus_1() const; + inline eT value_minus_1() const; + + + private: + + arma_aligned eT d_count; + arma_aligned uword i_count; + }; + + + +//! Class for keeping statistics of a continuously sampled process / signal. +//! Useful if the storage of individual samples is not necessary or desired. +//! Also useful if the number of samples is not known beforehand or exceeds +//! available memory. +template +class running_stat + { + public: + + typedef typename get_pod_type::result T; + + + inline ~running_stat(); + inline running_stat(); + + inline void operator() (const T sample); + inline void operator() (const std::complex& sample); + + inline void reset(); + + inline eT mean() const; + + inline T var (const uword norm_type = 0) const; + inline T stddev(const uword norm_type = 0) const; + + inline eT min() const; + inline eT max() const; + inline eT range() const; + + inline T count() const; + + // + // + + private: + + arma_aligned arma_counter counter; + + arma_aligned eT r_mean; + arma_aligned T r_var; + + arma_aligned eT min_val; + arma_aligned eT max_val; + + arma_aligned T min_val_norm; + arma_aligned T max_val_norm; + + + friend class running_stat_aux; + }; + + + +class running_stat_aux + { + public: + + template + inline static void update_stats(running_stat& x, const eT sample, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void update_stats(running_stat& x, const std::complex& sample, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void update_stats(running_stat& x, const typename eT::value_type sample, const typename arma_cx_only::result* junk = nullptr); + + template + inline static void update_stats(running_stat& x, const eT& sample, const typename arma_cx_only::result* junk = nullptr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/running_stat_meat.hpp b/src/armadillo/include/armadillo_bits/running_stat_meat.hpp new file mode 100644 index 0000000..35e6ba8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/running_stat_meat.hpp @@ -0,0 +1,463 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup running_stat +//! @{ + + + +template +inline +arma_counter::~arma_counter() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +arma_counter::arma_counter() + : d_count( eT(0)) + , i_count(uword(0)) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +const arma_counter& +arma_counter::operator++() + { + if(i_count < ARMA_MAX_UWORD) + { + i_count++; + } + else + { + d_count += eT(ARMA_MAX_UWORD); + i_count = 1; + } + + return *this; + } + + + +template +inline +void +arma_counter::operator++(int) + { + operator++(); + } + + + +template +inline +void +arma_counter::reset() + { + d_count = eT(0); + i_count = uword(0); + } + + + +template +inline +eT +arma_counter::value() const + { + return d_count + eT(i_count); + } + + + +template +inline +eT +arma_counter::value_plus_1() const + { + if(i_count < ARMA_MAX_UWORD) + { + return d_count + eT(i_count + 1); + } + else + { + return d_count + eT(ARMA_MAX_UWORD) + eT(1); + } + } + + + +template +inline +eT +arma_counter::value_minus_1() const + { + if(i_count > 0) + { + return d_count + eT(i_count - 1); + } + else + { + return d_count - eT(1); + } + } + + + +// + + + +template +inline +running_stat::~running_stat() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +running_stat::running_stat() + : r_mean ( eT(0)) + , r_var (typename running_stat::T(0)) + , min_val ( eT(0)) + , max_val ( eT(0)) + , min_val_norm(typename running_stat::T(0)) + , max_val_norm(typename running_stat::T(0)) + { + arma_extra_debug_sigprint_this(this); + } + + + +//! update statistics to reflect new sample +template +inline +void +running_stat::operator() (const typename running_stat::T sample) + { + arma_extra_debug_sigprint(); + + if( arma_isfinite(sample) == false ) + { + arma_debug_warn_level(3, "running_stat: sample ignored as it is non-finite" ); + return; + } + + running_stat_aux::update_stats(*this, sample); + } + + + +//! update statistics to reflect new sample (version for complex numbers) +template +inline +void +running_stat::operator() (const std::complex< typename running_stat::T >& sample) + { + arma_extra_debug_sigprint(); + + if( arma_isfinite(sample) == false ) + { + arma_debug_warn_level(3, "running_stat: sample ignored as it is non-finite" ); + return; + } + + running_stat_aux::update_stats(*this, sample); + } + + + +//! set all statistics to zero +template +inline +void +running_stat::reset() + { + arma_extra_debug_sigprint(); + + // typedef typename running_stat::T T; + + counter.reset(); + + r_mean = eT(0); + r_var = T(0); + + min_val = eT(0); + max_val = eT(0); + + min_val_norm = T(0); + max_val_norm = T(0); + } + + + +//! mean or average value +template +inline +eT +running_stat::mean() const + { + arma_extra_debug_sigprint(); + + return r_mean; + } + + + +//! variance +template +inline +typename running_stat::T +running_stat::var(const uword norm_type) const + { + arma_extra_debug_sigprint(); + + const T N = counter.value(); + + if(N > T(1)) + { + if(norm_type == 0) + { + return r_var; + } + else + { + const T N_minus_1 = counter.value_minus_1(); + return (N_minus_1/N) * r_var; + } + } + else + { + return T(0); + } + } + + + +//! standard deviation +template +inline +typename running_stat::T +running_stat::stddev(const uword norm_type) const + { + arma_extra_debug_sigprint(); + + return std::sqrt( (*this).var(norm_type) ); + } + + + +//! minimum value +template +inline +eT +running_stat::min() const + { + arma_extra_debug_sigprint(); + + return min_val; + } + + + +//! maximum value +template +inline +eT +running_stat::max() const + { + arma_extra_debug_sigprint(); + + return max_val; + } + + + +template +inline +eT +running_stat::range() const + { + arma_extra_debug_sigprint(); + + return (max_val - min_val); + } + + + +//! number of samples so far +template +inline +typename get_pod_type::result +running_stat::count() const + { + arma_extra_debug_sigprint(); + + return counter.value(); + } + + + +//! update statistics to reflect new sample (version for non-complex numbers, non-complex sample) +template +inline +void +running_stat_aux::update_stats(running_stat& x, const eT sample, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename running_stat::T T; + + const T N = x.counter.value(); + + if(N > T(0)) + { + if(sample < x.min_val) + { + x.min_val = sample; + } + + if(sample > x.max_val) + { + x.max_val = sample; + } + + const T N_plus_1 = x.counter.value_plus_1(); + const T N_minus_1 = x.counter.value_minus_1(); + + // note: variance has to be updated before the mean + + const eT tmp = sample - x.r_mean; + + x.r_var = N_minus_1/N * x.r_var + (tmp*tmp)/N_plus_1; + + x.r_mean = x.r_mean + (sample - x.r_mean)/N_plus_1; + //x.r_mean = (N/N_plus_1)*x.r_mean + sample/N_plus_1; + //x.r_mean = (x.r_mean + sample/N) * N/N_plus_1; + } + else + { + x.r_mean = sample; + x.min_val = sample; + x.max_val = sample; + + // r_var is initialised to zero + // in the constructor and reset() + } + + x.counter++; + } + + + +//! update statistics to reflect new sample (version for non-complex numbers, complex sample) +template +inline +void +running_stat_aux::update_stats(running_stat& x, const std::complex& sample, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + running_stat_aux::update_stats(x, std::real(sample)); + } + + + +//! update statistics to reflect new sample (version for complex numbers, non-complex sample) +template +inline +void +running_stat_aux::update_stats(running_stat& x, const typename eT::value_type sample, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename eT::value_type T; + + running_stat_aux::update_stats(x, std::complex(sample)); + } + + + +//! alter statistics to reflect new sample (version for complex numbers, complex sample) +template +inline +void +running_stat_aux::update_stats(running_stat& x, const eT& sample, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename eT::value_type T; + + const T sample_norm = std::norm(sample); + const T N = x.counter.value(); + + if(N > T(0)) + { + if(sample_norm < x.min_val_norm) + { + x.min_val_norm = sample_norm; + x.min_val = sample; + } + + if(sample_norm > x.max_val_norm) + { + x.max_val_norm = sample_norm; + x.max_val = sample; + } + + const T N_plus_1 = x.counter.value_plus_1(); + const T N_minus_1 = x.counter.value_minus_1(); + + x.r_var = N_minus_1/N * x.r_var + std::norm(sample - x.r_mean)/N_plus_1; + + x.r_mean = x.r_mean + (sample - x.r_mean)/N_plus_1; + //x.r_mean = (N/N_plus_1)*x.r_mean + sample/N_plus_1; + //x.r_mean = (x.r_mean + sample/N) * N/N_plus_1; + } + else + { + x.r_mean = sample; + x.min_val = sample; + x.max_val = sample; + x.min_val_norm = sample_norm; + x.max_val_norm = sample_norm; + + // r_var is initialised to zero + // in the constructor and reset() + } + + x.counter++; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/running_stat_vec_bones.hpp b/src/armadillo/include/armadillo_bits/running_stat_vec_bones.hpp new file mode 100644 index 0000000..13b076c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/running_stat_vec_bones.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup running_stat_vec +//! @{ + + +template struct rsv_get_elem_type_worker { }; +template struct rsv_get_elem_type_worker { typedef obj_type result; }; +template struct rsv_get_elem_type_worker { typedef typename obj_type::elem_type result; }; + +template struct rsv_get_elem_type { typedef typename rsv_get_elem_type_worker::value>::result elem_type; }; + + +template struct rsv_get_return_type1_worker { }; +template struct rsv_get_return_type1_worker { typedef Mat result; }; +template struct rsv_get_return_type1_worker { typedef obj_type result; }; + +template struct rsv_get_return_type1 { typedef typename rsv_get_return_type1_worker::value>::result return_type1; }; + + +template struct rsv_get_return_type2 { }; +template struct rsv_get_return_type2< Mat > { typedef Mat::result> return_type2; }; +template struct rsv_get_return_type2< Row > { typedef Row::result> return_type2; }; +template struct rsv_get_return_type2< Col > { typedef Col::result> return_type2; }; + + +//! Class for keeping statistics of a continuously sampled process / signal. +//! Useful if the storage of individual samples is not necessary or desired. +//! Also useful if the number of samples is not known beforehand or exceeds +//! available memory. +template +class running_stat_vec + { + public: + + // voodoo for compatibility with old user code + typedef typename rsv_get_elem_type::elem_type eT; + + typedef typename get_pod_type::result T; + + typedef typename rsv_get_return_type1::return_type1 return_type1; + typedef typename rsv_get_return_type2::return_type2 return_type2; + + inline ~running_stat_vec(); + inline running_stat_vec(const bool in_calc_cov = false); // TODO: investigate char* overload, eg. "calc_cov", "no_calc_cov" + + inline running_stat_vec(const running_stat_vec& in_rsv); + + inline running_stat_vec& operator=(const running_stat_vec& in_rsv); + + template inline void operator() (const Base< T, T1>& X); + template inline void operator() (const Base, T1>& X); + + inline void reset(); + + inline const return_type1& mean() const; + + inline const return_type2& var (const uword norm_type = 0); + inline return_type2 stddev(const uword norm_type = 0) const; + inline const Mat& cov (const uword norm_type = 0); + + inline const return_type1& min() const; + inline const return_type1& max() const; + inline return_type1 range() const; + + inline T count() const; + + // + // + + private: + + const bool calc_cov; + + arma_aligned arma_counter counter; + + arma_aligned return_type1 r_mean; + arma_aligned return_type2 r_var; + arma_aligned Mat r_cov; + + arma_aligned return_type1 min_val; + arma_aligned return_type1 max_val; + + arma_aligned Mat< T> min_val_norm; + arma_aligned Mat< T> max_val_norm; + + arma_aligned return_type2 r_var_dummy; + arma_aligned Mat r_cov_dummy; + + arma_aligned Mat tmp1; + arma_aligned Mat tmp2; + + friend class running_stat_vec_aux; + }; + + + +class running_stat_vec_aux + { + public: + + template + inline static void + update_stats + ( + running_stat_vec& x, + const Mat::eT>& sample, + const typename arma_not_cx::eT>::result* junk = nullptr + ); + + template + inline static void + update_stats + ( + running_stat_vec& x, + const Mat::T > >& sample, + const typename arma_not_cx::eT>::result* junk = nullptr + ); + + template + inline static void + update_stats + ( + running_stat_vec& x, + const Mat< typename running_stat_vec::T >& sample, + const typename arma_cx_only::eT>::result* junk = nullptr + ); + + template + inline static void + update_stats + ( + running_stat_vec& x, + const Mat::eT>& sample, + const typename arma_cx_only::eT>::result* junk = nullptr + ); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/running_stat_vec_meat.hpp b/src/armadillo/include/armadillo_bits/running_stat_vec_meat.hpp new file mode 100644 index 0000000..370fcf7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/running_stat_vec_meat.hpp @@ -0,0 +1,636 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup running_stat_vec +//! @{ + + + +template +inline +running_stat_vec::~running_stat_vec() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +running_stat_vec::running_stat_vec(const bool in_calc_cov) + : calc_cov(in_calc_cov) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +running_stat_vec::running_stat_vec(const running_stat_vec& in_rsv) + : calc_cov (in_rsv.calc_cov) + , counter (in_rsv.counter) + , r_mean (in_rsv.r_mean) + , r_var (in_rsv.r_var) + , r_cov (in_rsv.r_cov) + , min_val (in_rsv.min_val) + , max_val (in_rsv.max_val) + , min_val_norm(in_rsv.min_val_norm) + , max_val_norm(in_rsv.max_val_norm) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +running_stat_vec& +running_stat_vec::operator=(const running_stat_vec& in_rsv) + { + arma_extra_debug_sigprint(); + + access::rw(calc_cov) = in_rsv.calc_cov; + + counter = in_rsv.counter; + r_mean = in_rsv.r_mean; + r_var = in_rsv.r_var; + r_cov = in_rsv.r_cov; + min_val = in_rsv.min_val; + max_val = in_rsv.max_val; + min_val_norm = in_rsv.min_val_norm; + max_val_norm = in_rsv.max_val_norm; + + return *this; + } + + + +//! update statistics to reflect new sample +template +template +inline +void +running_stat_vec::operator() (const Base::T, T1>& X) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(X.get_ref()); + const Mat& sample = tmp.M; + + if( sample.is_empty() ) + { + return; + } + + if( sample.internal_has_nonfinite() ) + { + arma_debug_warn_level(3, "running_stat_vec: sample ignored as it has non-finite elements"); + return; + } + + running_stat_vec_aux::update_stats(*this, sample); + } + + + +template +template +inline +void +running_stat_vec::operator() (const Base< std::complex::T>, T1>& X) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(X.get_ref()); + + const Mat< std::complex >& sample = tmp.M; + + if( sample.is_empty() ) + { + return; + } + + if( sample.internal_has_nonfinite() ) + { + arma_debug_warn_level(3, "running_stat_vec: sample ignored as it has non-finite elements"); + return; + } + + running_stat_vec_aux::update_stats(*this, sample); + } + + + +//! set all statistics to zero +template +inline +void +running_stat_vec::reset() + { + arma_extra_debug_sigprint(); + + counter.reset(); + + r_mean.reset(); + r_var.reset(); + r_cov.reset(); + + min_val.reset(); + max_val.reset(); + + min_val_norm.reset(); + max_val_norm.reset(); + + r_var_dummy.reset(); + r_cov_dummy.reset(); + + tmp1.reset(); + tmp2.reset(); + } + + + +//! mean or average value +template +inline +const typename running_stat_vec::return_type1& +running_stat_vec::mean() const + { + arma_extra_debug_sigprint(); + + return r_mean; + } + + + +//! variance +template +inline +const typename running_stat_vec::return_type2& +running_stat_vec::var(const uword norm_type) + { + arma_extra_debug_sigprint(); + + const T N = counter.value(); + + if(N > T(1)) + { + if(norm_type == 0) + { + return r_var; + } + else + { + const T N_minus_1 = counter.value_minus_1(); + + r_var_dummy = (N_minus_1/N) * r_var; + + return r_var_dummy; + } + } + else + { + r_var_dummy.zeros(r_mean.n_rows, r_mean.n_cols); + + return r_var_dummy; + } + + } + + + +//! standard deviation +template +inline +typename running_stat_vec::return_type2 +running_stat_vec::stddev(const uword norm_type) const + { + arma_extra_debug_sigprint(); + + const T N = counter.value(); + + if(N > T(1)) + { + if(norm_type == 0) + { + return sqrt(r_var); + } + else + { + const T N_minus_1 = counter.value_minus_1(); + + return sqrt( (N_minus_1/N) * r_var ); + } + } + else + { + typedef typename running_stat_vec::return_type2 out_type; + return out_type(); + } + } + + + +//! covariance +template +inline +const Mat< typename running_stat_vec::eT >& +running_stat_vec::cov(const uword norm_type) + { + arma_extra_debug_sigprint(); + + if(calc_cov) + { + const T N = counter.value(); + + if(N > T(1)) + { + if(norm_type == 0) + { + return r_cov; + } + else + { + const T N_minus_1 = counter.value_minus_1(); + + r_cov_dummy = (N_minus_1/N) * r_cov; + + return r_cov_dummy; + } + } + else + { + const uword out_size = (std::max)(r_mean.n_rows, r_mean.n_cols); + + r_cov_dummy.zeros(out_size, out_size); + + return r_cov_dummy; + } + } + else + { + r_cov_dummy.reset(); + + return r_cov_dummy; + } + + } + + + +//! vector with minimum values +template +inline +const typename running_stat_vec::return_type1& +running_stat_vec::min() const + { + arma_extra_debug_sigprint(); + + return min_val; + } + + + +//! vector with maximum values +template +inline +const typename running_stat_vec::return_type1& +running_stat_vec::max() const + { + arma_extra_debug_sigprint(); + + return max_val; + } + + + +template +inline +typename running_stat_vec::return_type1 +running_stat_vec::range() const + { + arma_extra_debug_sigprint(); + + return (max_val - min_val); + } + + + +//! number of samples so far +template +inline +typename running_stat_vec::T +running_stat_vec::count() const + { + arma_extra_debug_sigprint(); + + return counter.value(); + } + + + +// + + + +//! update statistics to reflect new sample (version for non-complex numbers) +template +inline +void +running_stat_vec_aux::update_stats + ( + running_stat_vec& x, + const Mat::eT>& sample, + const typename arma_not_cx::eT>::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename running_stat_vec::eT eT; + typedef typename running_stat_vec::T T; + + const T N = x.counter.value(); + + if(N > T(0)) + { + arma_debug_assert_same_size(x.r_mean, sample, "running_stat_vec(): dimensionality mismatch"); + + const uword n_elem = sample.n_elem; + const eT* sample_mem = sample.memptr(); + eT* r_mean_mem = x.r_mean.memptr(); + T* r_var_mem = x.r_var.memptr(); + eT* min_val_mem = x.min_val.memptr(); + eT* max_val_mem = x.max_val.memptr(); + + const T N_plus_1 = x.counter.value_plus_1(); + const T N_minus_1 = x.counter.value_minus_1(); + + if(x.calc_cov) + { + Mat& tmp1 = x.tmp1; + Mat& tmp2 = x.tmp2; + + tmp1 = sample - x.r_mean; + + if(sample.n_cols == 1) + { + tmp2 = tmp1*trans(tmp1); + } + else + { + tmp2 = trans(tmp1)*tmp1; + } + + x.r_cov *= (N_minus_1/N); + x.r_cov += tmp2 / N_plus_1; + } + + + for(uword i=0; i max_val_mem[i]) + { + max_val_mem[i] = val; + } + + const eT r_mean_val = r_mean_mem[i]; + const eT tmp = val - r_mean_val; + + r_var_mem[i] = N_minus_1/N * r_var_mem[i] + (tmp*tmp)/N_plus_1; + + r_mean_mem[i] = r_mean_val + (val - r_mean_val)/N_plus_1; + } + } + else + { + arma_debug_check( (sample.is_vec() == false), "running_stat_vec(): given sample must be a vector" ); + + x.r_mean.set_size(sample.n_rows, sample.n_cols); + + x.r_var.zeros(sample.n_rows, sample.n_cols); + + if(x.calc_cov) + { + x.r_cov.zeros(sample.n_elem, sample.n_elem); + } + + x.min_val.set_size(sample.n_rows, sample.n_cols); + x.max_val.set_size(sample.n_rows, sample.n_cols); + + + const uword n_elem = sample.n_elem; + const eT* sample_mem = sample.memptr(); + eT* r_mean_mem = x.r_mean.memptr(); + eT* min_val_mem = x.min_val.memptr(); + eT* max_val_mem = x.max_val.memptr(); + + + for(uword i=0; i +inline +void +running_stat_vec_aux::update_stats + ( + running_stat_vec& x, + const Mat::T > >& sample, + const typename arma_not_cx::eT>::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename running_stat_vec::eT eT; + + running_stat_vec_aux::update_stats(x, conv_to< Mat >::from(sample)); + } + + + +//! update statistics to reflect new sample (version for complex numbers, non-complex sample) +template +inline +void +running_stat_vec_aux::update_stats + ( + running_stat_vec& x, + const Mat::T >& sample, + const typename arma_cx_only::eT>::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename running_stat_vec::eT eT; + + running_stat_vec_aux::update_stats(x, conv_to< Mat >::from(sample)); + } + + + +//! alter statistics to reflect new sample (version for complex numbers, complex sample) +template +inline +void +running_stat_vec_aux::update_stats + ( + running_stat_vec& x, + const Mat::eT>& sample, + const typename arma_cx_only::eT>::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename running_stat_vec::eT eT; + typedef typename running_stat_vec::T T; + + const T N = x.counter.value(); + + if(N > T(0)) + { + arma_debug_assert_same_size(x.r_mean, sample, "running_stat_vec(): dimensionality mismatch"); + + const uword n_elem = sample.n_elem; + const eT* sample_mem = sample.memptr(); + eT* r_mean_mem = x.r_mean.memptr(); + T* r_var_mem = x.r_var.memptr(); + eT* min_val_mem = x.min_val.memptr(); + eT* max_val_mem = x.max_val.memptr(); + T* min_val_norm_mem = x.min_val_norm.memptr(); + T* max_val_norm_mem = x.max_val_norm.memptr(); + + const T N_plus_1 = x.counter.value_plus_1(); + const T N_minus_1 = x.counter.value_minus_1(); + + if(x.calc_cov) + { + Mat& tmp1 = x.tmp1; + Mat& tmp2 = x.tmp2; + + tmp1 = sample - x.r_mean; + + if(sample.n_cols == 1) + { + tmp2 = arma::conj(tmp1)*strans(tmp1); + } + else + { + tmp2 = trans(tmp1)*tmp1; //tmp2 = strans(conj(tmp1))*tmp1; + } + + x.r_cov *= (N_minus_1/N); + x.r_cov += tmp2 / N_plus_1; + } + + + for(uword i=0; i max_val_norm_mem[i]) + { + max_val_norm_mem[i] = val_norm; + max_val_mem[i] = val; + } + + const eT& r_mean_val = r_mean_mem[i]; + + r_var_mem[i] = N_minus_1/N * r_var_mem[i] + std::norm(val - r_mean_val)/N_plus_1; + + r_mean_mem[i] = r_mean_val + (val - r_mean_val)/N_plus_1; + } + } + else + { + arma_debug_check( (sample.is_vec() == false), "running_stat_vec(): given sample must be a vector" ); + + x.r_mean.set_size(sample.n_rows, sample.n_cols); + + x.r_var.zeros(sample.n_rows, sample.n_cols); + + if(x.calc_cov) + { + x.r_cov.zeros(sample.n_elem, sample.n_elem); + } + + x.min_val.set_size(sample.n_rows, sample.n_cols); + x.max_val.set_size(sample.n_rows, sample.n_cols); + + x.min_val_norm.set_size(sample.n_rows, sample.n_cols); + x.max_val_norm.set_size(sample.n_rows, sample.n_cols); + + + const uword n_elem = sample.n_elem; + const eT* sample_mem = sample.memptr(); + eT* r_mean_mem = x.r_mean.memptr(); + eT* min_val_mem = x.min_val.memptr(); + eT* max_val_mem = x.max_val.memptr(); + T* min_val_norm_mem = x.min_val_norm.memptr(); + T* max_val_norm_mem = x.max_val_norm.memptr(); + + for(uword i=0; i + inline static bool eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); + + template + inline static bool eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts); + + template + inline static bool eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); + + template + inline static bool eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts); + + template + inline static bool eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eT sigma, const eigs_opts& opts); + + // + // eigs_gen() for real matrices + + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); + + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts); + + template + inline static bool eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); + + template + inline static bool eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts); + + // + // eigs_gen() for complex matrices + + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); + + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts); + + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat< std::complex >& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts); + + // + // spsolve() via SuperLU + + template + inline static bool spsolve_simple(Mat& out, const SpBase& A, const Base& B, const superlu_opts& user_opts); + + template + inline static bool spsolve_refine(Mat& out, typename T1::pod_type& out_rcond, const SpBase& A, const Base& B, const superlu_opts& user_opts); + + // + // support functions + + #if defined(ARMA_USE_SUPERLU) + + template + inline static typename get_pod_type::result norm1(superlu::SuperMatrix* A); + + template + inline static typename get_pod_type::result lu_rcond(superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type::result norm_val); + + inline static void set_superlu_opts(superlu::superlu_options_t& options, const superlu_opts& user_opts); + + template + inline static bool copy_to_supermatrix(superlu::SuperMatrix& out, const SpMat& A); + + template + inline static bool copy_to_supermatrix_with_shift(superlu::SuperMatrix& out, const SpMat& A, const eT shift); + + // // for debugging only + // template + // inline static void copy_to_spmat(SpMat& out, const superlu::SuperMatrix& A); + + template + inline static bool wrap_to_supermatrix(superlu::SuperMatrix& out, const Mat& A); + + inline static void destroy_supermatrix(superlu::SuperMatrix& out); + + #endif + + + + private: + + // calls arpack saupd()/naupd() because the code is so similar for each + // all of the extra variables are later used by seupd()/neupd(), but those + // functions are very different and we can't combine their code + + template + inline static void run_aupd_plain + ( + const uword n_eigvals, char* which, + const SpMat& X, const SpMat& Xst, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, + podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, + podarray& iparam, podarray& ipntr, + podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, + blas_int& info + ); + + template + inline static void run_aupd_shiftinvert + ( + const uword n_eigvals, const T sigma, + const SpMat& X, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, + podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, + podarray& iparam, podarray& ipntr, + podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, + blas_int& info + ); + + + template + inline static bool rudimentary_sym_check(const SpMat& X); + + template + inline static bool rudimentary_sym_check(const SpMat< std::complex >& X); + }; + + + +template +struct eigs_randu_filler + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline eigs_randu_filler(); + + inline void fill(podarray& X, const uword N); + }; + + +template +struct eigs_randu_filler< std::complex > + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline eigs_randu_filler(); + + inline void fill(podarray< std::complex >& X, const uword N); + }; + + + +#if defined(ARMA_USE_SUPERLU) + +class superlu_supermatrix_wrangler + { + private: + + bool used = false; + + arma_aligned superlu::SuperMatrix m; + + public: + + inline ~superlu_supermatrix_wrangler(); + inline superlu_supermatrix_wrangler(); + + inline superlu_supermatrix_wrangler(const superlu_supermatrix_wrangler&) = delete; + inline void operator= (const superlu_supermatrix_wrangler&) = delete; + + inline superlu::SuperMatrix& get_ref(); + inline superlu::SuperMatrix* get_ptr(); + }; + + +class superlu_stat_wrangler + { + private: + + arma_aligned superlu::SuperLUStat_t stat; + + public: + + inline ~superlu_stat_wrangler(); + inline superlu_stat_wrangler(); + + inline superlu_stat_wrangler(const superlu_stat_wrangler&) = delete; + inline void operator= (const superlu_stat_wrangler&) = delete; + + inline superlu::SuperLUStat_t* get_ptr(); + }; + + +template +class superlu_array_wrangler + { + private: + + arma_aligned eT* mem = nullptr; + + public: + + inline ~superlu_array_wrangler(); + inline superlu_array_wrangler(); + inline superlu_array_wrangler(const uword n_elem); + + inline void set_size(const uword n_elem); + inline void reset(); + + inline superlu_array_wrangler(const superlu_array_wrangler&) = delete; + inline void operator= (const superlu_array_wrangler&) = delete; + + inline eT* get_ptr(); + }; + + +template +class superlu_worker + { + private: + + bool factorisation_valid = false; + + superlu_supermatrix_wrangler* l = nullptr; + superlu_supermatrix_wrangler* u = nullptr; + + superlu_array_wrangler perm_c; + superlu_array_wrangler perm_r; + + superlu_stat_wrangler stat; + + public: + + inline ~superlu_worker(); + inline superlu_worker(); + + inline bool factorise(typename get_pod_type::result& out_rcond, const SpMat& A, const superlu_opts& user_opts); + + inline bool solve(Mat& X, const Mat& B); + + inline superlu_worker(const superlu_worker&) = delete; + inline void operator= (const superlu_worker&) = delete; + }; + +#endif + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/sp_auxlib_meat.hpp b/src/armadillo/include/armadillo_bits/sp_auxlib_meat.hpp new file mode 100644 index 0000000..dbfdf2d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/sp_auxlib_meat.hpp @@ -0,0 +1,2814 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup sp_auxlib +//! @{ + + +inline +sp_auxlib::form_type +sp_auxlib::interpret_form_str(const char* form_str) + { + arma_extra_debug_sigprint(); + + // the order of the 3 if statements below is important + if( form_str == nullptr ) { return form_none; } + if( form_str[0] == char(0) ) { return form_none; } + if( form_str[1] == char(0) ) { return form_none; } + + const char c1 = form_str[0]; + const char c2 = form_str[1]; + + if(c1 == 'l') + { + if(c2 == 'm') { return form_lm; } + if(c2 == 'r') { return form_lr; } + if(c2 == 'i') { return form_li; } + if(c2 == 'a') { return form_la; } + } + else + if(c1 == 's') + { + if(c2 == 'm') { return form_sm; } + if(c2 == 'r') { return form_sr; } + if(c2 == 'i') { return form_si; } + if(c2 == 'a') { return form_sa; } + } + + return form_none; + } + + + +//! immediate eigendecomposition of symmetric real sparse object +template +inline +bool +sp_auxlib::eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_sym(): given matrix must be square sized" ); + + if((arma_config::debug) && (sp_auxlib::rudimentary_sym_check(U.M) == false)) + { + if(is_cx::no ) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not hermitian"); } + } + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_sym(): detected non-finite elements"); + return false; + } + + // TODO: investigate optional redirection of "sm" to ARPACK as it's capable of shift-invert; + // TODO: in shift-invert mode, "sm" maps to "lm" of the shift-inverted matrix (with sigma = 0) + + #if defined(ARMA_USE_NEWARP) + { + return sp_auxlib::eigs_sym_newarp(eigval, eigvec, U.M, n_eigvals, form_val, opts); + } + #elif defined(ARMA_USE_ARPACK) + { + constexpr eT sigma = eT(0); + + return sp_auxlib::eigs_sym_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(form_val); + arma_ignore(opts); + + arma_stop_logic_error("eigs_sym(): use of NEWARP or ARPACK must be enabled"); + return false; + } + #endif + } + + + +//! immediate eigendecomposition of symmetric real sparse object +template +inline +bool +sp_auxlib::eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_sym(): given matrix must be square sized" ); + + if((arma_config::debug) && (sp_auxlib::rudimentary_sym_check(U.M) == false)) + { + if(is_cx::no ) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not hermitian"); } + } + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_sym(): detected non-finite elements"); + return false; + } + + #if (defined(ARMA_USE_NEWARP) && defined(ARMA_USE_SUPERLU)) + { + return sp_auxlib::eigs_sym_newarp(eigval, eigvec, U.M, n_eigvals, sigma, opts); + } + #elif (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + constexpr form_type form_val = form_sigma; + + return sp_auxlib::eigs_sym_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + arma_stop_logic_error("eigs_sym(): use of NEWARP or ARPACK as well as SuperLU must be enabled to use 'sigma'"); + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_NEWARP) + { + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_la) && (form_val != form_sa), "eigs_sym(): unknown form specified" ); + + if(X.is_square() == false) { return false; } + + const newarp::SparseGenMatProd op(X); + + arma_debug_check( (n_eigvals >= op.n_rows), "eigs_sym(): n_eigvals must be less than the number of rows in the matrix" ); + + // If the matrix is empty, the case is trivial. + if( (op.n_cols == 0) || (n_eigvals == 0) ) // We already know n_cols == n_rows. + { + eigval.reset(); + eigvec.reset(); + return true; + } + + uword n = op.n_rows; + + // Use max(2*k+1, 20) as default subspace dimension for the sym case; MATLAB uses max(2*k, 20), but we need to be backward-compatible. + uword ncv_default = uword( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits, otherwise cap it. + uword ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 1)) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim must be greater than k; using k+1 instead of ", opts.subdim); + ncv = uword(n_eigvals + 1); + } + else + if(opts.subdim > n) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = uword(opts.subdim); + } + } + + // Re-check that we are within the limits + if(ncv < (n_eigvals + 1)) { ncv = (n_eigvals + 1); } + if(ncv > n ) { ncv = n; } + + eT tol = (std::max)(eT(opts.tol), std::numeric_limits::epsilon()); + + uword maxiter = uword(opts.maxiter); + + // eigval.set_size(n_eigvals); + // eigvec.set_size(n, n_eigvals); + + bool status = true; + + uword nconv = 0; + + try + { + if(form_val == form_lm) + { + newarp::SymEigsSolver< eT, newarp::EigsSelect::LARGEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_sm) + { + newarp::SymEigsSolver< eT, newarp::EigsSelect::SMALLEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_la) + { + newarp::SymEigsSolver< eT, newarp::EigsSelect::LARGEST_ALGE, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_sa) + { + newarp::SymEigsSolver< eT, newarp::EigsSelect::SMALLEST_ALGE, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + } + catch(const std::runtime_error&) + { + status = false; + } + + if(status == true) + { + if(nconv == 0) { status = false; } + } + + return status; + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_ignore(n_eigvals); + arma_ignore(form_val); + arma_ignore(opts); + + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_NEWARP) + { + if(X.is_square() == false) { return false; } + + const newarp::SparseGenRealShiftSolve op(X, sigma); + + if(op.valid == false) { return false; } + + arma_debug_check( (n_eigvals >= op.n_rows), "eigs_sym(): n_eigvals must be less than the number of rows in the matrix" ); + + // If the matrix is empty, the case is trivial. + if( (op.n_cols == 0) || (n_eigvals == 0) ) // We already know n_cols == n_rows. + { + eigval.reset(); + eigvec.reset(); + return true; + } + + uword n = op.n_rows; + + // Use max(2*k+1, 20) as default subspace dimension for the sym case; MATLAB uses max(2*k, 20), but we need to be backward-compatible. + uword ncv_default = uword( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits, otherwise cap it. + uword ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 1)) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim must be greater than k; using k+1 instead of ", opts.subdim); + ncv = uword(n_eigvals + 1); + } + else + if(opts.subdim > n) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = uword(opts.subdim); + } + } + + // Re-check that we are within the limits + if(ncv < (n_eigvals + 1)) { ncv = (n_eigvals + 1); } + if(ncv > n ) { ncv = n; } + + eT tol = (std::max)(eT(opts.tol), std::numeric_limits::epsilon()); + + uword maxiter = uword(opts.maxiter); + + // eigval.set_size(n_eigvals); + // eigvec.set_size(n, n_eigvals); + + bool status = true; + + uword nconv = 0; + + try + { + newarp::SymEigsShiftSolver< eT, newarp::EigsSelect::LARGEST_MAGN, newarp::SparseGenRealShiftSolve > eigs(op, n_eigvals, ncv, sigma); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + catch(const std::runtime_error&) + { + status = false; + } + + if(status == true) + { + if(nconv == 0) { status = false; } + } + + return status; + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eT sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_ARPACK) + { + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_la) && (form_val != form_sa) && (form_val != form_sigma), "eigs_sym(): unknown form specified" ); + + if(X.is_square() == false) { return false; } + + char which_sm[3] = "SM"; + char which_lm[3] = "LM"; + char which_sa[3] = "SA"; + char which_la[3] = "LA"; + char* which; + + switch(form_val) + { + case form_sm: which = which_sm; break; + case form_lm: which = which_lm; break; + case form_sa: which = which_sa; break; + case form_la: which = which_la; break; + + default: which = which_lm; break; + } + + // Make sure we aren't asking for every eigenvalue. + // The _saupd() functions allow asking for one more eigenvalue than the _naupd() functions. + arma_debug_check( (n_eigvals >= X.n_rows), "eigs_sym(): n_eigvals must be less than the number of rows in the matrix" ); + + // If the matrix is empty, the case is trivial. + if( (X.n_cols == 0) || (n_eigvals == 0) ) // We already know n_cols == n_rows. + { + eigval.reset(); + eigvec.reset(); + return true; + } + + // Set up variables that get used for neupd(). + blas_int n, ncv, ncv_default, ldv, lworkl, info, maxiter; + + eT tol = eT(opts.tol); + maxiter = blas_int(opts.maxiter); + + podarray resid, v, workd, workl; + podarray iparam, ipntr; + podarray rwork; // Not used in this case. + + n = blas_int(X.n_rows); // The size of the matrix. + + // Use max(2*k+1, 20) as default subspace dimension for the sym case; MATLAB uses max(2*k, 20), but we need to be backward-compatible. + ncv_default = blas_int( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits + ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 1)) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim must be greater than k; using k+1 instead of ", opts.subdim); + ncv = blas_int(n_eigvals + 1); + } + else + if(blas_int(opts.subdim) > n) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = blas_int(opts.subdim); + } + } + + if(use_sigma) + //if(form_val == form_sigma) + { + run_aupd_shiftinvert(n_eigvals, sigma, X, true /* sym, not gen */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + else + { + const SpMat Xst = X.st(); + + run_aupd_plain(n_eigvals, which, X, Xst, true /* sym, not gen */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + + if(info != 0) { return false; } + + // The process has converged, and now we need to recover the actual eigenvectors using seupd() + blas_int rvec = 1; // .TRUE + blas_int nev = blas_int(n_eigvals); + + char howmny = 'A'; + char bmat = 'I'; // We are considering the standard eigenvalue problem. + + podarray select(ncv, arma_zeros_indicator()); // Logical array of dimension NCV. + blas_int ldz = n; + + // seupd() will output directly into the eigval and eigvec objects. + eigval.zeros( n_eigvals); + eigvec.zeros(n, n_eigvals); + + arpack::seupd(&rvec, &howmny, select.memptr(), eigval.memptr(), eigvec.memptr(), &ldz, (eT*) &sigma, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, &info); + + // Check for errors. + if(info != 0) { arma_debug_warn_level(1, "eigs_sym(): ARPACK error ", info, " in seupd()"); return false; } + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_ignore(n_eigvals); + arma_ignore(form_val); + arma_ignore(sigma); + arma_ignore(opts); + + return false; + } + #endif + } + + + +//! immediate eigendecomposition of non-symmetric real sparse object +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } + + // TODO: investigate optional redirection of "sm" to ARPACK as it's capable of shift-invert; + // TODO: in shift-invert mode, "sm" maps to "lm" of the shift-inverted matrix (with sigma = 0) + + #if defined(ARMA_USE_NEWARP) + { + return sp_auxlib::eigs_gen_newarp(eigval, eigvec, U.M, n_eigvals, form_val, opts); + } + #elif defined(ARMA_USE_ARPACK) + { + constexpr std::complex sigma = T(0); + + return sp_auxlib::eigs_gen_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(form_val); + arma_ignore(opts); + + arma_stop_logic_error("eigs_gen(): use of NEWARP or ARPACK must be enabled"); + return false; + } + #endif + } + + + +//! immediate eigendecomposition of non-symmetric real sparse object +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } + + #if (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + constexpr form_type form_val = form_sigma; + + return sp_auxlib::eigs_gen_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + arma_stop_logic_error("eigs_gen(): use of ARPACK and SuperLU must be enabled to use 'sigma'"); + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_NEWARP) + { + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_lr) && (form_val != form_sr) && (form_val != form_li) && (form_val != form_si), "eigs_gen(): unknown form specified" ); + + if(X.is_square() == false) { return false; } + + const newarp::SparseGenMatProd op(X); + + arma_debug_check( (n_eigvals + 1 >= op.n_rows), "eigs_gen(): n_eigvals + 1 must be less than the number of rows in the matrix" ); + + // If the matrix is empty, the case is trivial. + if( (op.n_cols == 0) || (n_eigvals == 0) ) // We already know n_cols == n_rows. + { + eigval.reset(); + eigvec.reset(); + return true; + } + + uword n = op.n_rows; + + // Use max(2*k+1, 20) as default subspace dimension for the gen case; same as MATLAB. + uword ncv_default = uword( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits + uword ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 3)) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim must be greater than k+2; using k+3 instead of ", opts.subdim); + ncv = uword(n_eigvals + 3); + } + else + if(opts.subdim > n) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = uword(opts.subdim); + } + } + + // Re-check that we are within the limits + if(ncv < (n_eigvals + 3)) { ncv = (n_eigvals + 3); } + if(ncv > n ) { ncv = n; } + + T tol = (std::max)(T(opts.tol), std::numeric_limits::epsilon()); + + uword maxiter = uword(opts.maxiter); + + // eigval.set_size(n_eigvals); + // eigvec.set_size(n, n_eigvals); + + bool status = true; + + uword nconv = 0; + + try + { + if(form_val == form_lm) + { + newarp::GenEigsSolver< T, newarp::EigsSelect::LARGEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_sm) + { + newarp::GenEigsSolver< T, newarp::EigsSelect::SMALLEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_lr) + { + newarp::GenEigsSolver< T, newarp::EigsSelect::LARGEST_REAL, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_sr) + { + newarp::GenEigsSolver< T, newarp::EigsSelect::SMALLEST_REAL, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_li) + { + newarp::GenEigsSolver< T, newarp::EigsSelect::LARGEST_IMAG, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + else + if(form_val == form_si) + { + newarp::GenEigsSolver< T, newarp::EigsSelect::SMALLEST_IMAG, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + } + catch(const std::runtime_error&) + { + status = false; + } + + if(status == true) + { + if(nconv == 0) { status = false; } + } + + return status; + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_ignore(n_eigvals); + arma_ignore(form_val); + arma_ignore(opts); + + return false; + } + #endif + } + + + + +template +inline +bool +sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_ARPACK) + { + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_lr) && (form_val != form_sr) && (form_val != form_li) && (form_val != form_si) && (form_val != form_sigma), "eigs_gen(): unknown form specified" ); + + if(X.is_square() == false) { return false; } + + char which_lm[3] = "LM"; + char which_sm[3] = "SM"; + char which_lr[3] = "LR"; + char which_sr[3] = "SR"; + char which_li[3] = "LI"; + char which_si[3] = "SI"; + + char* which; + + switch(form_val) + { + case form_lm: which = which_lm; break; + case form_sm: which = which_sm; break; + case form_lr: which = which_lr; break; + case form_sr: which = which_sr; break; + case form_li: which = which_li; break; + case form_si: which = which_si; break; + + default: which = which_lm; + } + + // Make sure we aren't asking for every eigenvalue. + arma_debug_check( (n_eigvals + 1 >= X.n_rows), "eigs_gen(): n_eigvals + 1 must be less than the number of rows in the matrix" ); + + // If the matrix is empty, the case is trivial. + if( (X.n_cols == 0) || (n_eigvals == 0) ) // We already know n_cols == n_rows. + { + eigval.reset(); + eigvec.reset(); + return true; + } + + // Set up variables that get used for neupd(). + blas_int n, ncv, ncv_default, ldv, lworkl, info, maxiter; + + T tol = T(opts.tol); + maxiter = blas_int(opts.maxiter); + + podarray resid, v, workd, workl; + podarray iparam, ipntr; + podarray rwork; // Not used in the real case. + + n = blas_int(X.n_rows); // The size of the matrix. + + // Use max(2*k+1, 20) as default subspace dimension for the gen case; same as MATLAB. + ncv_default = blas_int( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits + ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 3)) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim must be greater than k+2; using k+3 instead of ", opts.subdim); + ncv = blas_int(n_eigvals + 3); + } + else + if(blas_int(opts.subdim) > n) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = blas_int(opts.subdim); + } + } + + // WARNING!!! + // We are still not able to apply truly complex shifts to real matrices, + // in which case the OP that ARPACK wants is different (see [s/d]naupd). + // Also, if sigma contains a non-zero imaginary part, retrieving the eigenvalues + // becomes utterly messy (see [s/d]eupd, remark #3). + // We should never get to the point in which the imaginary part of sigma is non-zero; + // the user-facing functions currently convert X from real to complex if a complex sigma is detected. + // The check here is just for extra safety, and as a reminder of what's missing. + T sigmar = real(sigma); + T sigmai = imag(sigma); + + if(use_sigma) + //if(form_val == form_sigma) + { + if(sigmai != T(0)) { arma_stop_logic_error("eigs_gen(): complex 'sigma' not applicable to real matrix"); return false; } + + run_aupd_shiftinvert(n_eigvals, sigmar, X, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + else + { + const SpMat Xst = X.st(); + + run_aupd_plain(n_eigvals, which, X, Xst, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + + if(info != 0) { return false; } + + // The process has converged, and now we need to recover the actual eigenvectors using neupd(). + blas_int rvec = 1; // .TRUE + blas_int nev = blas_int(n_eigvals); + + char howmny = 'A'; + char bmat = 'I'; // We are considering the standard eigenvalue problem. + + podarray select(ncv, arma_zeros_indicator()); // logical array of dimension NCV + podarray dr(nev + 1, arma_zeros_indicator()); // real array of dimension NEV + 1 + podarray di(nev + 1, arma_zeros_indicator()); // real array of dimension NEV + 1 + podarray z(n * (nev + 1), arma_zeros_indicator()); // real N by NEV array if HOWMNY = 'A' + podarray workev(3 * ncv, arma_zeros_indicator()); + + blas_int ldz = n; + + arpack::neupd(&rvec, &howmny, select.memptr(), dr.memptr(), di.memptr(), z.memptr(), &ldz, (T*) &sigmar, (T*) &sigmai, workev.memptr(), &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); + + // Check for errors. + if(info != 0) { arma_debug_warn_level(1, "eigs_gen(): ARPACK error ", info, " in neupd()"); return false; } + + // Put it into the outputs. + eigval.set_size(n_eigvals); + eigvec.zeros(n, n_eigvals); + + for(uword i = 0; i < n_eigvals; ++i) + { + eigval[i] = std::complex(dr[i], di[i]); + } + + // Now recover the eigenvectors. + for(uword i = 0; i < n_eigvals; ++i) + { + // ARPACK ?neupd lays things out kinda odd in memory; + // so does LAPACK ?geev -- see auxlib::eig_gen() + if((i < n_eigvals - 1) && (eigval[i] == std::conj(eigval[i + 1]))) + { + for(uword j = 0; j < uword(n); ++j) + { + eigvec.at(j, i) = std::complex(z[n * i + j], z[n * (i + 1) + j]); + eigvec.at(j, i + 1) = std::complex(z[n * i + j], -z[n * (i + 1) + j]); + } + ++i; // Skip the next one. + } + else + if((i == n_eigvals - 1) && (std::complex(eigval[i]).imag() != 0.0)) + { + // We don't have the matched conjugate eigenvalue. + for(uword j = 0; j < uword(n); ++j) + { + eigvec.at(j, i) = std::complex(z[n * i + j], z[n * (i + 1) + j]); + } + } + else + { + // The eigenvector is entirely real. + for(uword j = 0; j < uword(n); ++j) + { + eigvec.at(j, i) = std::complex(z[n * i + j], T(0)); + } + } + } + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_ignore(n_eigvals); + arma_ignore(form_val); + arma_ignore(sigma); + arma_ignore(opts); + + return false; + } + #endif + } + + + +//! immediate eigendecomposition of non-symmetric complex sparse object +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X_expr, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X_expr.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } + + constexpr std::complex sigma = T(0); + + return sp_auxlib::eigs_gen(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + + + +//! immediate eigendecomposition of non-symmetric complex sparse object +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } + + #if (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + constexpr form_type form_val = form_sigma; + + return sp_auxlib::eigs_gen(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + arma_stop_logic_error("eigs_gen(): use of ARPACK and SuperLU must be enabled to use 'sigma'"); + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat< std::complex >& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_ARPACK) + { + // typedef typename std::complex eT; + + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_lr) && (form_val != form_sr) && (form_val != form_li) && (form_val != form_si) && (form_val != form_sigma), "eigs_gen(): unknown form specified" ); + + if(X.is_square() == false) { return false; } + + char which_lm[3] = "LM"; + char which_sm[3] = "SM"; + char which_lr[3] = "LR"; + char which_sr[3] = "SR"; + char which_li[3] = "LI"; + char which_si[3] = "SI"; + + char* which; + + switch(form_val) + { + case form_lm: which = which_lm; break; + case form_sm: which = which_sm; break; + case form_lr: which = which_lr; break; + case form_sr: which = which_sr; break; + case form_li: which = which_li; break; + case form_si: which = which_si; break; + + default: which = which_lm; + } + + // Make sure we aren't asking for every eigenvalue. + arma_debug_check( (n_eigvals + 1 >= X.n_rows), "eigs_gen(): n_eigvals + 1 must be less than the number of rows in the matrix" ); + + // If the matrix is empty, the case is trivial. + if( (X.n_cols == 0) || (n_eigvals == 0) ) // We already know n_cols == n_rows. + { + eigval.reset(); + eigvec.reset(); + return true; + } + + // Set up variables that get used for neupd(). + blas_int n, ncv, ncv_default, ldv, lworkl, info, maxiter; + + T tol = T(opts.tol); + maxiter = blas_int(opts.maxiter); + + podarray< std::complex > resid, v, workd, workl; + podarray iparam, ipntr; + podarray rwork; + + n = blas_int(X.n_rows); // The size of the matrix. + + // Use max(2*k+1, 20) as default subspace dimension for the gen case; same as MATLAB. + ncv_default = blas_int( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits + ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 3)) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim must be greater than k+2; using k+3 instead of ", opts.subdim); + ncv = blas_int(n_eigvals + 3); + } + else + if(blas_int(opts.subdim) > n) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = blas_int(opts.subdim); + } + } + + if(use_sigma) + //if(form_val == form_sigma) + { + run_aupd_shiftinvert(n_eigvals, sigma, X, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + else + { + const SpMat< std::complex > Xst = X.st(); + + run_aupd_plain(n_eigvals, which, X, Xst, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + + if(info != 0) { return false; } + + // The process has converged, and now we need to recover the actual eigenvectors using neupd(). + blas_int rvec = 1; // .TRUE + blas_int nev = blas_int(n_eigvals); + + char howmny = 'A'; + char bmat = 'I'; // We are considering the standard eigenvalue problem. + + podarray select(ncv, arma_zeros_indicator()); // logical array of dimension NCV + podarray> d(nev + 1, arma_zeros_indicator()); // complex array of dimension NEV + 1 + podarray> z(n * nev, arma_zeros_indicator()); // complex N by NEV array if HOWMNY = 'A' + podarray> workev(2 * ncv, arma_zeros_indicator()); + + blas_int ldz = n; + + // Prepare the outputs; neupd() will write directly to them. + eigval.zeros(n_eigvals); + eigvec.zeros(n, n_eigvals); + + arpack::neupd(&rvec, &howmny, select.memptr(), eigval.memptr(), +(std::complex*) NULL, eigvec.memptr(), &ldz, (std::complex*) &sigma, (std::complex*) NULL, workev.memptr(), &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); + + // Check for errors. + if(info != 0) { arma_debug_warn_level(1, "eigs_gen(): ARPACK error ", info, " in neupd()"); return false; } + + return (info == 0); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_ignore(n_eigvals); + arma_ignore(form_val); + arma_ignore(sigma); + arma_ignore(opts); + + arma_stop_logic_error("eigs_gen(): use of ARPACK must be enabled for decomposition of complex matrices"); + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::spsolve_simple(Mat& X, const SpBase& A_expr, const Base& B_expr, const superlu_opts& user_opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + typedef typename T1::elem_type eT; + + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, user_opts); + + const unwrap_spmat tmp1(A_expr.get_ref()); + const SpMat& A = tmp1.M; + + X = B_expr.get_ref(); // superlu::gssv() uses X as input (the B matrix) and as output (the solution) + + if(A.is_square() == false) + { + X.soft_reset(); + arma_stop_logic_error("spsolve(): solving under-determined / over-determined systems is currently not supported"); + return false; + } + + arma_debug_check( (A.n_rows != X.n_rows), "spsolve(): number of rows in the given objects must be the same", [&](){ X.soft_reset(); } ); + + if(A.is_empty() || X.is_empty()) + { + X.zeros(A.n_cols, X.n_cols); + return true; + } + + if(A.n_nonzero == uword(0)) { X.soft_reset(); return false; } + + if(arma_config::check_nonfinite && (A.internal_has_nonfinite() || X.internal_has_nonfinite())) + { + arma_debug_warn_level(3, "spsolve(): detected non-finite elements"); + return false; + } + + if(arma_config::debug) + { + bool overflow = false; + + overflow = (A.n_nonzero > INT_MAX); + overflow = (A.n_rows > INT_MAX) || overflow; + overflow = (A.n_cols > INT_MAX) || overflow; + overflow = (X.n_rows > INT_MAX) || overflow; + overflow = (X.n_cols > INT_MAX) || overflow; + + if(overflow) + { + arma_stop_runtime_error("spsolve(): integer overflow: matrix dimensions are too large for integer type used by SuperLU"); + return false; + } + } + + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler a; + + const bool status_x = wrap_to_supermatrix(x.get_ref(), X); + const bool status_a = copy_to_supermatrix(a.get_ref(), A); + + if( (status_x == false) || (status_a == false) ) { X.soft_reset(); return false; } + + superlu_supermatrix_wrangler l; + superlu_supermatrix_wrangler u; + + // paranoia: use SuperLU's memory allocation, in case it reallocs + + superlu_array_wrangler perm_c(A.n_cols+1); // extra paranoia: increase array length by 1 + superlu_array_wrangler perm_r(A.n_rows+1); + + superlu_stat_wrangler stat; + + int info = 0; // Return code. + + arma_extra_debug_print("superlu::gssv()"); + superlu::gssv(&options, a.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), l.get_ptr(), u.get_ptr(), x.get_ptr(), stat.get_ptr(), &info); + + + // Process the return code. + if( (info > 0) && (info <= int(A.n_cols)) ) + { + // std::ostringstream tmp; + // tmp << "spsolve(): could not solve system; LU factorisation completed, but detected zero in U(" << (info-1) << ',' << (info-1) << ')'; + // arma_debug_warn_level(1, tmp.str()); + } + else + if(info > int(A.n_cols)) + { + arma_debug_warn_level(1, "spsolve(): memory allocation failure"); + } + else + if(info < 0) + { + arma_debug_warn_level(1, "spsolve(): unknown SuperLU error code from gssv(): ", info); + } + + // No need to extract the data from x, since it's using the same memory as X + + return (info == 0); + } + #else + { + arma_ignore(X); + arma_ignore(A_expr); + arma_ignore(B_expr); + arma_ignore(user_opts); + arma_stop_logic_error("spsolve(): use of SuperLU must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& out_rcond, const SpBase& A_expr, const Base& B_expr, const superlu_opts& user_opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, user_opts); + + const unwrap_spmat tmp1(A_expr.get_ref()); + const SpMat& A = tmp1.M; + + const unwrap tmp2(B_expr.get_ref()); + const Mat& B_unwrap = tmp2.M; + + const bool B_is_modified = ( (user_opts.equilibrate) || (&B_unwrap == &X) ); + + Mat B_copy; if(B_is_modified) { B_copy = B_unwrap; } + + const Mat& B = (B_is_modified) ? B_copy : B_unwrap; + + if(A.is_square() == false) + { + X.soft_reset(); + arma_stop_logic_error("spsolve(): solving under-determined / over-determined systems is currently not supported"); + return false; + } + + arma_debug_check( (A.n_rows != B.n_rows), "spsolve(): number of rows in the given objects must be the same", [&](){ X.soft_reset(); } ); + + X.zeros(A.n_cols, B.n_cols); // set the elements to zero, as we don't trust the SuperLU spaghetti code + + if(A.is_empty() || B.is_empty()) { return true; } + + if(A.n_nonzero == uword(0)) { X.soft_reset(); return false; } + + if(arma_config::check_nonfinite && (A.internal_has_nonfinite() || B.internal_has_nonfinite())) + { + arma_debug_warn_level(3, "spsolve(): detected non-finite elements"); + return false; + } + + if(arma_config::debug) + { + bool overflow; + + overflow = (A.n_nonzero > INT_MAX); + overflow = (A.n_rows > INT_MAX) || overflow; + overflow = (A.n_cols > INT_MAX) || overflow; + overflow = (B.n_rows > INT_MAX) || overflow; + overflow = (B.n_cols > INT_MAX) || overflow; + overflow = (X.n_rows > INT_MAX) || overflow; + overflow = (X.n_cols > INT_MAX) || overflow; + + if(overflow) + { + arma_stop_runtime_error("spsolve(): integer overflow: matrix dimensions are too large for integer type used by SuperLU"); + return false; + } + } + + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler a; + superlu_supermatrix_wrangler b; + + const bool status_x = wrap_to_supermatrix(x.get_ref(), X); + const bool status_a = copy_to_supermatrix(a.get_ref(), A); // NOTE: superlu::gssvx() modifies 'a' if equilibration is enabled + const bool status_b = wrap_to_supermatrix(b.get_ref(), B); // NOTE: superlu::gssvx() modifies 'b' if equilibration is enabled + + if( (status_x == false) || (status_a == false) || (status_b == false) ) { X.soft_reset(); return false; } + + superlu_supermatrix_wrangler l; + superlu_supermatrix_wrangler u; + + // paranoia: use SuperLU's memory allocation, in case it reallocs + + superlu_array_wrangler perm_c(A.n_cols+1); // extra paranoia: increase array length by 1 + superlu_array_wrangler perm_r(A.n_rows+1); + superlu_array_wrangler etree(A.n_cols+1); + + superlu_array_wrangler R(A.n_rows+1); + superlu_array_wrangler C(A.n_cols+1); + superlu_array_wrangler ferr(B.n_cols+1); + superlu_array_wrangler berr(B.n_cols+1); + + superlu::GlobalLU_t glu; + arrayops::fill_zeros(reinterpret_cast(&glu), sizeof(superlu::GlobalLU_t)); + + superlu::mem_usage_t mu; + arrayops::fill_zeros(reinterpret_cast(&mu), sizeof(superlu::mem_usage_t)); + + superlu_stat_wrangler stat; + + char equed[8] = {}; // extra characters for paranoia + T rpg = T(0); + T rcond = T(0); + int info = int(0); // Return code. + + char work[8] = {}; + int lwork = int(0); // 0 means superlu will allocate memory + + arma_extra_debug_print("superlu::gssvx()"); + superlu::gssvx(&options, a.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), etree.get_ptr(), equed, R.get_ptr(), C.get_ptr(), l.get_ptr(), u.get_ptr(), &work[0], lwork, b.get_ptr(), x.get_ptr(), &rpg, &rcond, ferr.get_ptr(), berr.get_ptr(), &glu, &mu, stat.get_ptr(), &info); + + bool status = false; + + // Process the return code. + if(info == 0) + { + status = true; + } + if( (info > 0) && (info <= int(A.n_cols)) ) + { + // std::ostringstream tmp; + // tmp << "spsolve(): could not solve system; LU factorisation completed, but detected zero in U(" << (info-1) << ',' << (info-1) << ')'; + // arma_debug_warn_level(1, tmp.str()); + } + else + if( (info == int(A.n_cols+1)) && (user_opts.allow_ugly) ) + { + arma_debug_warn_level(2, "spsolve(): system is singular to working precision (rcond: ", rcond, ")"); + status = true; + } + else + if(info > int(A.n_cols+1)) + { + arma_debug_warn_level(1, "spsolve(): memory allocation failure"); + } + else + if(info < 0) + { + arma_debug_warn_level(1, "spsolve(): unknown SuperLU error code from gssvx(): ", info); + } + + // No need to extract the data from x, since it's using the same memory as X + + out_rcond = rcond; + + return status; + } + #else + { + arma_ignore(X); + arma_ignore(out_rcond); + arma_ignore(A_expr); + arma_ignore(B_expr); + arma_ignore(user_opts); + arma_stop_logic_error("spsolve(): use of SuperLU must be enabled"); + return false; + } + #endif + } + + + +#if defined(ARMA_USE_SUPERLU) + + template + inline + typename get_pod_type::result + sp_auxlib::norm1(superlu::SuperMatrix* A) + { + arma_extra_debug_sigprint(); + + char norm_id = '1'; + + arma_extra_debug_print("superlu::langs()"); + return superlu::langs(&norm_id, A); + } + + + + template + inline + typename get_pod_type::result + sp_auxlib::lu_rcond(superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type::result norm_val) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + char norm_id = '1'; + T rcond_out = T(0); + int info = int(0); + + superlu_stat_wrangler stat; + + arma_extra_debug_print("superlu::gscon()"); + superlu::gscon(&norm_id, L, U, norm_val, &rcond_out, stat.get_ptr(), &info); + + return (info == 0) ? T(rcond_out) : T(0); + } + + + + inline + void + sp_auxlib::set_superlu_opts(superlu::superlu_options_t& options, const superlu_opts& user_opts) + { + arma_extra_debug_sigprint(); + + // default options as the starting point + superlu::set_default_opts(&options); + + // our settings + options.Trans = superlu::NOTRANS; + options.ConditionNumber = superlu::YES; + + // process user_opts + + if(user_opts.equilibrate == true) { options.Equil = superlu::YES; } + if(user_opts.equilibrate == false) { options.Equil = superlu::NO; } + + if(user_opts.symmetric == true) { options.SymmetricMode = superlu::YES; } + if(user_opts.symmetric == false) { options.SymmetricMode = superlu::NO; } + + options.DiagPivotThresh = user_opts.pivot_thresh; + + if(user_opts.permutation == superlu_opts::NATURAL) { options.ColPerm = superlu::NATURAL; } + if(user_opts.permutation == superlu_opts::MMD_ATA) { options.ColPerm = superlu::MMD_ATA; } + if(user_opts.permutation == superlu_opts::MMD_AT_PLUS_A) { options.ColPerm = superlu::MMD_AT_PLUS_A; } + if(user_opts.permutation == superlu_opts::COLAMD) { options.ColPerm = superlu::COLAMD; } + + if(user_opts.refine == superlu_opts::REF_NONE) { options.IterRefine = superlu::NOREFINE; } + if(user_opts.refine == superlu_opts::REF_SINGLE) { options.IterRefine = superlu::SLU_SINGLE; } + if(user_opts.refine == superlu_opts::REF_DOUBLE) { options.IterRefine = superlu::SLU_DOUBLE; } + if(user_opts.refine == superlu_opts::REF_EXTRA) { options.IterRefine = superlu::SLU_EXTRA; } + } + + + + template + inline + bool + sp_auxlib::copy_to_supermatrix(superlu::SuperMatrix& out, const SpMat& A) + { + arma_extra_debug_sigprint(); + + // We store in column-major CSC. + out.Stype = superlu::SLU_NC; + + if( is_float::value) { out.Dtype = superlu::SLU_S; } + else if( is_double::value) { out.Dtype = superlu::SLU_D; } + else if( is_cx_float::value) { out.Dtype = superlu::SLU_C; } + else if(is_cx_double::value) { out.Dtype = superlu::SLU_Z; } + + out.Mtype = superlu::SLU_GE; // Just a general matrix. We don't know more now. + + // We have to actually create the object which stores the data. + // This gets cleaned by destroy_supermatrix(). + // We have to use SuperLU's problematic memory allocation routines since they are + // not guaranteed to be new and delete. See the comments in def_superlu.hpp + superlu::NCformat* nc = (superlu::NCformat*)superlu::malloc(sizeof(superlu::NCformat)); + + if(nc == nullptr) { return false; } + + A.sync(); + + nc->nnz = A.n_nonzero; + nc->nzval = (void*) superlu::malloc(sizeof(eT) * A.n_nonzero ); + nc->colptr = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * (A.n_cols + 1)); + nc->rowind = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * A.n_nonzero ); + + if( (nc->nzval == nullptr) || (nc->colptr == nullptr) || (nc->rowind == nullptr) ) { return false; } + + // Fill the matrix. + arrayops::copy((eT*) nc->nzval, A.values, A.n_nonzero); + + // // These have to be copied by hand, because the types may differ. + // for(uword i = 0; i <= A.n_cols; ++i) { nc->colptr[i] = (int_t) A.col_ptrs[i]; } + // for(uword i = 0; i < A.n_nonzero; ++i) { nc->rowind[i] = (int_t) A.row_indices[i]; } + + arrayops::convert(nc->colptr, A.col_ptrs, A.n_cols+1 ); + arrayops::convert(nc->rowind, A.row_indices, A.n_nonzero); + + out.nrow = superlu::int_t(A.n_rows); + out.ncol = superlu::int_t(A.n_cols); + out.Store = (void*) nc; + + return true; + } + + + + // memory efficient implementation of out = A - shift*I, where A is a square matrix + template + inline + bool + sp_auxlib::copy_to_supermatrix_with_shift(superlu::SuperMatrix& out, const SpMat& A, const eT shift) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (A.is_square() == false), "sp_auxlib::copy_to_supermatrix_with_shift(): given matrix must be square sized" ); + + if(shift == eT(0)) + { + arma_extra_debug_print("sp_auxlib::copy_to_supermatrix_with_shift(): shift is zero; redirecting to sp_auxlib::copy_to_supermatrix()"); + return sp_auxlib::copy_to_supermatrix(out, A); + } + + // We store in column-major CSC. + out.Stype = superlu::SLU_NC; + + if( is_float::value) { out.Dtype = superlu::SLU_S; } + else if( is_double::value) { out.Dtype = superlu::SLU_D; } + else if( is_cx_float::value) { out.Dtype = superlu::SLU_C; } + else if(is_cx_double::value) { out.Dtype = superlu::SLU_Z; } + + out.Mtype = superlu::SLU_GE; // Just a general matrix. We don't know more now. + + // We have to actually create the object which stores the data. + // This gets cleaned by destroy_supermatrix(). + superlu::NCformat* nc = (superlu::NCformat*)superlu::malloc(sizeof(superlu::NCformat)); + + if(nc == nullptr) { return false; } + + A.sync(); + + uword n_nonzero_diag_old = 0; + uword n_nonzero_diag_new = 0; + + const uword n_search_cols = (std::min)(A.n_rows, A.n_cols); + + for(uword j=0; j < n_search_cols; ++j) + { + const uword col_offset = A.col_ptrs[j ]; + const uword next_col_offset = A.col_ptrs[j + 1]; + + const uword* start_ptr = &(A.row_indices[ col_offset]); + const uword* end_ptr = &(A.row_indices[next_col_offset]); + + const uword wanted_row = j; + + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, wanted_row); // binary search + + if( (pos_ptr != end_ptr) && ((*pos_ptr) == wanted_row) ) + { + // element on the main diagonal is non-zero + ++n_nonzero_diag_old; + + const uword offset = uword(pos_ptr - start_ptr); + const uword index = offset + col_offset; + + const eT new_val = A.values[index] - shift; + + if(new_val != eT(0)) { ++n_nonzero_diag_new; } + } + else + { + // element on the main diagonal is zero, but sigma is non-zero, + // so the number of new non-zero elments on the diagonal is increased + ++n_nonzero_diag_new; + } + } + + const uword out_n_nonzero = A.n_nonzero - n_nonzero_diag_old + n_nonzero_diag_new; + + arma_extra_debug_print( arma_str::format("A.n_nonzero: %u") % A.n_nonzero ); + arma_extra_debug_print( arma_str::format("n_nonzero_diag_old: %u") % n_nonzero_diag_old ); + arma_extra_debug_print( arma_str::format("n_nonzero_diag_new: %u") % n_nonzero_diag_new ); + arma_extra_debug_print( arma_str::format("out_n_nonzero: %u") % out_n_nonzero ); + + nc->nnz = out_n_nonzero; + nc->nzval = (void*) superlu::malloc(sizeof(eT) * out_n_nonzero ); + nc->colptr = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * (A.n_cols + 1)); + nc->rowind = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * out_n_nonzero ); + + if( (nc->nzval == nullptr) || (nc->colptr == nullptr) || (nc->rowind == nullptr) ) { return false; } + + // fill the matrix column by column, and insert diagonal elements when necessary + + nc->colptr[0] = 0; + + eT* values_current = (eT*) nc->nzval; + superlu::int_t* rowind_current = nc->rowind; + + uword count = 0; + + for(uword j=0; j < A.n_cols; ++j) + { + const uword idx_start = A.col_ptrs[j ]; + const uword idx_end = A.col_ptrs[j + 1]; + + const eT* values_start = values_current; + + uword i = idx_start; + + // elements in the upper triangular part, excluding the main diagonal + for(; (i < idx_end) && (A.row_indices[i] < j); ++i) + { + (*values_current) = A.values[i]; + (*rowind_current) = superlu::int_t(A.row_indices[i]); + + ++values_current; + ++rowind_current; + + ++count; + } + + // elements on the main diagonal + if( (i < idx_end) && (A.row_indices[i] == j) ) + { + // A(j,j) is non-zero + + const eT new_diag_val = A.values[i] - shift; + + if(new_diag_val != eT(0)) + { + (*values_current) = new_diag_val; + (*rowind_current) = superlu::int_t(j); + + ++values_current; + ++rowind_current; + + ++count; + } + + ++i; + } + else + { + // A(j,j) is zero, so insert a new element + + if(j < n_search_cols) + { + (*values_current) = -shift; + (*rowind_current) = superlu::int_t(j); + + ++values_current; + ++rowind_current; + + ++count; + } + } + + // elements in the lower triangular part, excluding the main diagonal + for(; i < idx_end; ++i) + { + (*values_current) = A.values[i]; + (*rowind_current) = superlu::int_t(A.row_indices[i]); + + ++values_current; + ++rowind_current; + + ++count; + } + + // number of non-zero elements in the j-th column of out + const uword nnz_col = values_current - values_start; + nc->colptr[j + 1] = superlu::int_t(nc->colptr[j] + nnz_col); + } + + arma_extra_debug_print( arma_str::format("count: %u") % count ); + + arma_check( (count != out_n_nonzero), "internal error: sp_auxlib::copy_to_supermatrix_with_shift(): count != out_n_nonzero" ); + + out.nrow = superlu::int_t(A.n_rows); + out.ncol = superlu::int_t(A.n_cols); + out.Store = (void*) nc; + + return true; + } + + + +// // for debugging only +// template +// inline +// void +// sp_auxlib::copy_to_spmat(SpMat& out, const superlu::SuperMatrix& A) +// { +// arma_extra_debug_sigprint(); +// +// bool type_matched = false; +// +// if( is_float::value) { type_matched = (A.Dtype == superlu::SLU_S); } +// else if( is_double::value) { type_matched = (A.Dtype == superlu::SLU_D); } +// else if( is_cx_float::value) { type_matched = (A.Dtype == superlu::SLU_C); } +// else if(is_cx_double::value) { type_matched = (A.Dtype == superlu::SLU_Z); } +// +// arma_debug_check( (type_matched == false), "copy_to_spmat(): type mismatch" ); +// arma_debug_check( (A.Mtype != superlu::SLU_GE), "copy_to_spmat(): unknown layout" ); +// +// // NOTE: the l and u instances of SuperMatrix resulting from superlu::gstrf() +// // NOTE: do not have the superlu::SLU_GE layout +// +// const superlu::NCformat* nc = (const superlu::NCformat*)(A.Store); +// +// if(nc == nullptr) { out.reset(); return; } +// +// if( (nc->nzval == nullptr) || (nc->colptr == nullptr) || (nc->rowind == nullptr) ) { out.reset(); return; } +// +// const uword A_n_rows = uword(A.nrow ); +// const uword A_n_cols = uword(A.ncol ); +// const uword A_n_nonzero = uword(nc->nnz); +// +// if(A_n_nonzero == 0) { out.zeros(A_n_rows, A_n_cols); return; } +// +// out.reserve(A_n_rows, A_n_cols, A_n_nonzero); +// +// arrayops::copy(access::rwp(out.values), (const eT*)(nc->nzval), A_n_nonzero); +// +// arrayops::convert(access::rwp(out.col_ptrs), nc->colptr, A_n_cols+1 ); +// arrayops::convert(access::rwp(out.row_indices), nc->rowind, A_n_nonzero); +// +// out.remove_zeros(); // in case SuperLU has bugs and stores zeros in sparse matrices +// } + + + + template + inline + bool + sp_auxlib::wrap_to_supermatrix(superlu::SuperMatrix& out, const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: this function re-uses memory from matrix A + + // This is being stored as a dense matrix. + out.Stype = superlu::SLU_DN; + + if( is_float::value) { out.Dtype = superlu::SLU_S; } + else if( is_double::value) { out.Dtype = superlu::SLU_D; } + else if( is_cx_float::value) { out.Dtype = superlu::SLU_C; } + else if(is_cx_double::value) { out.Dtype = superlu::SLU_Z; } + + out.Mtype = superlu::SLU_GE; + + // We have to create the object that stores the data. + superlu::DNformat* dn = (superlu::DNformat*)superlu::malloc(sizeof(superlu::DNformat)); + + if(dn == nullptr) { return false; } + + dn->lda = A.n_rows; + dn->nzval = (void*) A.memptr(); // re-use memory instead of copying + + out.nrow = A.n_rows; + out.ncol = A.n_cols; + out.Store = (void*) dn; + + return true; + } + + + + inline + void + sp_auxlib::destroy_supermatrix(superlu::SuperMatrix& out) + { + arma_extra_debug_sigprint(); + + // Clean up. + if(out.Stype == superlu::SLU_NC) + { + superlu::destroy_compcol_mat(&out); + } + else + if(out.Stype == superlu::SLU_NCP) + { + superlu::destroy_compcolperm_mat(&out); + } + else + if(out.Stype == superlu::SLU_DN) + { + // superlu::destroy_dense_mat(&out); + + // since dn->nzval is set to re-use memory from a Mat object (which manages its own memory), + // we cannot simply call superlu::destroy_dense_mat(). + // Only the out.Store structure can be freed. + + superlu::DNformat* dn = (superlu::DNformat*) out.Store; + + if(dn != nullptr) { superlu::free(dn); } + } + else + if(out.Stype == superlu::SLU_SC) + { + superlu::destroy_supernode_mat(&out); + } + else + { + // Uh, crap. + + std::ostringstream tmp; + + tmp << "sp_auxlib::destroy_supermatrix(): unhandled Stype" << std::endl; + tmp << "Stype val: " << out.Stype << std::endl; + tmp << "Stype name: "; + + if(out.Stype == superlu::SLU_NC) { tmp << "SLU_NC"; } + if(out.Stype == superlu::SLU_NCP) { tmp << "SLU_NCP"; } + if(out.Stype == superlu::SLU_NR) { tmp << "SLU_NR"; } + if(out.Stype == superlu::SLU_SC) { tmp << "SLU_SC"; } + if(out.Stype == superlu::SLU_SCP) { tmp << "SLU_SCP"; } + if(out.Stype == superlu::SLU_SR) { tmp << "SLU_SR"; } + if(out.Stype == superlu::SLU_DN) { tmp << "SLU_DN"; } + if(out.Stype == superlu::SLU_NR_loc) { tmp << "SLU_NR_loc"; } + + arma_debug_warn_level(1, tmp.str()); + arma_stop_runtime_error("internal error: sp_auxlib::destroy_supermatrix()"); + } + } + +#endif + + + +template +inline +void +sp_auxlib::run_aupd_plain + ( + const uword n_eigvals, char* which, + const SpMat& X, const SpMat& Xst, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, + podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, + podarray& iparam, podarray& ipntr, + podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, + blas_int& info + ) + { + #if defined(ARMA_USE_ARPACK) + { + // ARPACK provides a "reverse communication interface" which is an + // entertainingly archaic FORTRAN software engineering technique that + // basically means that we call saupd()/naupd() and it tells us with some + // return code what we need to do next (usually a matrix-vector product) and + // then call it again. So this results in some type of iterative process + // where we call saupd()/naupd() many times. + + blas_int ido = 0; // This must be 0 for the first call. + char bmat = 'I'; // We are considering the standard eigenvalue problem. + n = X.n_rows; // The size of the matrix (should already be set outside). + blas_int nev = n_eigvals; + + // resid.zeros(n); + eigs_randu_filler randu_filler; + randu_filler.fill(resid, n); // use deterministic starting point + + // Two contraints on NCV: (NCV > NEV) for sym problems or + // (NCV > NEV + 2) for gen problems and (NCV <= N) + // + // We're calling either arpack::saupd() or arpack::naupd(), + // which have slighly different minimum constraint and recommended value for NCV: + // http://www.caam.rice.edu/software/ARPACK/UG/node136.html + // http://www.caam.rice.edu/software/ARPACK/UG/node138.html + + if(ncv < (nev + (sym ? 1 : 3))) { ncv = (nev + (sym ? 1 : 3)); } + if(ncv > n ) { ncv = n; } + + v.zeros(n * ncv); // Array N by NCV (output). + rwork.zeros(ncv); // Work array of size NCV for complex calls. + ldv = n; // "Leading dimension of V exactly as declared in the calling program." + + // IPARAM: integer array of length 11. + iparam.zeros(11); + iparam(0) = 1; // Exact shifts (not provided by us). + iparam(2) = maxiter; // Maximum iterations; all the examples use 300, but they were written in the ancient times. + iparam(6) = 1; // Mode 1: A * x = lambda * x. + + // IPNTR: integer array of length 14 (output). + ipntr.zeros(14); + + // Real work array used in the basic Arnoldi iteration for reverse communication. + workd.zeros(3 * n); + + // lworkl must be at least 3 * NCV^2 + 6 * NCV. + lworkl = 3 * (ncv * ncv) + 6 * ncv; + + // Real work array of length lworkl. + workl.zeros(lworkl); + + // info = 0; // resid to be filled with random values by ARPACK (non-deterministic) + info = 1; // resid is already filled with random values (deterministic) + + // All the parameters have been set or created. Time to loop a lot. + while(ido != 99) + { + // Call saupd() or naupd() with the current parameters. + if(sym) + { + arma_extra_debug_print("arpack::saupd()"); + arpack::saupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, &info); + } + else + { + arma_extra_debug_print("arpack::naupd()"); + arpack::naupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); + } + + // What do we do now? + switch (ido) + { + case -1: + // fallthrough + case 1: + { + // We need to calculate the matrix-vector multiplication y = OP * x + // where x is of length n and starts at workd(ipntr(0)), and y is of + // length n and starts at workd(ipntr(1)). + + // // OLD METHOD + // + // // operator*(sp_mat, vec) doesn't properly put the result into the + // // right place so we'll just reimplement it here for now... + // + // // Set the output to point at the right memory. We have to subtract + // // one from FORTRAN pointers... + // Col out(workd.memptr() + ipntr(1) - 1, n, false /* don't copy */); + // // Set the input to point at the right memory. + // Col in(workd.memptr() + ipntr(0) - 1, n, false /* don't copy */); + // + // out.zeros(); + // + // T* out_mem = out.memptr(); + // const T* in_mem = in.memptr(); + // + // typename SpMat::const_iterator X_it = X.begin(); + // + // const uword X_nnz = X.n_nonzero; + // + // for(uword count=0; count < X_nnz; ++count, ++X_it) + // { + // const eT X_it_val = (*X_it); + // const uword X_it_row = X_it.row(); + // const uword X_it_col = X_it.col(); + // + // out_mem[X_it_row] += X_it_val * in_mem[X_it_col]; + // } + // + // // No need to modify memory further since it was all done in-place. + + + // NEW METHOD + // + // both operator*(rowvec, sp_mat) and operator*(sp_mat, colvec) can now write to an existing object + + Row out(workd.memptr() + ipntr(1) - 1, n, false, true); + Row in(workd.memptr() + ipntr(0) - 1, n, false, true); + + out = in * Xst; + + break; + } + case 99: + // Nothing to do here, things have converged. + break; + default: + { + return; // Parent frame can look at the value of info. + } + } + } + + // The process has ended; check the return code. + if( (info != 0) && (info != 1) ) + { + // Print warnings if there was a failure. + + if(sym) + { + arma_debug_warn_level(1, "eigs_sym(): ARPACK error ", info, " in saupd()"); + } + else + { + arma_debug_warn_level(1, "eigs_gen(): ARPACK error ", info, " in naupd()"); + } + + return; // Parent frame can look at the value of info. + } + } + #else + { + arma_ignore(n_eigvals); + arma_ignore(which); + arma_ignore(X); + arma_ignore(sym); + arma_ignore(n); + arma_ignore(tol); + arma_ignore(maxiter); + arma_ignore(resid); + arma_ignore(ncv); + arma_ignore(v); + arma_ignore(ldv); + arma_ignore(iparam); + arma_ignore(ipntr); + arma_ignore(workd); + arma_ignore(workl); + arma_ignore(lworkl); + arma_ignore(rwork); + arma_ignore(info); + } + #endif + } + + + +// Here 'sigma' is 'T', but should be 'eT'. +// Applying complex shifts to real matrices is currently not directly implemented +template +inline +void +sp_auxlib::run_aupd_shiftinvert + ( + const uword n_eigvals, const T sigma, + const SpMat& X, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, + podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, + podarray& iparam, podarray& ipntr, + podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, + blas_int& info + ) + { + // TODO: inconsistent use of type names: T can be complex while eT can be real + + #if (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + char which_lm[3] = "LM"; + + char* which = which_lm; // NOTE: which_lm is the assumed operation when using shift-invert + + blas_int ido = 0; // This must be 0 for the first call. + char bmat = 'I'; // We are considering the standard eigenvalue problem. + n = X.n_rows; // The size of the matrix (should already be set outside). + blas_int nev = n_eigvals; + + // resid.zeros(n); + eigs_randu_filler randu_filler; + randu_filler.fill(resid, n); // use deterministic starting point + + // Two contraints on NCV: (NCV > NEV) for sym problems or + // (NCV > NEV + 2) for gen problems and (NCV <= N) + // + // We're calling either arpack::saupd() or arpack::naupd(), + // which have slighly different minimum constraint and recommended value for NCV: + // http://www.caam.rice.edu/software/ARPACK/UG/node136.html + // http://www.caam.rice.edu/software/ARPACK/UG/node138.html + + if(ncv < (nev + (sym ? 1 : 3))) { ncv = (nev + (sym ? 1 : 3)); } + if(ncv > n ) { ncv = n; } + + v.zeros(n * ncv); // Array N by NCV (output). + rwork.zeros(ncv); // Work array of size NCV for complex calls. + ldv = n; // "Leading dimension of V exactly as declared in the calling program." + + // IPARAM: integer array of length 11. + iparam.zeros(11); + iparam(0) = 1; // Exact shifts (not provided by us). + iparam(2) = maxiter; // Maximum iterations; all the examples use 300, but they were written in the ancient times. + // iparam(6) = 1; // Mode 1: A * x = lambda * x. + + // Change IPARAM for shift-invert + iparam(6) = 3; // Mode 3: A * x = lambda * M * x, M symmetric semi-definite. OP = inv[A - sigma*M]*M (A complex) or Real_Part{ inv[A - sigma*M]*M } (A real) and B = M. + + // IPNTR: integer array of length 14 (output). + ipntr.zeros(14); + + // Real work array used in the basic Arnoldi iteration for reverse communication. + workd.zeros(3 * n); + + // lworkl must be at least 3 * NCV^2 + 6 * NCV. + lworkl = 3 * (ncv * ncv) + 6 * ncv; + + // Real work array of length lworkl. + workl.zeros(lworkl); + + // info = 0; // resid to be filled with random values by ARPACK (non-deterministic) + info = 1; // resid is already filled with random values (deterministic) + + superlu_opts superlu_opts_default; + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, superlu_opts_default); + int lwork = 0; + superlu::trans_t trans = superlu::NOTRANS; + + superlu::GlobalLU_t Glu; /* Not needed on return. */ + arrayops::fill_zeros(reinterpret_cast(&Glu), sizeof(superlu::GlobalLU_t)); + + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler xC; + + const bool status_x = sp_auxlib::copy_to_supermatrix_with_shift(x.get_ref(), X, sigma); + + if(status_x == false) + { + arma_stop_runtime_error("run_aupd_shiftinvert(): could not construct SuperLU matrix"); + info = blas_int(-1); + return; + } + + // // for debugging only + // if(true) + // { + // cout << "*** testing output of copy_to_supermatrix_with_shift()" << endl; + // cout << "*** sigma: " << sigma << endl; + // + // SpMat Y(X); + // Y.diag() -= sigma; + // + // SpMat Z; + // + // sp_auxlib::copy_to_spmat(Z, x.get_ref()); + // + // cout << "*** size(Y): " << arma::size(Y) << endl; + // cout << "*** size(Z): " << arma::size(Z) << endl; + // cout << "*** accu(abs(Y)): " << accu(abs(Y)) << endl; + // cout << "*** accu(abs(Z)): " << accu(abs(Z)) << endl; + // + // if(arma::size(Y) == arma::size(Z)) + // { + // cout << "*** error: " << accu(abs(Y-Z)) << endl; + // } + // } + + superlu_supermatrix_wrangler l; + superlu_supermatrix_wrangler u; + + superlu_array_wrangler perm_c(X.n_cols+1); // paranoia: increase array length by 1 + superlu_array_wrangler perm_r(X.n_rows+1); + superlu_array_wrangler etree(X.n_cols+1); + + superlu_stat_wrangler stat; + + int panel_size = superlu::sp_ispec_environ(1); + int relax = superlu::sp_ispec_environ(2); + int slu_info = 0; // Return code. + + arma_extra_debug_print("superlu::gstrf()"); + superlu::get_permutation_c(options.ColPerm, x.get_ptr(), perm_c.get_ptr()); + superlu::sp_preorder_mat(&options, x.get_ptr(), perm_c.get_ptr(), etree.get_ptr(), xC.get_ptr()); + superlu::gstrf(&options, xC.get_ptr(), relax, panel_size, etree.get_ptr(), NULL, lwork, perm_c.get_ptr(), perm_r.get_ptr(), l.get_ptr(), u.get_ptr(), &Glu, stat.get_ptr(), &slu_info); + + if(slu_info != 0) + { + arma_debug_warn_level(2, "matrix is singular to working precision"); + info = blas_int(-1); + return; + } + + // NOTE: potential problem with inconsistent/mismatched use of eT and T types + eT x_norm_val = sp_auxlib::norm1(x.get_ptr()); + eT x_rcond = sp_auxlib::lu_rcond(l.get_ptr(), u.get_ptr(), x_norm_val); + + if( (x_rcond < std::numeric_limits::epsilon()) || arma_isnan(x_rcond) ) + { + arma_debug_warn_level(2, "matrix is singular to working precision (rcond: ", x_rcond, ")"); + info = blas_int(-1); + return; + } + + // All the parameters have been set or created. Time to loop a lot. + while(ido != 99) + { + // Call saupd() or naupd() with the current parameters. + if(sym) + { + arma_extra_debug_print("arpack::saupd()"); + arpack::saupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, &info); + } + else + { + arma_extra_debug_print("arpack::naupd()"); + arpack::naupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); + } + + // What do we do now? + switch (ido) + { + case -1: + // fallthrough + case 1: + { + // We need to calculate the matrix-vector multiplication y = OP * x + // where x is of length n and starts at workd(ipntr(0)), and y is of + // length n and starts at workd(ipntr(1)). + + // Set the output to point at the right memory. We have to subtract + // one from FORTRAN pointers... + Col out(workd.memptr() + ipntr(1) - 1, n, false /* don't copy */); + // Set the input to point at the right memory. + Col in(workd.memptr() + ipntr(0) - 1, n, false /* don't copy */); + + // Consider getting the LU factorization from ZGSTRF, and then + // solve the system L*U*out = in (possibly with permutation matrix?) + // Instead of "spsolve(out,X,in)" we call gstrf above and gstrs below + + out = in; + superlu_supermatrix_wrangler out_slu; + + const bool status_out_slu = sp_auxlib::wrap_to_supermatrix(out_slu.get_ref(), out); + + if(status_out_slu == false) { arma_stop_runtime_error("run_aupd_shiftinvert(): could not construct SuperLU matrix"); return; } + + arma_extra_debug_print("superlu::gstrs()"); + superlu::gstrs(trans, l.get_ptr(), u.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), out_slu.get_ptr(), stat.get_ptr(), &info); + + // No need to modify memory further since it was all done in-place. + + break; + } + case 99: + // Nothing to do here, things have converged. + break; + default: + { + return; // Parent frame can look at the value of info. + } + } + } + + // The process has ended; check the return code. + if( (info != 0) && (info != 1) ) + { + // Print warnings if there was a failure. + + if(sym) + { + arma_debug_warn_level(2, "eigs_sym(): ARPACK error ", info, " in saupd()"); + } + else + { + arma_debug_warn_level(2, "eigs_gen(): ARPACK error ", info, " in naupd()"); + } + + return; // Parent frame can look at the value of info. + } + } + #else + { + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(X); + arma_ignore(sym); + arma_ignore(n); + arma_ignore(tol); + arma_ignore(maxiter); + arma_ignore(resid); + arma_ignore(ncv); + arma_ignore(v); + arma_ignore(ldv); + arma_ignore(iparam); + arma_ignore(ipntr); + arma_ignore(workd); + arma_ignore(workl); + arma_ignore(lworkl); + arma_ignore(rwork); + arma_ignore(info); + } + #endif + } + + + +template +inline +bool +sp_auxlib::rudimentary_sym_check(const SpMat& X) + { + arma_extra_debug_sigprint(); + + if(X.n_rows != X.n_cols) { return false; } + + const eT tol = eT(10000) * std::numeric_limits::epsilon(); // allow some leeway + + typename SpMat::const_iterator it = X.begin(); + typename SpMat::const_iterator it_end = X.end(); + + const uword n_check_limit = (std::max)( uword(2), uword(X.n_nonzero/100) ); + + uword n_check = 1; + + while( (it != it_end) && (n_check <= n_check_limit) ) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + if(it_row != it_col) + { + const eT A = (*it); + const eT B = X.at( it_col, it_row ); // deliberately swapped + + const eT C = (std::max)(std::abs(A), std::abs(B)); + + const eT delta = std::abs(A - B); + + if(( (delta <= tol) || (delta <= (C * tol)) ) == false) { return false; } + + ++n_check; + } + + ++it; + } + + return true; + } + + + +template +inline +bool +sp_auxlib::rudimentary_sym_check(const SpMat< std::complex >& X) + { + arma_extra_debug_sigprint(); + + // NOTE: the function name is a misnomer, as it checks for hermitian complex matrices; + // NOTE: for simplicity of use, the function name is the same as for real matrices + + typedef typename std::complex eT; + + if(X.n_rows != X.n_cols) { return false; } + + const T tol = T(10000) * std::numeric_limits::epsilon(); // allow some leeway + + typename SpMat::const_iterator it = X.begin(); + typename SpMat::const_iterator it_end = X.end(); + + const uword n_check_limit = (std::max)( uword(2), uword(X.n_nonzero/100) ); + + uword n_check = 1; + + while( (it != it_end) && (n_check <= n_check_limit) ) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + if(it_row != it_col) + { + const eT A = (*it); + const eT B = X.at( it_col, it_row ); // deliberately swapped + + const T C_real = (std::max)(std::abs(A.real()), std::abs(B.real())); + const T C_imag = (std::max)(std::abs(A.imag()), std::abs(B.imag())); + + const T delta_real = std::abs(A.real() - B.real()); + const T delta_imag = std::abs(A.imag() + B.imag()); // take into account the conjugate + + const bool okay_real = ( (delta_real <= tol) || (delta_real <= (C_real * tol)) ); + const bool okay_imag = ( (delta_imag <= tol) || (delta_imag <= (C_imag * tol)) ); + + if( (okay_real == false) || (okay_imag == false) ) { return false; } + + ++n_check; + } + else + { + const eT A = (*it); + + if(std::abs(A.imag()) > tol) { return false; } + } + + ++it; + } + + return true; + } + + + +// + + + +template +inline +eigs_randu_filler::eigs_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +eigs_randu_filler::fill(podarray& X, const uword N) + { + arma_extra_debug_sigprint(); + + X.set_size(N); + + eT* X_mem = X.memptr(); + + for(uword i=0; i +inline +eigs_randu_filler< std::complex >::eigs_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +eigs_randu_filler< std::complex >::fill(podarray< std::complex >& X, const uword N) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + X.set_size(N); + + eT* X_mem = X.memptr(); + + for(uword i=0; i(&m); + bool all_zero = true; + + for(size_t i=0; i < sizeof(superlu::SuperMatrix); ++i) + { + if(m_char[i] != char(0)) { all_zero = false; break; } + } + + if(all_zero == false) { sp_auxlib::destroy_supermatrix(m); } + } + +inline +superlu_supermatrix_wrangler::superlu_supermatrix_wrangler() + { + arma_extra_debug_sigprint_this(this); + + arrayops::fill_zeros(reinterpret_cast(&m), sizeof(superlu::SuperMatrix)); + } + +inline +superlu::SuperMatrix& +superlu_supermatrix_wrangler::get_ref() + { + used = true; + + return m; + } + +inline +superlu::SuperMatrix* +superlu_supermatrix_wrangler::get_ptr() + { + used = true; + + return &m; + } + + +// + + +inline +superlu_stat_wrangler::~superlu_stat_wrangler() + { + arma_extra_debug_sigprint_this(this); + + superlu::free_stat(&stat); + } + +inline +superlu_stat_wrangler::superlu_stat_wrangler() + { + arma_extra_debug_sigprint_this(this); + + arrayops::fill_zeros(reinterpret_cast(&stat), sizeof(superlu::SuperLUStat_t)); + + superlu::init_stat(&stat); + } + +inline +superlu::SuperLUStat_t* +superlu_stat_wrangler::get_ptr() + { + return &stat; + } + + +// + + +template +inline +superlu_array_wrangler::~superlu_array_wrangler() + { + arma_extra_debug_sigprint_this(this); + + (*this).reset(); + } + +template +inline +superlu_array_wrangler::superlu_array_wrangler() + : mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + } + +template +inline +superlu_array_wrangler::superlu_array_wrangler(const uword n_elem) + : mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + (*this).set_size(n_elem); + } + +template +inline +void +superlu_array_wrangler::set_size(const uword n_elem) + { + arma_extra_debug_sigprint(); + + if(mem != nullptr) { (*this).reset(); } + + mem = (eT*)(superlu::malloc(n_elem * sizeof(eT))); + + arma_check_bad_alloc( (mem == nullptr), "superlu::malloc(): out of memory" ); + + arrayops::fill_zeros(mem, n_elem); + } + +template +inline +void +superlu_array_wrangler::reset() + { + arma_extra_debug_sigprint(); + + if(mem != nullptr) + { + superlu::free(mem); + mem = nullptr; + } + } + +template +inline +eT* +superlu_array_wrangler::get_ptr() + { + return mem; + } + + +// + + +template +inline +superlu_worker::~superlu_worker() + { + arma_extra_debug_sigprint_this(this); + + if(l != nullptr) { delete l; l = nullptr; } + if(u != nullptr) { delete u; u = nullptr; } + } + + +template +inline +superlu_worker::superlu_worker() + { + arma_extra_debug_sigprint_this(this); + } + + +template +inline +bool +superlu_worker::factorise(typename get_pod_type::result& out_rcond, const SpMat& A, const superlu_opts& user_opts) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + factorisation_valid = false; + + if(l != nullptr) { delete l; l = nullptr; } + if(u != nullptr) { delete u; u = nullptr; } + + l = new(std::nothrow) superlu_supermatrix_wrangler; + u = new(std::nothrow) superlu_supermatrix_wrangler; + + if( (l == nullptr) || (u == nullptr) ) + { + arma_debug_warn_level(3, "superlu_worker()::factorise(): could not construct SuperLU matrix"); + return false; + } + + superlu_supermatrix_wrangler& l_ref = (*l); + superlu_supermatrix_wrangler& u_ref = (*u); + + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, user_opts); + + superlu_supermatrix_wrangler AA; + superlu_supermatrix_wrangler AAc; + + const bool status_AA = sp_auxlib::copy_to_supermatrix(AA.get_ref(), A); + + if(status_AA == false) + { + arma_debug_warn_level(3, "superlu_worker()::factorise(): could not construct SuperLU matrix"); + return false; + } + + (*this).perm_c.set_size(A.n_cols+1); // paranoia: increase array length by 1 + (*this).perm_r.set_size(A.n_rows+1); + + superlu_array_wrangler etree(A.n_cols+1); + + superlu::GlobalLU_t Glu; + arrayops::fill_zeros(reinterpret_cast(&Glu), sizeof(superlu::GlobalLU_t)); + + int panel_size = superlu::sp_ispec_environ(1); + int relax = superlu::sp_ispec_environ(2); + int lwork = 0; + int info = 0; + + arma_extra_debug_print("superlu::superlu::get_permutation_c()"); + superlu::get_permutation_c(options.ColPerm, AA.get_ptr(), perm_c.get_ptr()); + + arma_extra_debug_print("superlu::superlu::sp_preorder_mat()"); + superlu::sp_preorder_mat(&options, AA.get_ptr(), perm_c.get_ptr(), etree.get_ptr(), AAc.get_ptr()); + + arma_extra_debug_print("superlu::gstrf()"); + superlu::gstrf(&options, AAc.get_ptr(), relax, panel_size, etree.get_ptr(), NULL, lwork, perm_c.get_ptr(), perm_r.get_ptr(), l_ref.get_ptr(), u_ref.get_ptr(), &Glu, stat.get_ptr(), &info); + + if(info != 0) + { + arma_debug_warn_level(3, "superlu_worker()::factorise(): LU factorisation failed"); + return false; + } + + const T AA_norm = sp_auxlib::norm1(AA.get_ptr()); + const T AA_rcond = sp_auxlib::lu_rcond(l_ref.get_ptr(), u_ref.get_ptr(), AA_norm); + + out_rcond = AA_rcond; + + if(arma_isnan(AA_rcond)) { return false; } + // if(AA_rcond == T(0)) { return false; } + + factorisation_valid = true; + + return true; + } + + +template +inline +bool +superlu_worker::solve(Mat& X, const Mat& B) + { + arma_extra_debug_sigprint(); + + if(factorisation_valid == false) { return false; } + if( (l == nullptr) || (u == nullptr) ) { return false; } + + superlu_supermatrix_wrangler& l_ref = (*l); + superlu_supermatrix_wrangler& u_ref = (*u); + + X = B; + + superlu_supermatrix_wrangler XX; + + const bool status_XX = sp_auxlib::wrap_to_supermatrix(XX.get_ref(), X); + + if(status_XX == false) + { + arma_debug_warn_level(3, "superlu_worker()::solve(): could not construct SuperLU matrix"); + return false; + } + + superlu::trans_t trans = superlu::NOTRANS; + int info = 0; + + arma_extra_debug_print("superlu::gstrs()"); + superlu::gstrs(trans, l_ref.get_ptr(), u_ref.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), XX.get_ptr(), stat.get_ptr(), &info); + + return (info == 0); + } + + +#endif + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/span.hpp b/src/armadillo/include/armadillo_bits/span.hpp new file mode 100644 index 0000000..14774f1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/span.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup span +//! @{ + + +struct span_alt {}; + + +template +class span_base + { + public: + static const span_alt all; + }; + + +template +const span_alt span_base::all = span_alt(); + + +class span : public span_base<> + { + public: + + uword a; + uword b; + bool whole; + + inline + span() + : a(0) + , b(0) + , whole(true) + { + } + + + inline + span(const span_alt&) + : a(0) + , b(0) + , whole(true) + { + } + + + inline + explicit + span(const uword in_a) + : a(in_a) + , b(in_a) + , whole(false) + { + } + + + // the "explicit" keyword is required here to prevent automatic conversion of {a,b} + // into an instance of span() when submatrices are specified + inline + explicit + span(const uword in_a, const uword in_b) + : a(in_a) + , b(in_b) + , whole(false) + { + } + + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spdiagview_bones.hpp b/src/armadillo/include/armadillo_bits/spdiagview_bones.hpp new file mode 100644 index 0000000..238e8a3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spdiagview_bones.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spdiagview +//! @{ + + +//! Class for storing data required to extract and set the diagonals of a sparse matrix +template +class spdiagview : public SpBase< eT, spdiagview > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + arma_aligned const SpMat& m; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + const uword row_offset; + const uword col_offset; + + const uword n_rows; // equal to n_elem + const uword n_elem; + + static constexpr uword n_cols = 1; + + + protected: + + arma_inline spdiagview(const SpMat& in_m, const uword in_row_offset, const uword in_col_offset, const uword len); + + + public: + + inline ~spdiagview(); + inline spdiagview() = delete; + + inline void operator=(const spdiagview& x); + + inline void operator+=(const eT val); + inline void operator-=(const eT val); + inline void operator*=(const eT val); + inline void operator/=(const eT val); + + template inline void operator= (const Base& x); + template inline void operator+=(const Base& x); + template inline void operator-=(const Base& x); + template inline void operator%=(const Base& x); + template inline void operator/=(const Base& x); + + template inline void operator= (const SpBase& x); + template inline void operator+=(const SpBase& x); + template inline void operator-=(const SpBase& x); + template inline void operator%=(const SpBase& x); + template inline void operator/=(const SpBase& x); + + inline SpMat_MapMat_val operator[](const uword ii); + inline eT operator[](const uword ii) const; + + inline SpMat_MapMat_val at(const uword ii); + inline eT at(const uword ii) const; + + inline SpMat_MapMat_val operator()(const uword ii); + inline eT operator()(const uword ii) const; + + inline SpMat_MapMat_val at(const uword in_n_row, const uword); + inline eT at(const uword in_n_row, const uword) const; + + inline SpMat_MapMat_val operator()(const uword in_n_row, const uword in_n_col); + inline eT operator()(const uword in_n_row, const uword in_n_col) const; + + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + inline void randu(); + inline void randn(); + + + inline static void extract(SpMat& out, const spdiagview& in); + inline static void extract( Mat& out, const spdiagview& in); + + + friend class SpMat; + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spdiagview_meat.hpp b/src/armadillo/include/armadillo_bits/spdiagview_meat.hpp new file mode 100644 index 0000000..603cadc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spdiagview_meat.hpp @@ -0,0 +1,1073 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spdiagview +//! @{ + + +template +inline +spdiagview::~spdiagview() + { + arma_extra_debug_sigprint(); + } + + +template +arma_inline +spdiagview::spdiagview(const SpMat& in_m, const uword in_row_offset, const uword in_col_offset, const uword in_len) + : m(in_m) + , row_offset(in_row_offset) + , col_offset(in_col_offset) + , n_rows(in_len) + , n_elem(in_len) + { + arma_extra_debug_sigprint(); + } + + + +//! set a diagonal of our matrix using a diagonal from a foreign matrix +template +inline +void +spdiagview::operator= (const spdiagview& x) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + arma_debug_check( (d.n_elem != x.n_elem), "spdiagview: diagonals have incompatible lengths" ); + + SpMat& d_m = const_cast< SpMat& >(d.m); + const SpMat& x_m = x.m; + + if( (&d_m == &x_m) || ((d.row_offset == 0) && (d.col_offset == 0)) ) + { + const Mat tmp(x); + + (*this).operator=(tmp); + } + else + { + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const uword x_row_offset = x.row_offset; + const uword x_col_offset = x.col_offset; + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) = x_m.at(i + x_row_offset, i + x_col_offset); + } + } + } + + + +template +inline +void +spdiagview::operator+=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { return; } + + SpMat& t_m = const_cast< SpMat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword i=0; i < t_n_elem; ++i) + { + t_m.at(i + t_row_offset, i + t_col_offset) += val; + } + } + + + +template +inline +void +spdiagview::operator-=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { return; } + + SpMat& t_m = const_cast< SpMat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword i=0; i < t_n_elem; ++i) + { + t_m.at(i + t_row_offset, i + t_col_offset) -= val; + } + } + + + +template +inline +void +spdiagview::operator*=(const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { (*this).zeros(); return; } + + SpMat& t_m = const_cast< SpMat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword i=0; i < t_n_elem; ++i) + { + t_m.at(i + t_row_offset, i + t_col_offset) *= val; + } + } + + + +template +inline +void +spdiagview::operator/=(const eT val) + { + arma_extra_debug_sigprint(); + + SpMat& t_m = const_cast< SpMat& >(m); + + const uword t_n_elem = n_elem; + const uword t_row_offset = row_offset; + const uword t_col_offset = col_offset; + + for(uword i=0; i < t_n_elem; ++i) + { + t_m.at(i + t_row_offset, i + t_col_offset) /= val; + } + } + + + +//! set a diagonal of our matrix using data from a foreign object +template +template +inline +void +spdiagview::operator= (const Base& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + if(is_same_type< T1, Gen, gen_zeros> >::yes) + { + const Proxy P(o.get_ref()); + + arma_debug_check( (d_n_elem != P.get_n_elem()), "spdiagview: given object has incompatible size" ); + + (*this).zeros(); + + return; + } + + if(is_same_type< T1, Gen, gen_ones> >::yes) + { + const Proxy P(o.get_ref()); + + arma_debug_check( (d_n_elem != P.get_n_elem()), "spdiagview: given object has incompatible size" ); + + (*this).ones(); + + return; + } + + const quasi_unwrap U(o.get_ref()); + const Mat& x = U.M; + + const eT* x_mem = x.memptr(); + + arma_debug_check + ( + ( (d_n_elem != x.n_elem) || ((x.n_rows != 1) && (x.n_cols != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( (d_row_offset == 0) && (d_col_offset == 0) ) + { + SpMat tmp1; + + tmp1.eye(d_m.n_rows, d_m.n_cols); + + bool has_zero = false; + + for(uword i=0; i < d_n_elem; ++i) + { + const eT val = x_mem[i]; + + access::rw(tmp1.values[i]) = val; + + if(val == eT(0)) { has_zero = true; } + } + + if(has_zero) { tmp1.remove_zeros(); } + + if(tmp1.n_nonzero == 0) { (*this).zeros(); return; } + + SpMat tmp2; + + spglue_merge::diagview_merge(tmp2, d_m, tmp1); + + d_m.steal_mem(tmp2); + } + else + { + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) = x_mem[i]; + } + } + } + + + +template +template +inline +void +spdiagview::operator+=(const Base& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) ) + { + const unwrap::stored_type> tmp(P.Q); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) += x_mem[i]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) += Pea[i]; + } + } + } + + + +template +template +inline +void +spdiagview::operator-=(const Base& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) ) + { + const unwrap::stored_type> tmp(P.Q); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) -= x_mem[i]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) -= Pea[i]; + } + } + } + + + +template +template +inline +void +spdiagview::operator%=(const Base& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) ) + { + const unwrap::stored_type> tmp(P.Q); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) *= x_mem[i]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) *= Pea[i]; + } + } + } + + + +template +template +inline +void +spdiagview::operator/=(const Base& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const Proxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( (is_Mat::stored_type>::value) || (Proxy::use_at) ) + { + const unwrap::stored_type> tmp(P.Q); + const Mat& x = tmp.M; + + const eT* x_mem = x.memptr(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) /= x_mem[i]; + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < d_n_elem; ++i) + { + d_m.at(i + d_row_offset, i + d_col_offset) /= Pea[i]; + } + } + } + + + +//! set a diagonal of our matrix using data from a foreign object +template +template +inline +void +spdiagview::operator= (const SpBase& o) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U( o.get_ref() ); + const SpMat& x = U.M; + + arma_debug_check + ( + ( (n_elem != x.n_elem) || ((x.n_rows != 1) && (x.n_cols != 1)) ), + "spdiagview: given object has incompatible size" + ); + + const Mat tmp(x); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +spdiagview::operator+=(const SpBase& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const SpProxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( SpProxy::use_iterator || P.is_alias(d_m) ) + { + const SpMat tmp(P.Q); + + if(tmp.n_cols == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += tmp.at(i,0); } + } + else + if(tmp.n_rows == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += tmp.at(0,i); } + } + } + else + { + if(P.get_n_cols() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += P.at(i,0); } + } + else + if(P.get_n_rows() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) += P.at(0,i); } + } + } + } + + + +template +template +inline +void +spdiagview::operator-=(const SpBase& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const SpProxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( SpProxy::use_iterator || P.is_alias(d_m) ) + { + const SpMat tmp(P.Q); + + if(tmp.n_cols == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= tmp.at(i,0); } + } + else + if(tmp.n_rows == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= tmp.at(0,i); } + } + } + else + { + if(P.get_n_cols() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= P.at(i,0); } + } + else + if(P.get_n_rows() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) -= P.at(0,i); } + } + } + } + + + +template +template +inline +void +spdiagview::operator%=(const SpBase& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const SpProxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( SpProxy::use_iterator || P.is_alias(d_m) ) + { + const SpMat tmp(P.Q); + + if(tmp.n_cols == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= tmp.at(i,0); } + } + else + if(tmp.n_rows == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= tmp.at(0,i); } + } + } + else + { + if(P.get_n_cols() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= P.at(i,0); } + } + else + if(P.get_n_rows() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) *= P.at(0,i); } + } + } + } + + + +template +template +inline +void +spdiagview::operator/=(const SpBase& o) + { + arma_extra_debug_sigprint(); + + spdiagview& d = *this; + + SpMat& d_m = const_cast< SpMat& >(d.m); + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + const SpProxy P( o.get_ref() ); + + arma_debug_check + ( + ( (d_n_elem != P.get_n_elem()) || ((P.get_n_rows() != 1) && (P.get_n_cols() != 1)) ), + "spdiagview: given object has incompatible size" + ); + + if( SpProxy::use_iterator || P.is_alias(d_m) ) + { + const SpMat tmp(P.Q); + + if(tmp.n_cols == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= tmp.at(i,0); } + } + else + if(tmp.n_rows == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= tmp.at(0,i); } + } + } + else + { + if(P.get_n_cols() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= P.at(i,0); } + } + else + if(P.get_n_rows() == 1) + { + for(uword i=0; i < d_n_elem; ++i) { d_m.at(i + d_row_offset, i + d_col_offset) /= P.at(0,i); } + } + } + } + + + +template +inline +void +spdiagview::extract(SpMat& out, const spdiagview& d) + { + arma_extra_debug_sigprint(); + + const SpMat& d_m = d.m; + + const uword d_n_elem = d.n_elem; + const uword d_row_offset = d.row_offset; + const uword d_col_offset = d.col_offset; + + Col cache(d_n_elem, arma_nozeros_indicator()); + eT* cache_mem = cache.memptr(); + + uword d_n_nonzero = 0; + + for(uword i=0; i < d_n_elem; ++i) + { + const eT val = d_m.at(i + d_row_offset, i + d_col_offset); + + cache_mem[i] = val; + + d_n_nonzero += (val != eT(0)) ? uword(1) : uword(0); + } + + out.reserve(d_n_elem, 1, d_n_nonzero); + + uword count = 0; + for(uword i=0; i < d_n_elem; ++i) + { + const eT val = cache_mem[i]; + + if(val != eT(0)) + { + access::rw(out.row_indices[count]) = i; + access::rw(out.values[count]) = val; + ++count; + } + } + + access::rw(out.col_ptrs[0]) = 0; + access::rw(out.col_ptrs[1]) = d_n_nonzero; + } + + + +//! extract a diagonal and store it as a dense column vector +template +inline +void +spdiagview::extract(Mat& out, const spdiagview& in) + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the 'out' matrix has already been set to the correct size; + // size setting is done by either the Mat contructor or Mat::operator=() + + const SpMat& in_m = in.m; + + const uword in_n_elem = in.n_elem; + const uword in_row_offset = in.row_offset; + const uword in_col_offset = in.col_offset; + + eT* out_mem = out.memptr(); + + for(uword i=0; i < in_n_elem; ++i) + { + out_mem[i] = in_m.at(i + in_row_offset, i + in_col_offset); + } + } + + + +template +inline +SpMat_MapMat_val +spdiagview::operator[](const uword i) + { + return (const_cast< SpMat& >(m)).at(i+row_offset, i+col_offset); + } + + + +template +inline +eT +spdiagview::operator[](const uword i) const + { + return m.at(i+row_offset, i+col_offset); + } + + + +template +inline +SpMat_MapMat_val +spdiagview::at(const uword i) + { + return (const_cast< SpMat& >(m)).at(i+row_offset, i+col_offset); + } + + + +template +inline +eT +spdiagview::at(const uword i) const + { + return m.at(i+row_offset, i+col_offset); + } + + + +template +inline +SpMat_MapMat_val +spdiagview::operator()(const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "spdiagview::operator(): out of bounds" ); + + return (const_cast< SpMat& >(m)).at(i+row_offset, i+col_offset); + } + + + +template +inline +eT +spdiagview::operator()(const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "spdiagview::operator(): out of bounds" ); + + return m.at(i+row_offset, i+col_offset); + } + + + +template +inline +SpMat_MapMat_val +spdiagview::at(const uword row, const uword) + { + return (const_cast< SpMat& >(m)).at(row+row_offset, row+col_offset); + } + + + +template +inline +eT +spdiagview::at(const uword row, const uword) const + { + return m.at(row+row_offset, row+col_offset); + } + + + +template +inline +SpMat_MapMat_val +spdiagview::operator()(const uword row, const uword col) + { + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" ); + + return (const_cast< SpMat& >(m)).at(row+row_offset, row+col_offset); + } + + + +template +inline +eT +spdiagview::operator()(const uword row, const uword col) const + { + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" ); + + return m.at(row+row_offset, row+col_offset); + } + + + +template +inline +void +spdiagview::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + if(old_val == eT(0)) + { + arma_debug_warn_level(1, "spdiagview::replace(): replacement not done, as old_val = 0"); + } + else + { + Mat tmp(*this); + + tmp.replace(old_val, new_val); + + (*this).operator=(tmp); + } + } + + + +template +inline +void +spdiagview::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +spdiagview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +spdiagview::fill(const eT val) + { + arma_extra_debug_sigprint(); + + if( (row_offset == 0) && (col_offset == 0) && (m.sync_state != 1) ) + { + if(val == eT(0)) + { + SpMat tmp(arma_reserve_indicator(), m.n_rows, m.n_cols, m.n_nonzero); // worst case scenario + + typename SpMat::const_iterator it = m.begin(); + typename SpMat::const_iterator it_end = m.end(); + + uword count = 0; + + for(; it != it_end; ++it) + { + const uword row = it.row(); + const uword col = it.col(); + + if(row != col) + { + access::rw(tmp.values[count]) = (*it); + access::rw(tmp.row_indices[count]) = row; + access::rw(tmp.col_ptrs[col + 1])++; + ++count; + } + } + + for(uword i=0; i < tmp.n_cols; ++i) + { + access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i]; + } + + // quick resize without reallocating memory and copying data + access::rw( tmp.n_nonzero) = count; + access::rw( tmp.values[count]) = eT(0); + access::rw(tmp.row_indices[count]) = uword(0); + + access::rw(m).steal_mem(tmp); + } + else // val != eT(0) + { + SpMat tmp1; + + tmp1.eye(m.n_rows, m.n_cols); + + if(val != eT(1)) { tmp1 *= val; } + + SpMat tmp2; + + spglue_merge::diagview_merge(tmp2, m, tmp1); + + access::rw(m).steal_mem(tmp2); + } + } + else + { + SpMat& x = const_cast< SpMat& >(m); + + const uword local_n_elem = n_elem; + + for(uword i=0; i < local_n_elem; ++i) + { + x.at(i+row_offset, i+col_offset) = val; + } + } + } + + + +template +inline +void +spdiagview::zeros() + { + arma_extra_debug_sigprint(); + + (*this).fill(eT(0)); + } + + + +template +inline +void +spdiagview::ones() + { + arma_extra_debug_sigprint(); + + (*this).fill(eT(1)); + } + + + +template +inline +void +spdiagview::randu() + { + arma_extra_debug_sigprint(); + + SpMat& x = const_cast< SpMat& >(m); + + const uword local_n_elem = n_elem; + + for(uword i=0; i < local_n_elem; ++i) + { + x.at(i+row_offset, i+col_offset) = eT(arma_rng::randu()); + } + } + + + +template +inline +void +spdiagview::randn() + { + arma_extra_debug_sigprint(); + + SpMat& x = const_cast< SpMat& >(m); + + const uword local_n_elem = n_elem; + + for(uword i=0; i < local_n_elem; ++i) + { + x.at(i+row_offset, i+col_offset) = eT(arma_rng::randn()); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_join_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_join_bones.hpp new file mode 100644 index 0000000..93829b7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_join_bones.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_join +//! @{ + + + +class spglue_join_cols + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + + template + inline static void apply(SpMat& out, const SpBase& A, const SpBase& B, const SpBase& C); + + template + inline static void apply(SpMat& out, const SpBase& A, const SpBase& B, const SpBase& C, const SpBase& D); + }; + + + +class spglue_join_rows + { + public: + + template + struct traits + { + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + + template + inline static void apply(SpMat& out, const SpBase& A, const SpBase& B, const SpBase& C); + + template + inline static void apply(SpMat& out, const SpBase& A, const SpBase& B, const SpBase& C, const SpBase& D); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_join_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_join_meat.hpp new file mode 100644 index 0000000..4a3e244 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_join_meat.hpp @@ -0,0 +1,350 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_join +//! @{ + + + +template +inline +void +spglue_join_cols::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(X.A); + const unwrap_spmat UB(X.B); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + SpMat tmp; + + spglue_join_cols::apply_noalias(tmp, UA.M, UB.M); + + out.steal_mem(tmp); + } + else + { + spglue_join_cols::apply_noalias(out, UA.M, UB.M); + } + } + + + +template +inline +void +spglue_join_cols::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + arma_debug_check + ( + ( (A_n_cols != B_n_cols) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ), + "join_cols() / join_vert(): number of columns must be the same" + ); + + out.set_size( A_n_rows + B_n_rows, (std::max)(A_n_cols, B_n_cols) ); + + if( out.n_elem > 0 ) + { + if(A.is_empty() == false) + { + out.submat(0, 0, A_n_rows-1, out.n_cols-1) = A; + } + + if(B.is_empty() == false) + { + out.submat(A_n_rows, 0, out.n_rows-1, out.n_cols-1) = B; + } + } + } + + + +template +inline +void +spglue_join_cols::apply(SpMat& out, const SpBase& A_expr, const SpBase& B_expr, const SpBase& C_expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(A_expr.get_ref()); + const unwrap_spmat UB(B_expr.get_ref()); + const unwrap_spmat UC(C_expr.get_ref()); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + const SpMat& C = UC.M; + + const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows; + const uword out_n_cols = (std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols); + + arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + + out.set_size(out_n_rows, out_n_cols); + + if(out.n_elem == 0) { return; } + + uword row_start = 0; + uword row_end_p1 = 0; + + if(A.n_elem > 0) { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; } + + row_start = row_end_p1; + + if(B.n_elem > 0) { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; } + + row_start = row_end_p1; + + if(C.n_elem > 0) { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; } + } + + + +template +inline +void +spglue_join_cols::apply(SpMat& out, const SpBase& A_expr, const SpBase& B_expr, const SpBase& C_expr, const SpBase& D_expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(A_expr.get_ref()); + const unwrap_spmat UB(B_expr.get_ref()); + const unwrap_spmat UC(C_expr.get_ref()); + const unwrap_spmat UD(D_expr.get_ref()); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + const SpMat& C = UC.M; + const SpMat& D = UD.M; + + const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows + D.n_rows; + const uword out_n_cols = (std::max)(((std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols)), D.n_cols); + + arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + arma_debug_check( ((D.n_cols != out_n_cols) && ((D.n_rows > 0) || (D.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" ); + + out.set_size(out_n_rows, out_n_cols); + + if(out.n_elem == 0) { return; } + + uword row_start = 0; + uword row_end_p1 = 0; + + if(A.n_elem > 0) { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; } + + row_start = row_end_p1; + + if(B.n_elem > 0) { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; } + + row_start = row_end_p1; + + if(C.n_elem > 0) { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; } + + row_start = row_end_p1; + + if(D.n_elem > 0) { row_end_p1 += D.n_rows; out.rows(row_start, row_end_p1 - 1) = D; } + } + + + +template +inline +void +spglue_join_rows::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(X.A); + const unwrap_spmat UB(X.B); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + SpMat tmp; + + spglue_join_rows::apply_noalias(tmp, UA.M, UB.M); + + out.steal_mem(tmp); + } + else + { + spglue_join_rows::apply_noalias(out, UA.M, UB.M); + } + } + + + +template +inline +void +spglue_join_rows::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_n_nz = A.n_nonzero; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + const uword B_n_nz = B.n_nonzero; + + arma_debug_check + ( + ( (A_n_rows != B.n_rows) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ), + "join_rows() / join_horiz(): number of rows must be the same" + ); + + const uword C_n_rows = (std::max)(A_n_rows, B_n_rows); + const uword C_n_cols = A_n_cols + B_n_cols; + const uword C_n_nz = A_n_nz + B_n_nz; + + if( ((C_n_rows * C_n_cols) == 0) || (C_n_nz == 0) ) + { + out.zeros(C_n_rows, C_n_cols); + return; + } + + out.reserve(C_n_rows, C_n_cols, C_n_nz); + + arrayops::copy( access::rwp(out.values), A.values, A_n_nz ); + arrayops::copy( access::rwp(out.values) + A_n_nz, B.values, B_n_nz+1 ); + + arrayops::copy( access::rwp(out.row_indices), A.row_indices, A_n_nz ); + arrayops::copy( access::rwp(out.row_indices) + A_n_nz, B.row_indices, B_n_nz+1 ); + + arrayops::copy( access::rwp(out.col_ptrs), A.col_ptrs, A_n_cols ); + arrayops::copy( access::rwp(out.col_ptrs) + A_n_cols, B.col_ptrs, B_n_cols+2 ); + + arrayops::inplace_plus( access::rwp(out.col_ptrs) + A_n_cols, A_n_nz, B_n_cols+1 ); + + + // // OLD METHOD + // + // umat locs(2, C_n_nz, arma_nozeros_indicator()); + // Col vals( C_n_nz, arma_nozeros_indicator()); + // + // uword* locs_mem = locs.memptr(); + // eT* vals_mem = vals.memptr(); + // + // typename SpMat::const_iterator A_it = A.begin(); + // + // for(uword i=0; i < A_n_nz; ++i) + // { + // const uword row = A_it.row(); + // const uword col = A_it.col(); + // + // (*locs_mem) = row; locs_mem++; + // (*locs_mem) = col; locs_mem++; + // + // (*vals_mem) = (*A_it); vals_mem++; + // + // ++A_it; + // } + // + // typename SpMat::const_iterator B_it = B.begin(); + // + // for(uword i=0; i < B_n_nz; ++i) + // { + // const uword row = B_it.row(); + // const uword col = A_n_cols + B_it.col(); + // + // (*locs_mem) = row; locs_mem++; + // (*locs_mem) = col; locs_mem++; + // + // (*vals_mem) = (*B_it); vals_mem++; + // + // ++B_it; + // } + // + // // TODO: the first element of B within C will always have a larger index than the last element of A in C; + // // TODO: so, is sorting really necessary here? + // SpMat tmp(locs, vals, C_n_rows, C_n_cols, true, false); + // + // out.steal_mem(tmp); + } + + + +template +inline +void +spglue_join_rows::apply(SpMat& out, const SpBase& A_expr, const SpBase& B_expr, const SpBase& C_expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(A_expr.get_ref()); + const unwrap_spmat UB(B_expr.get_ref()); + const unwrap_spmat UC(C_expr.get_ref()); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + const SpMat& C = UC.M; + + SpMat tmp; + + spglue_join_rows::apply_noalias(tmp, A, B); + spglue_join_rows::apply_noalias(out, tmp, C); + } + + + +template +inline +void +spglue_join_rows::apply(SpMat& out, const SpBase& A_expr, const SpBase& B_expr, const SpBase& C_expr, const SpBase& D_expr) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat UA(A_expr.get_ref()); + const unwrap_spmat UB(B_expr.get_ref()); + const unwrap_spmat UC(C_expr.get_ref()); + const unwrap_spmat UD(D_expr.get_ref()); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + const SpMat& C = UC.M; + const SpMat& D = UD.M; + + SpMat AB; + SpMat ABC; + + spglue_join_rows::apply_noalias(AB, A, B); + spglue_join_rows::apply_noalias(ABC, AB, C); + spglue_join_rows::apply_noalias(out, ABC, D); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_kron_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_kron_bones.hpp new file mode 100644 index 0000000..e0d33b2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_kron_bones.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_kron +//! @{ + + + +class spglue_kron + { + public: + + template + struct traits + { + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_kron_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_kron_meat.hpp new file mode 100644 index 0000000..b45f3e7 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_kron_meat.hpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_kron +//! @{ + + + +template +inline +void +spglue_kron::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(X.A); + const unwrap_spmat UB(X.B); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + SpMat tmp; + + spglue_kron::apply_noalias(tmp, UA.M, UB.M); + + out.steal_mem(tmp); + } + else + { + spglue_kron::apply_noalias(out, UA.M, UB.M); + } + } + + + +template +inline +void +spglue_kron::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + const uword out_n_nonzero = A.n_nonzero * B.n_nonzero; + + out.reserve(A_n_rows * B_n_rows, A_n_cols * B_n_cols, out_n_nonzero); + + if(out_n_nonzero == 0) { return; } + + access::rw(out.col_ptrs[0]) = 0; + + uword count = 0; + + for(uword A_col=0; A_col < A_n_cols; ++A_col) + for(uword B_col=0; B_col < B_n_cols; ++B_col) + { + for(uword A_i = A.col_ptrs[A_col]; A_i < A.col_ptrs[A_col+1]; ++A_i) + { + const uword out_row = A.row_indices[A_i] * B_n_rows; + + const eT A_val = A.values[A_i]; + + for(uword B_i = B.col_ptrs[B_col]; B_i < B.col_ptrs[B_col+1]; ++B_i) + { + access::rw(out.values[count]) = A_val * B.values[B_i]; + access::rw(out.row_indices[count]) = out_row + B.row_indices[B_i]; + + count++; + } + } + + access::rw(out.col_ptrs[A_col * B_n_cols + B_col + 1]) = count; + } + } + + + +// template +// inline +// void +// spglue_kron::apply(SpMat& out, const SpGlue& X) +// { +// arma_extra_debug_sigprint(); +// +// typedef typename T1::elem_type eT; +// +// const unwrap_spmat UA(X.A); +// const unwrap_spmat UB(X.B); +// +// const SpMat& A = UA.M; +// const SpMat& B = UB.M; +// +// umat locs(2, A.n_nonzero * B.n_nonzero, arma_nozeros_indicator()); +// Col vals( A.n_nonzero * B.n_nonzero, arma_nozeros_indicator()); +// +// uword* locs_mem = locs.memptr(); +// eT* vals_mem = vals.memptr(); +// +// typename SpMat::const_iterator A_it = A.begin(); +// typename SpMat::const_iterator A_it_end = A.end(); +// +// typename SpMat::const_iterator B_it_start = B.begin(); +// typename SpMat::const_iterator B_it_end = B.end(); +// +// const uword B_n_rows = B.n_rows; +// const uword B_n_cols = B.n_cols; +// +// uword i = 0; +// +// while(A_it != A_it_end) +// { +// typename SpMat::const_iterator B_it = B_it_start; +// +// const uword loc_row = A_it.row() * B_n_rows; +// const uword loc_col = A_it.col() * B_n_cols; +// +// const eT A_val = (*A_it); +// +// while(B_it != B_it_end) +// { +// (*locs_mem) = loc_row + B_it.row(); locs_mem++; +// (*locs_mem) = loc_col + B_it.col(); locs_mem++; +// +// vals_mem[i] = A_val * (*B_it); +// +// ++i; +// ++B_it; +// } +// +// ++A_it; +// } +// +// out = SpMat(locs, vals, A.n_rows*B.n_rows, A.n_cols*B.n_cols); +// } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_max_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_max_bones.hpp new file mode 100644 index 0000000..156eeb5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_max_bones.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_max +//! @{ + + + +class spglue_max + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + + template + inline static void dense_sparse_max(Mat& out, const Base& X, const SpBase& Y); + + template + inline + static + typename enable_if2::no, eT>::result + elem_max(const eT& a, const eT& b); + + template + inline + static + typename enable_if2::yes, eT>::result + elem_max(const eT& a, const eT& b); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_max_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_max_meat.hpp new file mode 100644 index 0000000..4ee19a6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_max_meat.hpp @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_max +//! @{ + + + +template +inline +void +spglue_max::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy pa(X.A); + const SpProxy pb(X.B); + + const bool is_alias = pa.is_alias(out) || pb.is_alias(out); + + if(is_alias == false) + { + spglue_max::apply_noalias(out, pa, pb); + } + else + { + SpMat tmp; + + spglue_max::apply_noalias(tmp, pa, pb); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_max::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise max()"); + + const uword max_n_nonzero = pa.get_n_nonzero() + pb.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = pa.begin(); + typename SpProxy::const_iterator_type x_end = pa.end(); + + typename SpProxy::const_iterator_type y_it = pb.begin(); + typename SpProxy::const_iterator_type y_end = pb.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + eT out_val; + + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + out_val = elem_max(eT(*x_it), eT(*y_it)); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + out_val = elem_max(eT(*x_it), eT(0)); + + ++x_it; + } + else + { + out_val = elem_max(eT(*y_it), eT(0)); + + ++y_it; + + use_y_loc = true; + } + } + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_max::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +template +inline +void +spglue_max::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const SpProxy< SpMat > pa(A); + const SpProxy< SpMat > pb(B); + + spglue_max::apply_noalias(out, pa, pb); + } + + + +template +inline +void +spglue_max::dense_sparse_max(Mat& out, const Base& X, const SpBase& Y) + { + arma_extra_debug_sigprint(); + + // NOTE: this function assumes there is no aliasing between matrix 'out' and X + + const Proxy pa(X.get_ref()); + const SpProxy pb(Y.get_ref()); + + const uword n_rows = pa.get_n_rows(); + const uword n_cols = pa.get_n_cols(); + + arma_debug_assert_same_size( n_rows, n_cols, pb.get_n_rows(), pb.get_n_cols(), "element-wise max()" ); + + out.set_size(n_rows, n_cols); + + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + out.at(r,c) = elem_max(pa.at(r,c), pb.at(r,c)); + } + } + + + +//! max of non-complex elements +template +inline +typename enable_if2::no, eT>::result +spglue_max::elem_max(const eT& a, const eT& b) + { + return (std::max)(a, b); + } + + + +//! max of complex elements +template +inline +typename enable_if2::yes, eT>::result +spglue_max::elem_max(const eT& a, const eT& b) + { + return (std::abs(a) > std::abs(b)) ? a : b; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_merge_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_merge_bones.hpp new file mode 100644 index 0000000..2ef6c59 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_merge_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_merge +//! @{ + + + +class spglue_merge + { + public: + + template + inline static void subview_merge(SpSubview& sv, const SpMat& B); + + template + inline static void subview_merge(SpSubview& sv, const Mat& B); + + template + inline static void symmat_merge(SpMat& out, const SpMat& A, const SpMat& B); + + template + inline static void diagview_merge(SpMat& out, const SpMat& A, const SpMat& B); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_merge_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_merge_meat.hpp new file mode 100644 index 0000000..18339da --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_merge_meat.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_merge +//! @{ + + + +template +inline +void +spglue_merge::subview_merge(SpSubview& sv, const SpMat& B) + { + arma_extra_debug_sigprint(); + + if(sv.n_elem == 0) { return; } + + if(B.n_nonzero == 0) { sv.zeros(); return; } + + SpMat& A = access::rw(sv.m); + + const uword merge_n_nonzero = A.n_nonzero - sv.n_nonzero + B.n_nonzero; + + const uword sv_row_start = sv.aux_row1; + const uword sv_col_start = sv.aux_col1; + + const uword sv_row_end = sv.aux_row1 + sv.n_rows - 1; + const uword sv_col_end = sv.aux_col1 + sv.n_cols - 1; + + + if(A.n_nonzero == sv.n_nonzero) + { + // A is either all zeros or has all of its elements in the subview + // so the merge is equivalent to overwrite of A + + SpMat tmp(arma_reserve_indicator(), A.n_rows, A.n_cols, B.n_nonzero); + + typename SpMat::const_iterator B_it = B.begin(); + typename SpMat::const_iterator B_it_end = B.end(); + + uword tmp_count = 0; + + for(; B_it != B_it_end; ++B_it) + { + access::rw(tmp.values[tmp_count]) = (*B_it); + access::rw(tmp.row_indices[tmp_count]) = B_it.row() + sv_row_start; + access::rw(tmp.col_ptrs[B_it.col() + sv_col_start + 1])++; + ++tmp_count; + } + + for(uword i=0; i < tmp.n_cols; ++i) + { + access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i]; + } + + A.steal_mem(tmp); + + access::rw(sv.n_nonzero) = B.n_nonzero; + + return; + } + + + if(sv.n_nonzero > (A.n_nonzero/2)) + { + // A has most of its elements in the subview, + // so regenerate A with zeros in the subview region + // in order to increase merging efficiency + + sv.zeros(); + } + + + SpMat out(arma_reserve_indicator(), A.n_rows, A.n_cols, merge_n_nonzero); + + typename SpMat::const_iterator x_it = A.begin(); + typename SpMat::const_iterator x_end = A.end(); + + typename SpMat::const_iterator y_it = B.begin(); + typename SpMat::const_iterator y_end = B.end(); + + uword count = 0; + + bool x_it_valid = (x_it != x_end); + bool y_it_valid = (y_it != y_end); + + while(x_it_valid || y_it_valid) + { + eT out_val = eT(0); + + const uword x_it_row = (x_it_valid) ? uword(x_it.row()) : uword(0); + const uword x_it_col = (x_it_valid) ? uword(x_it.col()) : uword(0); + + const uword y_it_row = (y_it_valid) ? uword(sv_row_start + y_it.row()) : uword(0); + const uword y_it_col = (y_it_valid) ? uword(sv_col_start + y_it.col()) : uword(0); + + bool use_y_loc = false; + + if(x_it_valid && y_it_valid) + { + if( (x_it_row == y_it_row) && (x_it_col == y_it_col) ) + { + out_val = (*y_it); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + const bool x_inside_box = ((x_it_row >= sv_row_start) && (x_it_row <= sv_row_end)) && ((x_it_col >= sv_col_start) && (x_it_col <= sv_col_end)); + + out_val = (x_inside_box) ? eT(0) : (*x_it); + + ++x_it; + } + else + { + out_val = (*y_it); + + ++y_it; + + use_y_loc = true; + } + } + } + else + if(x_it_valid) + { + const bool x_inside_box = ((x_it_row >= sv_row_start) && (x_it_row <= sv_row_end)) && ((x_it_col >= sv_col_start) && (x_it_col <= sv_col_end)); + + out_val = (x_inside_box) ? eT(0) : (*x_it); + + ++x_it; + } + else + if(y_it_valid) + { + out_val = (*y_it); + + ++y_it; + + use_y_loc = true; + } + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + x_it_valid = (x_it != x_end); + y_it_valid = (y_it != y_end); + } + + arma_check( (count != merge_n_nonzero), "internal error: spglue_merge::subview_merge(): count != merge_n_nonzero" ); + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + A.steal_mem(out); + + access::rw(sv.n_nonzero) = B.n_nonzero; + } + + + +template +inline +void +spglue_merge::subview_merge(SpSubview& sv, const Mat& B) + { + arma_extra_debug_sigprint(); + + if(sv.n_elem == 0) { return; } + + const eT* B_memptr = B.memptr(); + const uword B_n_elem = B.n_elem; + + uword B_n_nonzero = 0; + + for(uword i=0; i < B_n_elem; ++i) + { + B_n_nonzero += (B_memptr[i] != eT(0)) ? uword(1) : uword(0); + } + + if(B_n_nonzero == 0) { sv.zeros(); return; } + + SpMat& A = access::rw(sv.m); + + const uword merge_n_nonzero = A.n_nonzero - sv.n_nonzero + B_n_nonzero; + + const uword sv_row_start = sv.aux_row1; + const uword sv_col_start = sv.aux_col1; + + const uword sv_row_end = sv.aux_row1 + sv.n_rows - 1; + const uword sv_col_end = sv.aux_col1 + sv.n_cols - 1; + + + if(A.n_nonzero == sv.n_nonzero) + { + // A is either all zeros or has all of its elements in the subview + // so the merge is equivalent to overwrite of A + + SpMat tmp(arma_reserve_indicator(), A.n_rows, A.n_cols, B_n_nonzero); + + typename Mat::const_row_col_iterator B_it = B.begin_row_col(); + typename Mat::const_row_col_iterator B_it_end = B.end_row_col(); + + uword tmp_count = 0; + + for(; B_it != B_it_end; ++B_it) + { + const eT val = (*B_it); + + if(val != eT(0)) + { + access::rw(tmp.values[tmp_count]) = val; + access::rw(tmp.row_indices[tmp_count]) = B_it.row() + sv_row_start; + access::rw(tmp.col_ptrs[B_it.col() + sv_col_start + 1])++; + ++tmp_count; + } + } + + for(uword i=0; i < tmp.n_cols; ++i) + { + access::rw(tmp.col_ptrs[i + 1]) += tmp.col_ptrs[i]; + } + + A.steal_mem(tmp); + + access::rw(sv.n_nonzero) = B_n_nonzero; + + return; + } + + + if(sv.n_nonzero > (A.n_nonzero/2)) + { + // A has most of its elements in the subview, + // so regenerate A with zeros in the subview region + // in order to increase merging efficiency + + sv.zeros(); + } + + + SpMat out(arma_reserve_indicator(), A.n_rows, A.n_cols, merge_n_nonzero); + + typename SpMat::const_iterator x_it = A.begin(); + typename SpMat::const_iterator x_end = A.end(); + + typename Mat::const_row_col_iterator y_it = B.begin_row_col(); + typename Mat::const_row_col_iterator y_end = B.end_row_col(); + + uword count = 0; + + bool x_it_valid = (x_it != x_end); + bool y_it_valid = (y_it != y_end); + + while(x_it_valid || y_it_valid) + { + eT out_val = eT(0); + + const uword x_it_row = (x_it_valid) ? uword(x_it.row()) : uword(0); + const uword x_it_col = (x_it_valid) ? uword(x_it.col()) : uword(0); + + const uword y_it_row = (y_it_valid) ? uword(sv_row_start + y_it.row()) : uword(0); + const uword y_it_col = (y_it_valid) ? uword(sv_col_start + y_it.col()) : uword(0); + + bool use_y_loc = false; + + if(x_it_valid && y_it_valid) + { + if( (x_it_row == y_it_row) && (x_it_col == y_it_col) ) + { + out_val = (*y_it); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + const bool x_inside_box = ((x_it_row >= sv_row_start) && (x_it_row <= sv_row_end)) && ((x_it_col >= sv_col_start) && (x_it_col <= sv_col_end)); + + out_val = (x_inside_box) ? eT(0) : (*x_it); + + ++x_it; + } + else + { + out_val = (*y_it); + + ++y_it; + + use_y_loc = true; + } + } + } + else + if(x_it_valid) + { + const bool x_inside_box = ((x_it_row >= sv_row_start) && (x_it_row <= sv_row_end)) && ((x_it_col >= sv_col_start) && (x_it_col <= sv_col_end)); + + out_val = (x_inside_box) ? eT(0) : (*x_it); + + ++x_it; + } + else + if(y_it_valid) + { + out_val = (*y_it); + + ++y_it; + + use_y_loc = true; + } + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + x_it_valid = (x_it != x_end); + y_it_valid = (y_it != y_end); + } + + arma_check( (count != merge_n_nonzero), "internal error: spglue_merge::subview_merge(): count != merge_n_nonzero" ); + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + A.steal_mem(out); + + access::rw(sv.n_nonzero) = B_n_nonzero; + } + + + +template +inline +void +spglue_merge::symmat_merge(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + out.reserve(A.n_rows, A.n_cols, 2*A.n_nonzero); // worst case scenario + + typename SpMat::const_iterator x_it = A.begin(); + typename SpMat::const_iterator x_end = A.end(); + + typename SpMat::const_iterator y_it = B.begin(); + typename SpMat::const_iterator y_end = B.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + eT out_val; + + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + // this can only happen on the diagonal + + out_val = (*x_it); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + out_val = (*x_it); + + ++x_it; + } + else + { + out_val = (*y_it); + + ++y_it; + + use_y_loc = true; + } + } + + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + + + +template +inline +void +spglue_merge::diagview_merge(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that B has non-zero elements only on the main diagonal + + out.reserve(A.n_rows, A.n_cols, A.n_nonzero + B.n_nonzero); // worst case scenario + + typename SpMat::const_iterator x_it = A.begin(); + typename SpMat::const_iterator x_end = A.end(); + + typename SpMat::const_iterator y_it = B.begin(); + typename SpMat::const_iterator y_end = B.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + eT out_val = eT(0); + + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + // this can only happen on the diagonal + + out_val = (*y_it); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + if(x_it_col != x_it_row) { out_val = (*x_it); } // don't take values from the main diagonal of A + + ++x_it; + } + else + { + if(y_it_col == y_it_row) { out_val = (*y_it); use_y_loc = true; } // take values only from the main diagonal of B + + ++y_it; + } + } + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_min_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_min_bones.hpp new file mode 100644 index 0000000..93e8c59 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_min_bones.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_min +//! @{ + + + +class spglue_min + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + + template + inline static void dense_sparse_min(Mat& out, const Base& X, const SpBase& Y); + + template + inline + static + typename enable_if2::no, eT>::result + elem_min(const eT& a, const eT& b); + + template + inline + static + typename enable_if2::yes, eT>::result + elem_min(const eT& a, const eT& b); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_min_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_min_meat.hpp new file mode 100644 index 0000000..cdfc197 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_min_meat.hpp @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_min +//! @{ + + + +template +inline +void +spglue_min::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy pa(X.A); + const SpProxy pb(X.B); + + const bool is_alias = pa.is_alias(out) || pb.is_alias(out); + + if(is_alias == false) + { + spglue_min::apply_noalias(out, pa, pb); + } + else + { + SpMat tmp; + + spglue_min::apply_noalias(tmp, pa, pb); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_min::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise min()"); + + const uword max_n_nonzero = pa.get_n_nonzero() + pb.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = pa.begin(); + typename SpProxy::const_iterator_type x_end = pa.end(); + + typename SpProxy::const_iterator_type y_it = pb.begin(); + typename SpProxy::const_iterator_type y_end = pb.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + eT out_val; + + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + out_val = elem_min(eT(*x_it), eT(*y_it)); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + out_val = elem_min(eT(*x_it), eT(0)); + + ++x_it; + } + else + { + out_val = elem_min(eT(*y_it), eT(0)); + + ++y_it; + + use_y_loc = true; + } + } + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_min::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +template +inline +void +spglue_min::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const SpProxy< SpMat > pa(A); + const SpProxy< SpMat > pb(B); + + spglue_min::apply_noalias(out, pa, pb); + } + + + +template +inline +void +spglue_min::dense_sparse_min(Mat& out, const Base& X, const SpBase& Y) + { + arma_extra_debug_sigprint(); + + // NOTE: this function assumes there is no aliasing between matrix 'out' and X + + const Proxy pa(X.get_ref()); + const SpProxy pb(Y.get_ref()); + + const uword n_rows = pa.get_n_rows(); + const uword n_cols = pa.get_n_cols(); + + arma_debug_assert_same_size( n_rows, n_cols, pb.get_n_rows(), pb.get_n_cols(), "element-wise min()" ); + + out.set_size(n_rows, n_cols); + + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + out.at(r,c) = elem_min(pa.at(r,c), pb.at(r,c)); + } + } + + + +// min of non-complex elements +template +inline +typename enable_if2::no, eT>::result +spglue_min::elem_min(const eT& a, const eT& b) + { + return (std::min)(a, b); + } + + + +// min of complex elements +template +inline +typename enable_if2::yes, eT>::result +spglue_min::elem_min(const eT& a, const eT& b) + { + return (std::abs(a) < std::abs(b)) ? a : b; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_minus_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_minus_bones.hpp new file mode 100644 index 0000000..39463c3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_minus_bones.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_minus +//! @{ + + + +class spglue_minus + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& result, const SpProxy& pa, const SpProxy& pb); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + }; + + + +class spglue_minus_mixed + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_minus_mixed>& expr); + + template + inline static void sparse_minus_dense(Mat< typename promote_type::result>& out, const T1& X, const T2& Y); + + template + inline static void dense_minus_sparse(Mat< typename promote_type::result>& out, const T1& X, const T2& Y); + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/spglue_minus_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_minus_meat.hpp new file mode 100644 index 0000000..1ad7161 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_minus_meat.hpp @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_minus +//! @{ + + + +template +inline +void +spglue_minus::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy pa(X.A); + const SpProxy pb(X.B); + + const bool is_alias = pa.is_alias(out) || pb.is_alias(out); + + if(is_alias == false) + { + spglue_minus::apply_noalias(out, pa, pb); + } + else + { + SpMat tmp; + + spglue_minus::apply_noalias(tmp, pa, pb); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_minus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "subtraction"); + + if(pa.get_n_nonzero() == 0) { out = pb.Q; out *= eT(-1); return; } + if(pb.get_n_nonzero() == 0) { out = pa.Q; return; } + + const uword max_n_nonzero = pa.get_n_nonzero() + pb.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = pa.begin(); + typename SpProxy::const_iterator_type x_end = pa.end(); + + typename SpProxy::const_iterator_type y_it = pb.begin(); + typename SpProxy::const_iterator_type y_end = pb.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + eT out_val; + + const uword x_it_row = x_it.row(); + const uword x_it_col = x_it.col(); + + const uword y_it_row = y_it.row(); + const uword y_it_col = y_it.col(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + out_val = (*x_it) - (*y_it); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + out_val = (*x_it); + + ++x_it; + } + else + { + out_val = -(*y_it); // take the negative + + ++y_it; + + use_y_loc = true; + } + } + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_minus::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +template +inline +void +spglue_minus::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const SpProxy< SpMat > pa(A); + const SpProxy< SpMat > pb(B); + + spglue_minus::apply_noalias(out, pa, pb); + } + + + +// + + + +template +inline +void +spglue_minus_mixed::apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_minus_mixed>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const SpMat& BB = reinterpret_cast< const SpMat& >(B); + + out = AA - BB; + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + const SpMat& AA = reinterpret_cast< const SpMat& >(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA - BB; + } + else + { + // upgrade T1 and T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA - BB; + } + } + + + +template +inline +void +spglue_minus_mixed::sparse_minus_dense(Mat< typename promote_type::result>& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const quasi_unwrap UB(Y); + const Mat& B = UB.M; + + const uword B_n_elem = B.n_elem; + const eT2* B_mem = B.memptr(); + + out.set_size(B.n_rows, B.n_cols); + + out_eT* out_mem = out.memptr(); + + for(uword i=0; i pa(X); + + arma_debug_assert_same_size( pa.get_n_rows(), pa.get_n_cols(), out.n_rows, out.n_cols, "subtraction" ); + + typename SpProxy::const_iterator_type it = pa.begin(); + typename SpProxy::const_iterator_type it_end = pa.end(); + + while(it != it_end) + { + out.at(it.row(), it.col()) += out_eT(*it); + ++it; + } + } + + + +template +inline +void +spglue_minus_mixed::dense_minus_sparse(Mat< typename promote_type::result>& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if(is_same_type::no) + { + out = conv_to< Mat >::from(X); + } + else + { + const quasi_unwrap UA(X); + + const Mat& A = UA.M; + + out = reinterpret_cast< const Mat& >(A); + } + + const SpProxy pb(Y); + + arma_debug_assert_same_size( out.n_rows, out.n_cols, pb.get_n_rows(), pb.get_n_cols(), "subtraction" ); + + typename SpProxy::const_iterator_type it = pb.begin(); + typename SpProxy::const_iterator_type it_end = pb.end(); + + while(it != it_end) + { + out.at(it.row(), it.col()) -= out_eT(*it); + ++it; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_plus_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_plus_bones.hpp new file mode 100644 index 0000000..b92cb71 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_plus_bones.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_plus +//! @{ + + + +class spglue_plus + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + }; + + + +class spglue_plus_mixed + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_plus_mixed>& expr); + + template + inline static void dense_plus_sparse(Mat< typename promote_type::result>& out, const T1& X, const T2& Y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_plus_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_plus_meat.hpp new file mode 100644 index 0000000..b8eada0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_plus_meat.hpp @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_plus +//! @{ + + + +template +inline +void +spglue_plus::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy pa(X.A); + const SpProxy pb(X.B); + + const bool is_alias = pa.is_alias(out) || pb.is_alias(out); + + if(is_alias == false) + { + spglue_plus::apply_noalias(out, pa, pb); + } + else + { + SpMat tmp; + + spglue_plus::apply_noalias(tmp, pa, pb); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_plus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "addition"); + + if(pa.get_n_nonzero() == 0) { out = pb.Q; return; } + if(pb.get_n_nonzero() == 0) { out = pa.Q; return; } + + const uword max_n_nonzero = pa.get_n_nonzero() + pb.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = pa.begin(); + typename SpProxy::const_iterator_type x_end = pa.end(); + + typename SpProxy::const_iterator_type y_it = pb.begin(); + typename SpProxy::const_iterator_type y_end = pb.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + eT out_val; + + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + out_val = (*x_it) + (*y_it); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + out_val = (*x_it); + + ++x_it; + } + else + { + out_val = (*y_it); + + ++y_it; + + use_y_loc = true; + } + } + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_plus::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +template +inline +void +spglue_plus::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const SpProxy< SpMat > pa(A); + const SpProxy< SpMat > pb(B); + + spglue_plus::apply_noalias(out, pa, pb); + } + + + +// + + + +template +inline +void +spglue_plus_mixed::apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_plus_mixed>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const SpMat& BB = reinterpret_cast< const SpMat& >(B); + + out = AA + BB; + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + const SpMat& AA = reinterpret_cast< const SpMat& >(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA + BB; + } + else + { + // upgrade T1 and T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA + BB; + } + } + + + +template +inline +void +spglue_plus_mixed::dense_plus_sparse(Mat< typename promote_type::result>& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if(is_same_type::no) + { + out = conv_to< Mat >::from(X); + } + else + { + const quasi_unwrap UA(X); + + const Mat& A = UA.M; + + out = reinterpret_cast< const Mat& >(A); + } + + const SpProxy pb(Y); + + arma_debug_assert_same_size( out.n_rows, out.n_cols, pb.get_n_rows(), pb.get_n_cols(), "addition" ); + + typename SpProxy::const_iterator_type it = pb.begin(); + typename SpProxy::const_iterator_type it_end = pb.end(); + + while(it != it_end) + { + out.at(it.row(), it.col()) += out_eT(*it); + ++it; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_relational_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_relational_bones.hpp new file mode 100644 index 0000000..f84caf4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_relational_bones.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_relational +//! @{ + + + +class spglue_rel_lt + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const mtSpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB); + }; + + + +class spglue_rel_gt + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const mtSpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB); + }; + + + +class spglue_rel_and + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const mtSpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB); + }; + + + +class spglue_rel_or + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const mtSpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_relational_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_relational_meat.hpp new file mode 100644 index 0000000..92564ab --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_relational_meat.hpp @@ -0,0 +1,545 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_relational +//! @{ + + + +template +inline +void +spglue_rel_lt::apply(SpMat& out, const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + const SpProxy PA(X.A); + const SpProxy PB(X.B); + + const bool is_alias = PA.is_alias(out) || PB.is_alias(out); + + if(is_alias == false) + { + spglue_rel_lt::apply_noalias(out, PA, PB); + } + else + { + SpMat tmp; + + spglue_rel_lt::apply_noalias(tmp, PA, PB); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_rel_lt::apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator<"); + + const uword max_n_nonzero = PA.get_n_nonzero() + PB.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = PA.begin(); + typename SpProxy::const_iterator_type x_end = PA.end(); + + typename SpProxy::const_iterator_type y_it = PB.begin(); + typename SpProxy::const_iterator_type y_end = PB.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + uword out_val; + + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + out_val = ((*x_it) < (*y_it)) ? uword(1) : uword(0); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + out_val = ((*x_it) < eT(0)) ? uword(1) : uword(0); + + ++x_it; + } + else + { + out_val = (eT(0) < (*y_it)) ? uword(1) : uword(0); + + ++y_it; + + use_y_loc = true; + } + } + + if(out_val != uword(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_lt::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +// + + + +template +inline +void +spglue_rel_gt::apply(SpMat& out, const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + const SpProxy PA(X.A); + const SpProxy PB(X.B); + + const bool is_alias = PA.is_alias(out) || PB.is_alias(out); + + if(is_alias == false) + { + spglue_rel_gt::apply_noalias(out, PA, PB); + } + else + { + SpMat tmp; + + spglue_rel_gt::apply_noalias(tmp, PA, PB); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_rel_gt::apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator>"); + + const uword max_n_nonzero = PA.get_n_nonzero() + PB.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = PA.begin(); + typename SpProxy::const_iterator_type x_end = PA.end(); + + typename SpProxy::const_iterator_type y_it = PB.begin(); + typename SpProxy::const_iterator_type y_end = PB.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + uword out_val; + + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + out_val = ((*x_it) > (*y_it)) ? uword(1) : uword(0); + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + out_val = ((*x_it) > eT(0)) ? uword(1) : uword(0); + + ++x_it; + } + else + { + out_val = (eT(0) > (*y_it)) ? uword(1) : uword(0); + + ++y_it; + + use_y_loc = true; + } + } + + if(out_val != uword(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_gt::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +// + + + +template +inline +void +spglue_rel_and::apply(SpMat& out, const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + const SpProxy PA(X.A); + const SpProxy PB(X.B); + + const bool is_alias = PA.is_alias(out) || PB.is_alias(out); + + if(is_alias == false) + { + spglue_rel_and::apply_noalias(out, PA, PB); + } + else + { + SpMat tmp; + + spglue_rel_and::apply_noalias(tmp, PA, PB); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_rel_and::apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator&&"); + + if( (PA.get_n_nonzero() == 0) || (PB.get_n_nonzero() == 0) ) + { + out.zeros(PA.get_n_rows(), PA.get_n_cols()); + return; + } + + const uword max_n_nonzero = (std::min)(PA.get_n_nonzero(), PB.get_n_nonzero()); + + // Resize memory to upper bound + out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = PA.begin(); + typename SpProxy::const_iterator_type x_end = PA.end(); + + typename SpProxy::const_iterator_type y_it = PB.begin(); + typename SpProxy::const_iterator_type y_end = PB.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + const uword x_it_row = x_it.row(); + const uword x_it_col = x_it.col(); + + const uword y_it_row = y_it.row(); + const uword y_it_col = y_it.col(); + + if(x_it == y_it) + { + access::rw(out.values[count]) = uword(1); + + access::rw(out.row_indices[count]) = x_it_row; + access::rw(out.col_ptrs[x_it_col + 1])++; + ++count; + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + ++x_it; + } + else + { + ++y_it; + } + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_and::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +// + + + +template +inline +void +spglue_rel_or::apply(SpMat& out, const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + const SpProxy PA(X.A); + const SpProxy PB(X.B); + + const bool is_alias = PA.is_alias(out) || PB.is_alias(out); + + if(is_alias == false) + { + spglue_rel_or::apply_noalias(out, PA, PB); + } + else + { + SpMat tmp; + + spglue_rel_or::apply_noalias(tmp, PA, PB); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_rel_or::apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator||"); + + const uword max_n_nonzero = PA.get_n_nonzero() + PB.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = PA.begin(); + typename SpProxy::const_iterator_type x_end = PA.end(); + + typename SpProxy::const_iterator_type y_it = PB.begin(); + typename SpProxy::const_iterator_type y_end = PB.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + ++x_it; + } + else + { + ++y_it; + + use_y_loc = true; + } + } + + access::rw(out.values[count]) = uword(1); + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_or::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_schur_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_schur_bones.hpp new file mode 100644 index 0000000..605de3a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_schur_bones.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_schur +//! @{ + + + +class spglue_schur + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + }; + + + +class spglue_schur_misc + : public traits_glue_or + { + public: + + template + inline static void dense_schur_sparse(SpMat& out, const T1& x, const T2& y); + }; + + + +class spglue_schur_mixed + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_schur_mixed>& expr); + + template + inline static void dense_schur_sparse(SpMat< typename promote_type::result>& out, const T1& X, const T2& Y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_schur_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_schur_meat.hpp new file mode 100644 index 0000000..1ad8a8f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_schur_meat.hpp @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_schur +//! @{ + + + +template +inline +void +spglue_schur::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy pa(X.A); + const SpProxy pb(X.B); + + const bool is_alias = pa.is_alias(out) || pb.is_alias(out); + + if(is_alias == false) + { + spglue_schur::apply_noalias(out, pa, pb); + } + else + { + SpMat tmp; + + spglue_schur::apply_noalias(tmp, pa, pb); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_schur::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication"); + + if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) ) + { + out.zeros(pa.get_n_rows(), pa.get_n_cols()); + return; + } + + const uword max_n_nonzero = (std::min)(pa.get_n_nonzero(), pb.get_n_nonzero()); + + // Resize memory to upper bound + out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = pa.begin(); + typename SpProxy::const_iterator_type x_end = pa.end(); + + typename SpProxy::const_iterator_type y_it = pb.begin(); + typename SpProxy::const_iterator_type y_end = pb.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + const uword x_it_row = x_it.row(); + const uword x_it_col = x_it.col(); + + const uword y_it_row = y_it.row(); + const uword y_it_col = y_it.col(); + + if(x_it == y_it) + { + const eT out_val = (*x_it) * (*y_it); + + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + access::rw(out.row_indices[count]) = x_it_row; + access::rw(out.col_ptrs[x_it_col + 1])++; + ++count; + } + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + ++x_it; + } + else + { + ++y_it; + } + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_schur::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +template +inline +void +spglue_schur::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const SpProxy< SpMat > pa(A); + const SpProxy< SpMat > pb(B); + + spglue_schur::apply_noalias(out, pa, pb); + } + + + +// +// +// + + + +template +inline +void +spglue_schur_misc::dense_schur_sparse(SpMat& out, const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy pa(x); + const SpProxy pb(y); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication"); + + const uword max_n_nonzero = pb.get_n_nonzero(); + + // Resize memory to upper bound. + out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); + + uword count = 0; + + typename SpProxy::const_iterator_type it = pb.begin(); + typename SpProxy::const_iterator_type it_end = pb.end(); + + while(it != it_end) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + const eT val = (*it) * pa.at(it_row, it_col); + + if(val != eT(0)) + { + access::rw( out.values[count]) = val; + access::rw( out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + + ++it; + + arma_check( (count > max_n_nonzero), "internal error: spglue_schur_misc::dense_schur_sparse(): count > max_n_nonzero" ); + } + + // Fix column pointers. + for(uword c = 1; c <= out.n_cols; ++c) + { + access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +// + + + +template +inline +void +spglue_schur_mixed::apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_schur_mixed>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const SpMat& BB = reinterpret_cast< const SpMat& >(B); + + out = AA % BB; + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + const SpMat& AA = reinterpret_cast< const SpMat& >(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA % BB; + } + else + { + // upgrade T1 and T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA % BB; + } + } + + + +template +inline +void +spglue_schur_mixed::dense_schur_sparse(SpMat< typename promote_type::result>& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + const Proxy pa(X); + const SpProxy pb(Y); + + arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication"); + + // count new size + uword new_n_nonzero = 0; + + typename SpProxy::const_iterator_type it = pb.begin(); + typename SpProxy::const_iterator_type it_end = pb.end(); + + while(it != it_end) + { + if( (out_eT(*it) * out_eT(pa.at(it.row(), it.col()))) != out_eT(0) ) { ++new_n_nonzero; } + + ++it; + } + + // Resize memory accordingly. + out.reserve(pa.get_n_rows(), pa.get_n_cols(), new_n_nonzero); + + uword count = 0; + + typename SpProxy::const_iterator_type it2 = pb.begin(); + + while(it2 != it_end) + { + const uword it2_row = it2.row(); + const uword it2_col = it2.col(); + + const out_eT val = out_eT(*it2) * out_eT(pa.at(it2_row, it2_col)); + + if(val != out_eT(0)) + { + access::rw( out.values[count]) = val; + access::rw( out.row_indices[count]) = it2_row; + access::rw(out.col_ptrs[it2_col + 1])++; + ++count; + } + + ++it2; + } + + // Fix column pointers. + for(uword c = 1; c <= out.n_cols; ++c) + { + access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1]; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_times_bones.hpp b/src/armadillo/include/armadillo_bits/spglue_times_bones.hpp new file mode 100644 index 0000000..63c21b3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_times_bones.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_times +//! @{ + + + +class spglue_times + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(SpMat& out, const SpGlue& X); + + template + inline static void apply(SpMat& out, const SpGlue,T2,spglue_times>& X); + + template + inline static void apply_noalias(SpMat& c, const SpMat& x, const SpMat& y); + }; + + + +class spglue_times_mixed + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_times_mixed>& expr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spglue_times_meat.hpp b/src/armadillo/include/armadillo_bits/spglue_times_meat.hpp new file mode 100644 index 0000000..852dcad --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spglue_times_meat.hpp @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spglue_times +//! @{ + + + +template +inline +void +spglue_times::apply(SpMat& out, const SpGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(X.A); + const unwrap_spmat UB(X.B); + + const bool is_alias = (UA.is_alias(out) || UB.is_alias(out)); + + if(is_alias == false) + { + spglue_times::apply_noalias(out, UA.M, UB.M); + } + else + { + SpMat tmp; + + spglue_times::apply_noalias(tmp, UA.M, UB.M); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_times::apply(SpMat& out, const SpGlue,T2,spglue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(X.A.m); + const unwrap_spmat UB(X.B); + + const bool is_alias = (UA.is_alias(out) || UB.is_alias(out)); + + if(is_alias == false) + { + spglue_times::apply_noalias(out, UA.M, UB.M); + } + else + { + SpMat tmp; + + spglue_times::apply_noalias(tmp, UA.M, UB.M); + + out.steal_mem(tmp); + } + + out *= X.A.aux; + } + + + +template +inline +void +spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y) + { + arma_extra_debug_sigprint(); + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + const uword y_n_rows = y.n_rows; + const uword y_n_cols = y.n_cols; + + arma_debug_assert_mul_size(x_n_rows, x_n_cols, y_n_rows, y_n_cols, "matrix multiplication"); + + // First we must determine the structure of the new matrix (column pointers). + // This follows the algorithm described in 'Sparse Matrix Multiplication + // Package (SMMP)' (R.E. Bank and C.C. Douglas, 2001). Their description of + // "SYMBMM" does not include anything about memory allocation. In addition it + // does not consider that there may be elements which space may be allocated + // for but which evaluate to zero anyway. So we have to modify the algorithm + // to work that way. For the "SYMBMM" implementation we will not determine + // the row indices but instead just the column pointers. + + //SpMat c(x_n_rows, y_n_cols); // Initializes col_ptrs to 0. + c.zeros(x_n_rows, y_n_cols); + + //if( (x.n_elem == 0) || (y.n_elem == 0) ) { return; } + if( (x.n_nonzero == 0) || (y.n_nonzero == 0) ) { return; } + + // Auxiliary storage which denotes when items have been found. + podarray index(x_n_rows); + index.fill(x_n_rows); // Fill with invalid links. + + typename SpMat::const_iterator y_it = y.begin(); + typename SpMat::const_iterator y_end = y.end(); + + // SYMBMM: calculate column pointers for resultant matrix to obtain a good + // upper bound on the number of nonzero elements. + uword cur_col_length = 0; + uword last_ind = x_n_rows + 1; + do + { + const uword y_it_row = y_it.row(); + + // Look through the column that this point (*y_it) could affect. + typename SpMat::const_iterator x_it = x.begin_col_no_sync(y_it_row); + + while(x_it.col() == y_it_row) + { + const uword x_it_row = x_it.row(); + + // A point at x(i, j) and y(j, k) implies a point at c(i, k). + if(index[x_it_row] == x_n_rows) + { + index[x_it_row] = last_ind; + last_ind = x_it_row; + ++cur_col_length; + } + + ++x_it; + } + + const uword old_col = y_it.col(); + ++y_it; + + // See if column incremented. + if(old_col != y_it.col()) + { + // Set column pointer (this is not a cumulative count; that is done later). + access::rw(c.col_ptrs[old_col + 1]) = cur_col_length; + cur_col_length = 0; + + // Return index markers to zero. Use last_ind for traversal. + while(last_ind != x_n_rows + 1) + { + const uword tmp = index[last_ind]; + index[last_ind] = x_n_rows; + last_ind = tmp; + } + } + } + while(y_it != y_end); + + // Accumulate column pointers. + for(uword i = 0; i < c.n_cols; ++i) + { + access::rw(c.col_ptrs[i + 1]) += c.col_ptrs[i]; + } + + // Now that we know a decent bound on the number of nonzero elements, + // allocate the memory and fill it. + + const uword max_n_nonzero = c.col_ptrs[c.n_cols]; + + c.mem_resize(max_n_nonzero); + + // Now the implementation of the NUMBMM algorithm. + uword cur_pos = 0; // Current position in c matrix. + podarray sums(x_n_rows); // Partial sums. + sums.zeros(); + + podarray sorted_indices(x_n_rows); // upper bound + + // last_ind is already set to x_n_rows, and cur_col_length is already set to 0. + // We will loop through all columns as necessary. + uword cur_col = 0; + while(cur_col < c.n_cols) + { + // Skip to next column with elements in it. + while((cur_col < c.n_cols) && (c.col_ptrs[cur_col] == c.col_ptrs[cur_col + 1])) + { + // Update current column pointer to actual number of nonzero elements up + // to this point. + access::rw(c.col_ptrs[cur_col]) = cur_pos; + ++cur_col; + } + + if(cur_col == c.n_cols) { break; } + + // Update current column pointer. + access::rw(c.col_ptrs[cur_col]) = cur_pos; + + // Check all elements in this column. + typename SpMat::const_iterator y_col_it = y.begin_col_no_sync(cur_col); + + while(y_col_it.col() == cur_col) + { + const uword y_col_it_row = y_col_it.row(); + + // Check all elements in the column of the other matrix corresponding to + // the row of this column. + typename SpMat::const_iterator x_col_it = x.begin_col_no_sync(y_col_it_row); + + const eT y_value = (*y_col_it); + + while(x_col_it.col() == y_col_it_row) + { + const uword x_col_it_row = x_col_it.row(); + + // A point at x(i, j) and y(j, k) implies a point at c(i, k). + // Add to partial sum. + const eT x_value = (*x_col_it); + sums[x_col_it_row] += (x_value * y_value); + + // Add point if it hasn't already been marked. + if(index[x_col_it_row] == x_n_rows) + { + index[x_col_it_row] = last_ind; + last_ind = x_col_it_row; + } + + ++x_col_it; + } + + ++y_col_it; + } + + // Now sort the indices that were used in this column. + uword cur_index = 0; + while(last_ind != x_n_rows + 1) + { + const uword tmp = last_ind; + + // Check that it wasn't a "fake" nonzero element. + if(sums[tmp] != eT(0)) + { + // Assign to next open position. + sorted_indices[cur_index] = tmp; + ++cur_index; + } + + last_ind = index[tmp]; + index[tmp] = x_n_rows; + } + + // Now sort the indices. + if(cur_index != 0) + { + op_sort::direct_sort_ascending(sorted_indices.memptr(), cur_index); + + for(uword k = 0; k < cur_index; ++k) + { + const uword row = sorted_indices[k]; + access::rw(c.row_indices[cur_pos]) = row; + access::rw(c.values[cur_pos]) = sums[row]; + sums[row] = eT(0); + ++cur_pos; + } + } + + // Move to next column. + ++cur_col; + } + + // Update last column pointer and resize to actual memory size. + + // access::rw(c.col_ptrs[c.n_cols]) = cur_pos; + // c.mem_resize(cur_pos); + + access::rw(c.col_ptrs[c.n_cols]) = cur_pos; + + if(cur_pos < max_n_nonzero) { c.mem_resize(cur_pos); } + } + + + +// +// +// + + + +template +inline +void +spglue_times_mixed::apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_times_mixed>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename eT_promoter::eT out_eT; + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const SpMat& BB = reinterpret_cast< const SpMat& >(B); + + out = AA * BB; + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + const SpMat& AA = reinterpret_cast< const SpMat& >(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA * BB; + } + else + { + // upgrade T1 and T2 + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + out = AA * BB; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_diagmat_bones.hpp b/src/armadillo/include/armadillo_bits/spop_diagmat_bones.hpp new file mode 100644 index 0000000..41b1ae3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_diagmat_bones.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_diagmat +//! @{ + + +class spop_diagmat + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + + template + inline static void apply_noalias(SpMat& out, const SpBase& expr); + + template + inline static void apply_noalias(SpMat& out, const SpGlue& expr); + + template + inline static void apply_noalias(SpMat& out, const SpGlue& expr); + + template + inline static void apply_noalias(SpMat& out, const SpGlue& expr); + + template + inline static void apply_noalias(SpMat& out, const SpGlue& expr); + + }; + + + +class spop_diagmat2 + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + + template + inline static void apply_noalias(SpMat& out, const SpMat& X, const uword row_offset, const uword col_offset); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_diagmat_meat.hpp b/src/armadillo/include/armadillo_bits/spop_diagmat_meat.hpp new file mode 100644 index 0000000..a6f9faf --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_diagmat_meat.hpp @@ -0,0 +1,456 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_diagmat +//! @{ + + + +template +inline +void +spop_diagmat::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(in.is_alias(out) == false) + { + spop_diagmat::apply_noalias(out, in.m); + } + else + { + SpMat tmp; + + spop_diagmat::apply_noalias(tmp, in.m); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spop_diagmat::apply_noalias(SpMat& out, const SpBase& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(expr.get_ref()); + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + const uword P_n_nz = P.get_n_nonzero(); + + const bool P_is_vec = (P_n_rows == 1) || (P_n_cols == 1); + + if(P_is_vec) // generate a diagonal matrix out of a vector + { + const uword N = (P_n_rows == 1) ? P_n_cols : P_n_rows; + + out.zeros(N, N); + + if(P_n_nz == 0) { return; } + + typename SpProxy::const_iterator_type it = P.begin(); + + if(P_n_cols == 1) + { + for(uword i=0; i < P_n_nz; ++i) + { + const uword row = it.row(); + + out.at(row,row) = (*it); + + ++it; + } + } + else + if(P_n_rows == 1) + { + for(uword i=0; i < P_n_nz; ++i) + { + const uword col = it.col(); + + out.at(col,col) = (*it); + + ++it; + } + } + } + else // generate a diagonal matrix out of a matrix + { + out.zeros(P_n_rows, P_n_cols); + + const uword N = (std::min)(P_n_rows, P_n_cols); + + if( (is_SpMat::stored_type>::value) && (P_n_nz >= 5*N) ) + { + const unwrap_spmat::stored_type> U(P.Q); + + const SpMat& X = U.M; + + for(uword i=0; i < N; ++i) + { + const eT val = X.at(i,i); // use binary search + + if(val != eT(0)) { out.at(i,i) = val; } + } + } + else + { + if(P_n_nz == 0) { return; } + + typename SpProxy::const_iterator_type it = P.begin(); + + for(uword i=0; i < P_n_nz; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + if(row == col) { out.at(row,row) = (*it); } + + ++it; + } + } + } + } + + + +template +inline +void +spop_diagmat::apply_noalias(SpMat& out, const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "addition"); + + const bool is_vec = (A.n_rows == 1) || (A.n_cols == 1); + + if(is_vec) // generate a diagonal matrix out of a vector + { + const uword N = (A.n_rows == 1) ? A.n_cols : A.n_rows; + + out.zeros(N,N); + + if(A.n_rows == 1) + { + for(uword i=0; i < N; ++i) { out.at(i,i) = A.at(0,i) + B.at(0,i); } + } + else + { + for(uword i=0; i < N; ++i) { out.at(i,i) = A.at(i,0) + B.at(i,0); } + } + } + else // generate a diagonal matrix out of a matrix + { + SpMat AA; spop_diagmat::apply_noalias(AA, A); + SpMat BB; spop_diagmat::apply_noalias(BB, B); + + out = AA + BB; + } + } + + + +template +inline +void +spop_diagmat::apply_noalias(SpMat& out, const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "subtraction"); + + const bool is_vec = (A.n_rows == 1) || (A.n_cols == 1); + + if(is_vec) // generate a diagonal matrix out of a vector + { + const uword N = (A.n_rows == 1) ? A.n_cols : A.n_rows; + + out.zeros(N,N); + + if(A.n_rows == 1) + { + for(uword i=0; i < N; ++i) { out.at(i,i) = A.at(0,i) - B.at(0,i); } + } + else + { + for(uword i=0; i < N; ++i) { out.at(i,i) = A.at(i,0) - B.at(i,0); } + } + } + else // generate a diagonal matrix out of a matrix + { + SpMat AA; spop_diagmat::apply_noalias(AA, A); + SpMat BB; spop_diagmat::apply_noalias(BB, B); + + out = AA - BB; + } + } + + + +template +inline +void +spop_diagmat::apply_noalias(SpMat& out, const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "element-wise multiplication"); + + const bool is_vec = (A.n_rows == 1) || (A.n_cols == 1); + + if(is_vec) // generate a diagonal matrix out of a vector + { + const uword N = (A.n_rows == 1) ? A.n_cols : A.n_rows; + + out.zeros(N,N); + + if(A.n_rows == 1) + { + for(uword i=0; i < N; ++i) { out.at(i,i) = A.at(0,i) * B.at(0,i); } + } + else + { + for(uword i=0; i < N; ++i) { out.at(i,i) = A.at(i,0) * B.at(i,0); } + } + } + else // generate a diagonal matrix out of a matrix + { + SpMat AA; spop_diagmat::apply_noalias(AA, A); + SpMat BB; spop_diagmat::apply_noalias(BB, B); + + out = AA % BB; + } + } + + + +template +inline +void +spop_diagmat::apply_noalias(SpMat& out, const SpGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(expr.A); + const unwrap_spmat UB(expr.B); + + const SpMat& A = UA.M; + const SpMat& B = UB.M; + + arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const uword C_n_rows = A.n_rows; + const uword C_n_cols = B.n_cols; + + const bool is_vec = (C_n_rows == 1) || (C_n_cols == 1); + + if(is_vec) // generate a diagonal matrix out of a vector + { + const SpMat C = A*B; + + spop_diagmat::apply_noalias(out, C); + } + else // generate a diagonal matrix out of a matrix + { + const uword N = (std::min)(C_n_rows, C_n_cols); + + if( (A.n_nonzero >= 5*N) || (B.n_nonzero >= 5*N) ) + { + out.zeros(C_n_rows, C_n_cols); + + for(uword k=0; k < N; ++k) + { + typename SpMat::const_col_iterator B_it = B.begin_col_no_sync(k); + typename SpMat::const_col_iterator B_it_end = B.end_col_no_sync(k); + + eT acc = eT(0); + + while(B_it != B_it_end) + { + const eT B_val = (*B_it); + const uword i = B_it.row(); + + acc += A.at(k,i) * B_val; + + ++B_it; + } + + out(k,k) = acc; + } + } + else + { + const SpMat C = A*B; + + spop_diagmat::apply_noalias(out, C); + } + } + } + + + +// +// + + + +template +inline +void +spop_diagmat2::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword row_offset = in.aux_uword_a; + const uword col_offset = in.aux_uword_b; + + const unwrap_spmat U(in.m); + + if(U.is_alias(out)) + { + SpMat tmp; + + spop_diagmat2::apply_noalias(tmp, U.M, row_offset, col_offset); + + out.steal_mem(tmp); + } + else + { + spop_diagmat2::apply_noalias(out, U.M, row_offset, col_offset); + } + } + + + +template +inline +void +spop_diagmat2::apply_noalias(SpMat& out, const SpMat& X, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = X.n_rows; + const uword n_cols = X.n_cols; + const uword n_elem = X.n_elem; + + if(n_elem == 0) { out.reset(); return; } + + const bool X_is_vec = (n_rows == 1) || (n_cols == 1); + + if(X_is_vec) // generate a diagonal matrix out of a vector + { + const uword n_pad = (std::max)(row_offset, col_offset); + + out.zeros(n_elem + n_pad, n_elem + n_pad); + + const uword X_n_nz = X.n_nonzero; + + if(X_n_nz == 0) { return; } + + typename SpMat::const_iterator it = X.begin(); + + if(n_cols == 1) + { + for(uword i=0; i < X_n_nz; ++i) + { + const uword row = it.row(); + + out.at(row_offset + row, col_offset + row) = (*it); + + ++it; + } + } + else + if(n_rows == 1) + { + for(uword i=0; i < X_n_nz; ++i) + { + const uword col = it.col(); + + out.at(row_offset + col, col_offset + col) = (*it); + + ++it; + } + } + } + else // generate a diagonal matrix out of a matrix + { + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diagmat(): requested diagonal out of bounds" + ); + + out.zeros(n_rows, n_cols); + + if(X.n_nonzero == 0) { return; } + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i + struct traits + { + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + inline static void apply(SpMat& out, const SpOp& in, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply(SpMat& out, const SpOp& in, const typename arma_cx_only::result* junk = nullptr); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_htrans_meat.hpp b/src/armadillo/include/armadillo_bits/spop_htrans_meat.hpp new file mode 100644 index 0000000..624d399 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_htrans_meat.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_htrans +//! @{ + + + +template +inline +void +spop_htrans::apply(SpMat& out, const SpOp& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + spop_strans::apply(out, in); + } + + + +template +inline +void +spop_htrans::apply(SpMat& out, const SpOp& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + spop_strans::apply(out, in); + + const uword N = out.n_nonzero; + + for(uword i=0; i + inline static void apply(SpMat& out, const SpOp& in); + + // + + template + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static typename T1::elem_type vector_max(const T1& X, const typename arma_not_cx::result* junk = nullptr); + + template + inline static typename arma_not_cx::result max(const SpBase& X); + + template + inline static typename arma_not_cx::result max_with_index(const SpProxy& P, uword& index_of_max_val); + + // + + template + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + template + inline static typename T1::elem_type vector_max(const T1& X, const typename arma_cx_only::result* junk = nullptr); + + template + inline static typename arma_cx_only::result max(const SpBase& X); + + template + inline static typename arma_cx_only::result max_with_index(const SpProxy& P, uword& index_of_max_val); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_max_meat.hpp b/src/armadillo/include/armadillo_bits/spop_max_meat.hpp new file mode 100644 index 0000000..8f40a0e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_max_meat.hpp @@ -0,0 +1,686 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_max +//! @{ + + + +template +inline +void +spop_max::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "max(): parameter 'dim' must be 0 or 1" ); + + const SpProxy p(in.m); + + const uword p_n_rows = p.get_n_rows(); + const uword p_n_cols = p.get_n_cols(); + + if( (p_n_rows == 0) || (p_n_cols == 0) || (p.get_n_nonzero() == 0) ) + { + if(dim == 0) { out.zeros((p_n_rows > 0) ? 1 : 0, p_n_cols); } + if(dim == 1) { out.zeros(p_n_rows, (p_n_cols > 0) ? 1 : 0); } + + return; + } + + spop_max::apply_proxy(out, p, dim); + } + + + +template +inline +void +spop_max::apply_proxy + ( + SpMat& out, + const SpProxy& p, + const uword dim, + const typename arma_not_cx::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + const uword p_n_cols = p.get_n_cols(); + const uword p_n_rows = p.get_n_rows(); + + if(dim == 0) // find the maximum in each column + { + Row value(p_n_cols, arma_zeros_indicator()); + urowvec count(p_n_cols, arma_zeros_indicator()); + + while(it != it_end) + { + const uword col = it.col(); + + value[col] = (count[col] == 0) ? (*it) : (std::max)(value[col], (*it)); + count[col]++; + ++it; + } + + for(uword col=0; col value(p_n_rows, arma_zeros_indicator()); + ucolvec count(p_n_rows, arma_zeros_indicator()); + + while(it != it_end) + { + const uword row = it.row(); + + value[row] = (count[row] == 0) ? (*it) : (std::max)(value[row], (*it)); + count[row]++; + ++it; + } + + for(uword row=0; row +inline +typename T1::elem_type +spop_max::vector_max + ( + const T1& x, + const typename arma_not_cx::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const SpProxy p(x); + + if(p.get_n_elem() == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + if(p.get_n_nonzero() == 0) { return eT(0); } + + if(SpProxy::use_iterator == false) + { + // direct access of values + if(p.get_n_nonzero() == p.get_n_elem()) + { + return op_max::direct_max(p.get_values(), p.get_n_nonzero()); + } + else + { + return (std::max)(eT(0), op_max::direct_max(p.get_values(), p.get_n_nonzero())); + } + } + else + { + // use iterator + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + eT result = (*it); + ++it; + + while(it != it_end) + { + if((*it) > result) { result = (*it); } + + ++it; + } + + if(p.get_n_nonzero() == p.get_n_elem()) + { + return result; + } + else + { + return (std::max)(eT(0), result); + } + } + } + + + +template +inline +typename arma_not_cx::result +spop_max::max(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + eT max_val = priv::most_neg(); + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + if((*it) > max_val) { max_val = *it; } + + ++it; + } + } + else + { + // We can do direct access of the values, row_indices, and col_ptrs. + // We don't need the location of the max value, so we can just call out to + // other functions... + max_val = op_max::direct_max(P.get_values(), n_nonzero); + } + + if(n_elem == n_nonzero) + { + return max_val; + } + else + { + return (std::max)(eT(0), max_val); + } + } + + + +template +inline +typename arma_not_cx::result +spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + const uword n_rows = P.get_n_rows(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + index_of_max_val = uword(0); + + return Datum::nan; + } + + eT max_val = priv::most_neg(); + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + if((*it) > max_val) + { + max_val = *it; + index_of_max_val = it.row() + it.col() * n_rows; + } + + ++it; + } + } + else + { + // We can do direct access. + max_val = op_max::direct_max(P.get_values(), n_nonzero, index_of_max_val); + + // Convert to actual position in matrix. + const uword row = P.get_row_indices()[index_of_max_val]; + uword col = 0; + while(P.get_col_ptrs()[++col] <= index_of_max_val) { } + index_of_max_val = (col - 1) * n_rows + row; + } + + + if(n_elem != n_nonzero) + { + max_val = (std::max)(eT(0), max_val); + + // If the max_val is a nonzero element, we need its actual position in the matrix. + if(max_val == eT(0)) + { + // Find first zero element. + uword last_row = 0; + uword last_col = 0; + + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + // Have we moved more than one position from the last place? + if((it.col() == last_col) && (it.row() - last_row > 1)) + { + index_of_max_val = it.col() * n_rows + last_row + 1; + break; + } + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + { + index_of_max_val = last_col * n_rows + last_row + 1; + break; + } + else if((it.col() == last_col + 1) && (it.row() > 0)) + { + index_of_max_val = it.col() * n_rows; + break; + } + else if(it.col() > last_col + 1) + { + index_of_max_val = (last_col + 1) * n_rows; + break; + } + + last_row = it.row(); + last_col = it.col(); + ++it; + } + } + } + + return max_val; + } + + + +template +inline +void +spop_max::apply_proxy + ( + SpMat& out, + const SpProxy& p, + const uword dim, + const typename arma_cx_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + const uword p_n_cols = p.get_n_cols(); + const uword p_n_rows = p.get_n_rows(); + + if(dim == 0) // find the maximum in each column + { + Row rawval(p_n_cols, arma_zeros_indicator()); + Row< T> absval(p_n_cols, arma_zeros_indicator()); + + while(it != it_end) + { + const uword col = it.col(); + + const eT& v = (*it); + const T a = std::abs(v); + + if(a > absval[col]) + { + absval[col] = a; + rawval[col] = v; + } + + ++it; + } + + out = rawval; + } + else + if(dim == 1) // find the maximum in each row + { + Col rawval(p_n_rows, arma_zeros_indicator()); + Col< T> absval(p_n_rows, arma_zeros_indicator()); + + while(it != it_end) + { + const uword row = it.row(); + + const eT& v = (*it); + const T a = std::abs(v); + + if(a > absval[row]) + { + absval[row] = a; + rawval[row] = v; + } + + ++it; + } + + out = rawval; + } + } + + + +template +inline +typename T1::elem_type +spop_max::vector_max + ( + const T1& x, + const typename arma_cx_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const SpProxy p(x); + + if(p.get_n_elem() == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + if(p.get_n_nonzero() == 0) { return eT(0); } + + if(SpProxy::use_iterator == false) + { + // direct access of values + if(p.get_n_nonzero() == p.get_n_elem()) + { + return op_max::direct_max(p.get_values(), p.get_n_nonzero()); + } + else + { + const eT val1 = eT(0); + const eT val2 = op_max::direct_max(p.get_values(), p.get_n_nonzero()); + + return ( std::abs(val1) >= std::abs(val2) ) ? val1 : val2; + } + } + else + { + // use iterator + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + eT best_val_orig = *it; + T best_val_abs = std::abs(best_val_orig); + + ++it; + + while(it != it_end) + { + eT val_orig = *it; + T val_abs = std::abs(val_orig); + + if(val_abs > best_val_abs) + { + best_val_abs = val_abs; + best_val_orig = val_orig; + } + + ++it; + } + + if(p.get_n_nonzero() == p.get_n_elem()) + { + return best_val_orig; + } + else + { + const eT val1 = eT(0); + + return ( std::abs(val1) >= best_val_abs ) ? val1 : best_val_orig; + } + } + } + + + +template +inline +typename arma_cx_only::result +spop_max::max(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const SpProxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + T max_val = priv::most_neg(); + eT ret_val; + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + const T tmp_val = std::abs(*it); + + if(tmp_val > max_val) + { + max_val = tmp_val; + ret_val = *it; + } + + ++it; + } + } + else + { + // We can do direct access of the values, row_indices, and col_ptrs. + // We don't need the location of the max value, so we can just call out to + // other functions... + ret_val = op_max::direct_max(P.get_values(), n_nonzero); + max_val = std::abs(ret_val); + } + + if(n_elem == n_nonzero) + { + return max_val; + } + else + { + return (T(0) > max_val) ? eT(0) : ret_val; + } + } + + + +template +inline +typename arma_cx_only::result +spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + const uword n_rows = P.get_n_rows(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + index_of_max_val = uword(0); + + return Datum::nan; + } + + T max_val = priv::most_neg(); + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + const T tmp_val = std::abs(*it); + + if(tmp_val > max_val) + { + max_val = tmp_val; + index_of_max_val = it.row() + it.col() * n_rows; + } + + ++it; + } + } + else + { + // We can do direct access. + max_val = std::abs(op_max::direct_max(P.get_values(), n_nonzero, index_of_max_val)); + + // Convert to actual position in matrix. + const uword row = P.get_row_indices()[index_of_max_val]; + uword col = 0; + while(P.get_col_ptrs()[++col] <= index_of_max_val) { } + index_of_max_val = (col - 1) * n_rows + row; + } + + + if(n_elem != n_nonzero) + { + max_val = (std::max)(T(0), max_val); + + // If the max_val is a nonzero element, we need its actual position in the matrix. + if(max_val == T(0)) + { + // Find first zero element. + uword last_row = 0; + uword last_col = 0; + + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + // Have we moved more than one position from the last place? + if((it.col() == last_col) && (it.row() - last_row > 1)) + { + index_of_max_val = it.col() * n_rows + last_row + 1; + break; + } + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + { + index_of_max_val = last_col * n_rows + last_row + 1; + break; + } + else if((it.col() == last_col + 1) && (it.row() > 0)) + { + index_of_max_val = it.col() * n_rows; + break; + } + else if(it.col() > last_col + 1) + { + index_of_max_val = (last_col + 1) * n_rows; + break; + } + + last_row = it.row(); + last_col = it.col(); + ++it; + } + } + } + + return P[index_of_max_val]; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_mean_bones.hpp b/src/armadillo/include/armadillo_bits/spop_mean_bones.hpp new file mode 100644 index 0000000..3d3e102 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_mean_bones.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_mean +//! @{ + + +//! Class for finding mean values of a sparse matrix +class spop_mean + : public traits_op_xvec + { + public: + + // Apply mean into an output sparse matrix (or vector). + template + inline static void apply(SpMat& out, const SpOp& in); + + template + inline static void apply_noalias_fast(SpMat& out, const SpProxy& p, const uword dim); + + template + inline static void apply_noalias_slow(SpMat& out, const SpProxy& p, const uword dim); + + // Take direct mean of a set of values. Length of array and number of values can be different. + template + inline static eT direct_mean(const eT* const X, const uword length, const uword N); + + template + inline static eT direct_mean_robust(const eT* const X, const uword length, const uword N); + + template + inline static typename T1::elem_type mean_all(const SpBase& X); + + template + inline static typename T1::elem_type mean_all(const SpOp& expr); + + // Take the mean using an iterator. + template + inline static eT iterator_mean(T1& it, const T1& end, const uword n_zero, const eT junk); + + template + inline static eT iterator_mean_robust(T1& it, const T1& end, const uword n_zero, const eT junk); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_mean_meat.hpp b/src/armadillo/include/armadillo_bits/spop_mean_meat.hpp new file mode 100644 index 0000000..dd97916 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_mean_meat.hpp @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_mean +//! @{ + + + +template +inline +void +spop_mean::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "mean(): parameter 'dim' must be 0 or 1" ); + + const SpProxy p(in.m); + + if(p.is_alias(out) == false) + { + spop_mean::apply_noalias_fast(out, p, dim); + } + else + { + SpMat tmp; + + spop_mean::apply_noalias_fast(tmp, p, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spop_mean::apply_noalias_fast + ( + SpMat& out, + const SpProxy& p, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const uword p_n_rows = p.get_n_rows(); + const uword p_n_cols = p.get_n_cols(); + + if( (p_n_rows == 0) || (p_n_cols == 0) || (p.get_n_nonzero() == 0) ) + { + if(dim == 0) { out.zeros((p_n_rows > 0) ? 1 : 0, p_n_cols); } + if(dim == 1) { out.zeros(p_n_rows, (p_n_cols > 0) ? 1 : 0); } + + return; + } + + if(dim == 0) // find the mean in each column + { + Row acc(p_n_cols, arma_zeros_indicator()); + + eT* acc_mem = acc.memptr(); + + if(SpProxy::use_iterator) + { + typename SpProxy::const_iterator_type it = p.begin(); + + const uword N = p.get_n_nonzero(); + + for(uword i=0; i < N; ++i) { acc_mem[it.col()] += (*it); ++it; } + + acc /= T(p_n_rows); + } + else + { + for(uword col = 0; col < p_n_cols; ++col) + { + acc_mem[col] = arrayops::accumulate + ( + &p.get_values()[p.get_col_ptrs()[col]], + p.get_col_ptrs()[col + 1] - p.get_col_ptrs()[col] + ) / T(p_n_rows); + } + } + + out = acc; + } + else + if(dim == 1) // find the mean in each row + { + Col acc(p_n_rows, arma_zeros_indicator()); + + eT* acc_mem = acc.memptr(); + + typename SpProxy::const_iterator_type it = p.begin(); + + const uword N = p.get_n_nonzero(); + + for(uword i=0; i < N; ++i) { acc_mem[it.row()] += (*it); ++it; } + + acc /= T(p_n_cols); + + out = acc; + } + + if(out.internal_has_nonfinite()) + { + spop_mean::apply_noalias_slow(out, p, dim); + } + } + + + +template +inline +void +spop_mean::apply_noalias_slow + ( + SpMat& out, + const SpProxy& p, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword p_n_rows = p.get_n_rows(); + const uword p_n_cols = p.get_n_cols(); + + if(dim == 0) // find the mean in each column + { + arma_extra_debug_print("spop_mean::apply_noalias(): dim = 0"); + + out.set_size((p_n_rows > 0) ? 1 : 0, p_n_cols); + + if( (p_n_rows == 0) || (p.get_n_nonzero() == 0) ) { return; } + + for(uword col = 0; col < p_n_cols; ++col) + { + // Do we have to use an iterator or can we use memory directly? + if(SpProxy::use_iterator) + { + typename SpProxy::const_iterator_type it = p.begin_col(col); + typename SpProxy::const_iterator_type end = p.begin_col(col + 1); + + const uword n_zero = p_n_rows - (end.pos() - it.pos()); + + out.at(0,col) = spop_mean::iterator_mean(it, end, n_zero, eT(0)); + } + else + { + out.at(0,col) = spop_mean::direct_mean + ( + &p.get_values()[p.get_col_ptrs()[col]], + p.get_col_ptrs()[col + 1] - p.get_col_ptrs()[col], + p_n_rows + ); + } + } + } + else + if(dim == 1) // find the mean in each row + { + arma_extra_debug_print("spop_mean::apply_noalias(): dim = 1"); + + out.set_size(p_n_rows, (p_n_cols > 0) ? 1 : 0); + + if( (p_n_cols == 0) || (p.get_n_nonzero() == 0) ) { return; } + + for(uword row = 0; row < p_n_rows; ++row) + { + // We must use an iterator regardless of how it is stored. + typename SpProxy::const_row_iterator_type it = p.begin_row(row); + typename SpProxy::const_row_iterator_type end = p.end_row(row); + + const uword n_zero = p_n_cols - (end.pos() - it.pos()); + + out.at(row,0) = spop_mean::iterator_mean(it, end, n_zero, eT(0)); + } + } + } + + + +template +inline +eT +spop_mean::direct_mean + ( + const eT* const X, + const uword length, + const uword N + ) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const eT result = ((length > 0) && (N > 0)) ? eT(arrayops::accumulate(X, length) / T(N)) : eT(0); + + return arma_isfinite(result) ? result : spop_mean::direct_mean_robust(X, length, N); + } + + + +template +inline +eT +spop_mean::direct_mean_robust + ( + const eT* const X, + const uword length, + const uword N + ) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + uword i, j; + + eT r_mean = eT(0); + + const uword diff = (N - length); // number of zeros + + for(i = 0, j = 1; j < length; i += 2, j += 2) + { + const eT Xi = X[i]; + const eT Xj = X[j]; + + r_mean += (Xi - r_mean) / T(diff + j); + r_mean += (Xj - r_mean) / T(diff + j + 1); + } + + if(i < length) + { + const eT Xi = X[i]; + + r_mean += (Xi - r_mean) / T(diff + i + 1); + } + + return r_mean; + } + + + +template +inline +typename T1::elem_type +spop_mean::mean_all(const SpBase& X) + { + arma_extra_debug_sigprint(); + + SpProxy p(X.get_ref()); + + if(SpProxy::use_iterator) + { + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type end = p.end(); + + return spop_mean::iterator_mean(it, end, p.get_n_elem() - p.get_n_nonzero(), typename T1::elem_type(0)); + } + else // use_iterator == false; that is, we can directly access the values array + { + return spop_mean::direct_mean(p.get_values(), p.get_n_nonzero(), p.get_n_elem()); + } + } + + + +template +inline +typename T1::elem_type +spop_mean::mean_all(const SpOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const bool is_vectorise = \ + (is_same_type::yes) + || (is_same_type::yes) + || (is_same_type::yes); + + if(is_vectorise) + { + return spop_mean::mean_all(expr.m); + } + + const SpMat tmp = expr; + + return spop_mean::mean_all(tmp); + } + + + +template +inline +eT +spop_mean::iterator_mean(T1& it, const T1& end, const uword n_zero, const eT junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + eT acc = eT(0); + + T1 backup_it(it); // in case we have to use robust iterator_mean + + const uword it_begin_pos = it.pos(); + + while(it != end) + { + acc += (*it); + ++it; + } + + const uword count = n_zero + (it.pos() - it_begin_pos); + + const eT result = (count > 0) ? eT(acc / T(count)) : eT(0); + + return arma_isfinite(result) ? result : spop_mean::iterator_mean_robust(backup_it, end, n_zero, eT(0)); + } + + + +template +inline +eT +spop_mean::iterator_mean_robust(T1& it, const T1& end, const uword n_zero, const eT junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + eT r_mean = eT(0); + + const uword it_begin_pos = it.pos(); + + while(it != end) + { + r_mean += ((*it - r_mean) / T(n_zero + (it.pos() - it_begin_pos) + 1)); + ++it; + } + + return r_mean; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_min_bones.hpp b/src/armadillo/include/armadillo_bits/spop_min_bones.hpp new file mode 100644 index 0000000..1fd8ed3 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_min_bones.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_min +//! @{ + + +class spop_min + : public traits_op_xvec + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + + // + + template + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_not_cx::result* junk = nullptr); + + template + inline static typename T1::elem_type vector_min(const T1& X, const typename arma_not_cx::result* junk = nullptr); + + template + inline static typename arma_not_cx::result min(const SpBase& X); + + template + inline static typename arma_not_cx::result min_with_index(const SpProxy& P, uword& index_of_min_val); + + // + + template + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_cx_only::result* junk = nullptr); + + template + inline static typename T1::elem_type vector_min(const T1& X, const typename arma_cx_only::result* junk = nullptr); + + template + inline static typename arma_cx_only::result min(const SpBase& X); + + template + inline static typename arma_cx_only::result min_with_index(const SpProxy& P, uword& index_of_min_val); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_min_meat.hpp b/src/armadillo/include/armadillo_bits/spop_min_meat.hpp new file mode 100644 index 0000000..b47c33c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_min_meat.hpp @@ -0,0 +1,722 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_min +//! @{ + + + +template +inline +void +spop_min::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "min(): parameter 'dim' must be 0 or 1" ); + + const SpProxy p(in.m); + + const uword p_n_rows = p.get_n_rows(); + const uword p_n_cols = p.get_n_cols(); + + if( (p_n_rows == 0) || (p_n_cols == 0) || (p.get_n_nonzero() == 0) ) + { + if(dim == 0) { out.zeros((p_n_rows > 0) ? 1 : 0, p_n_cols); } + if(dim == 1) { out.zeros(p_n_rows, (p_n_cols > 0) ? 1 : 0); } + + return; + } + + spop_min::apply_proxy(out, p, dim); + } + + + +template +inline +void +spop_min::apply_proxy + ( + SpMat& out, + const SpProxy& p, + const uword dim, + const typename arma_not_cx::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + const uword p_n_cols = p.get_n_cols(); + const uword p_n_rows = p.get_n_rows(); + + if(dim == 0) // find the minimum in each column + { + Row value(p_n_cols, arma_zeros_indicator()); + urowvec count(p_n_cols, arma_zeros_indicator()); + + while(it != it_end) + { + const uword col = it.col(); + + value[col] = (count[col] == 0) ? (*it) : (std::min)(value[col], (*it)); + count[col]++; + ++it; + } + + for(uword col=0; col value(p_n_rows, arma_zeros_indicator()); + ucolvec count(p_n_rows, arma_zeros_indicator()); + + while(it != it_end) + { + const uword row = it.row(); + + value[row] = (count[row] == 0) ? (*it) : (std::min)(value[row], (*it)); + count[row]++; + ++it; + } + + for(uword row=0; row +inline +typename T1::elem_type +spop_min::vector_min + ( + const T1& x, + const typename arma_not_cx::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const SpProxy p(x); + + if(p.get_n_elem() == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + if(p.get_n_nonzero() == 0) { return eT(0); } + + if(SpProxy::use_iterator == false) + { + // direct access of values + if(p.get_n_nonzero() == p.get_n_elem()) + { + return op_min::direct_min(p.get_values(), p.get_n_nonzero()); + } + else + { + return (std::min)(eT(0), op_min::direct_min(p.get_values(), p.get_n_nonzero())); + } + } + else + { + // use iterator + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + eT result = (*it); + ++it; + + while(it != it_end) + { + if((*it) < result) { result = (*it); } + + ++it; + } + + if(p.get_n_nonzero() == p.get_n_elem()) + { + return result; + } + else + { + return (std::min)(eT(0), result); + } + } + } + + + +template +inline +typename arma_not_cx::result +spop_min::min(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + eT min_val = priv::most_pos(); + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + if((*it) < min_val) { min_val = *it; } + + ++it; + } + } + else + { + // We can do direct access of the values, row_indices, and col_ptrs. + // We don't need the location of the min value, so we can just call out to + // other functions... + min_val = op_min::direct_min(P.get_values(), n_nonzero); + } + + if(n_elem == n_nonzero) + { + return min_val; + } + else + { + return (std::min)(eT(0), min_val); + } + } + + + +template +inline +typename arma_not_cx::result +spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + const uword n_rows = P.get_n_rows(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + index_of_min_val = uword(0); + + return Datum::nan; + } + + eT min_val = priv::most_pos(); + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + if((*it) < min_val) + { + min_val = *it; + index_of_min_val = it.row() + it.col() * n_rows; + } + + ++it; + } + } + else + { + // We can do direct access. + min_val = op_min::direct_min(P.get_values(), n_nonzero, index_of_min_val); + + // Convert to actual position in matrix. + const uword row = P.get_row_indices()[index_of_min_val]; + uword col = 0; + while(P.get_col_ptrs()[++col] < index_of_min_val + 1) { } + index_of_min_val = (col - 1) * n_rows + row; + } + + + if(n_elem != n_nonzero) + { + min_val = (std::min)(eT(0), min_val); + + // If the min_val is a nonzero element, we need its actual position in the matrix. + if(min_val == eT(0)) + { + // Find first zero element. + uword last_row = 0; + uword last_col = 0; + + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + // Have we moved more than one position from the last place? + if((it.col() == last_col) && (it.row() - last_row > 1)) + { + index_of_min_val = it.col() * n_rows + last_row + 1; + break; + } + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + { + index_of_min_val = last_col * n_rows + last_row + 1; + break; + } + else if((it.col() == last_col + 1) && (it.row() > 0)) + { + index_of_min_val = it.col() * n_rows; + break; + } + else if(it.col() > last_col + 1) + { + index_of_min_val = (last_col + 1) * n_rows; + break; + } + + last_row = it.row(); + last_col = it.col(); + ++it; + } + } + } + + return min_val; + } + + + +template +inline +void +spop_min::apply_proxy + ( + SpMat& out, + const SpProxy& p, + const uword dim, + const typename arma_cx_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + const uword p_n_cols = p.get_n_cols(); + const uword p_n_rows = p.get_n_rows(); + + if(dim == 0) // find the minimum in each column + { + Row rawval(p_n_cols, arma_zeros_indicator()); + Row< T> absval(p_n_cols, arma_zeros_indicator()); + urowvec count(p_n_cols, arma_zeros_indicator()); + + while(it != it_end) + { + const uword col = it.col(); + + const eT& v = (*it); + const T a = std::abs(v); + + if(count[col] == 0) + { + absval[col] = a; + rawval[col] = v; + } + else + { + if(a < absval[col]) + { + absval[col] = a; + rawval[col] = v; + } + } + + count[col]++; + ++it; + } + + for(uword col=0; col < p_n_cols; ++col) + { + if(count[col] < p_n_rows) + { + if(T(0) < absval[col]) { rawval[col] = eT(0); } + } + } + + out = rawval; + } + else + if(dim == 1) // find the minimum in each row + { + Col rawval(p_n_rows, arma_zeros_indicator()); + Col< T> absval(p_n_rows, arma_zeros_indicator()); + ucolvec count(p_n_rows, arma_zeros_indicator()); + + while(it != it_end) + { + const uword row = it.row(); + + const eT& v = (*it); + const T a = std::abs(v); + + if(count[row] == 0) + { + absval[row] = a; + rawval[row] = v; + } + else + { + if(a < absval[row]) + { + absval[row] = a; + rawval[row] = v; + } + } + + count[row]++; + ++it; + } + + for(uword row=0; row < p_n_rows; ++row) + { + if(count[row] < p_n_cols) + { + if(T(0) < absval[row]) { rawval[row] = eT(0); } + } + } + + out = rawval; + } + } + + + +template +inline +typename T1::elem_type +spop_min::vector_min + ( + const T1& x, + const typename arma_cx_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const SpProxy p(x); + + if(p.get_n_elem() == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + if(p.get_n_nonzero() == 0) { return eT(0); } + + if(SpProxy::use_iterator == false) + { + // direct access of values + if(p.get_n_nonzero() == p.get_n_elem()) + { + return op_min::direct_min(p.get_values(), p.get_n_nonzero()); + } + else + { + const eT val1 = eT(0); + const eT val2 = op_min::direct_min(p.get_values(), p.get_n_nonzero()); + + return ( std::abs(val1) < std::abs(val2) ) ? val1 : val2; + } + } + else + { + // use iterator + typename SpProxy::const_iterator_type it = p.begin(); + typename SpProxy::const_iterator_type it_end = p.end(); + + eT best_val_orig = *it; + T best_val_abs = std::abs(best_val_orig); + + ++it; + + while(it != it_end) + { + eT val_orig = *it; + T val_abs = std::abs(val_orig); + + if(val_abs < best_val_abs) + { + best_val_abs = val_abs; + best_val_orig = val_orig; + } + + ++it; + } + + if(p.get_n_nonzero() == p.get_n_elem()) + { + return best_val_orig; + } + else + { + const eT val1 = eT(0); + + return ( std::abs(val1) < best_val_abs ) ? val1 : best_val_orig; + } + } + } + + + +template +inline +typename arma_cx_only::result +spop_min::min(const SpBase& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const SpProxy P(X.get_ref()); + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + T min_val = priv::most_pos(); + eT ret_val; + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + const T tmp_val = std::abs(*it); + + if(tmp_val < min_val) + { + min_val = tmp_val; + ret_val = *it; + } + + ++it; + } + } + else + { + // We can do direct access of the values, row_indices, and col_ptrs. + // We don't need the location of the min value, so we can just call out to + // other functions... + ret_val = op_min::direct_min(P.get_values(), n_nonzero); + min_val = std::abs(ret_val); + } + + if(n_elem == n_nonzero) + { + return ret_val; + } + else + { + return (T(0) < min_val) ? eT(0) : ret_val; + } + } + + + +template +inline +typename arma_cx_only::result +spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + const uword n_elem = P.get_n_elem(); + const uword n_nonzero = P.get_n_nonzero(); + const uword n_rows = P.get_n_rows(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + index_of_min_val = uword(0); + + return Datum::nan; + } + + T min_val = priv::most_pos(); + + if(SpProxy::use_iterator) + { + // We have to iterate over the elements. + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + const T tmp_val = std::abs(*it); + + if(tmp_val < min_val) + { + min_val = tmp_val; + index_of_min_val = it.row() + it.col() * n_rows; + } + + ++it; + } + } + else + { + // We can do direct access. + min_val = std::abs(op_min::direct_min(P.get_values(), n_nonzero, index_of_min_val)); + + // Convert to actual position in matrix. + const uword row = P.get_row_indices()[index_of_min_val]; + uword col = 0; + while(P.get_col_ptrs()[++col] < index_of_min_val + 1) { } + index_of_min_val = (col - 1) * n_rows + row; + } + + + if(n_elem != n_nonzero) + { + min_val = (std::min)(T(0), min_val); + + // If the min_val is a nonzero element, we need its actual position in the matrix. + if(min_val == T(0)) + { + // Find first zero element. + uword last_row = 0; + uword last_col = 0; + + typedef typename SpProxy::const_iterator_type it_type; + + it_type it = P.begin(); + it_type it_end = P.end(); + + while(it != it_end) + { + // Have we moved more than one position from the last place? + if((it.col() == last_col) && (it.row() - last_row > 1)) + { + index_of_min_val = it.col() * n_rows + last_row + 1; + break; + } + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + { + index_of_min_val = last_col * n_rows + last_row + 1; + break; + } + else if((it.col() == last_col + 1) && (it.row() > 0)) + { + index_of_min_val = it.col() * n_rows; + break; + } + else if(it.col() > last_col + 1) + { + index_of_min_val = (last_col + 1) * n_rows; + break; + } + + last_row = it.row(); + last_col = it.col(); + ++it; + } + } + } + + return P[index_of_min_val]; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_misc_bones.hpp b/src/armadillo/include/armadillo_bits/spop_misc_bones.hpp new file mode 100644 index 0000000..42117f9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_misc_bones.hpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_misc +//! @{ + + +class spop_scalar_times + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_cx_scalar_times + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat< std::complex >& out, const mtSpOp< std::complex, T1, spop_cx_scalar_times>& in); + }; + + + +class spop_square + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_sqrt + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_abs + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_cx_abs + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& in); + }; + + + +class spop_arg + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_cx_arg + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& in); + }; + + + +class spop_real + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& in); + }; + + + +class spop_imag + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& in); + }; + + + +class spop_conj + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_repelem + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_reshape + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_resize + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_floor + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_ceil + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_round + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_trunc + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_sign + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_diagvec + : public traits_op_col + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_flipud + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_fliplr + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_misc_meat.hpp b/src/armadillo/include/armadillo_bits/spop_misc_meat.hpp new file mode 100644 index 0000000..1ef51cc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_misc_meat.hpp @@ -0,0 +1,596 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_misc +//! @{ + + + +namespace priv + { + template + struct functor_scalar_times + { + const eT k; + + functor_scalar_times(const eT in_k) : k(in_k) {} + + arma_inline eT operator()(const eT val) const { return val * k; } + }; + } + + + +template +inline +void +spop_scalar_times::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(in.aux != eT(0)) + { + out.init_xform(in.m, priv::functor_scalar_times(in.aux)); + } + else + { + const SpProxy P(in.m); + + out.zeros( P.get_n_rows(), P.get_n_cols() ); + } + } + + + +namespace priv + { + template + struct functor_cx_scalar_times + { + typedef std::complex out_eT; + + const out_eT k; + + functor_cx_scalar_times(const out_eT in_k) : k(in_k) {} + + arma_inline out_eT operator()(const T val) const { return val * k; } + }; + } + + + +template +inline +void +spop_cx_scalar_times::apply(SpMat< std::complex >& out, const mtSpOp< std::complex, T1, spop_cx_scalar_times >& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename std::complex out_eT; + + if(in.aux_out_eT != out_eT(0)) + { + out.init_xform_mt(in.m, priv::functor_cx_scalar_times(in.aux_out_eT)); + } + else + { + const SpProxy P(in.m); + + out.zeros( P.get_n_rows(), P.get_n_cols() ); + } + } + + + +namespace priv + { + struct functor_square + { + template + arma_inline eT operator()(const eT val) const { return val*val; } + }; + } + + + +template +inline +void +spop_square::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_square()); + } + + + +namespace priv + { + struct functor_sqrt + { + template + arma_inline eT operator()(const eT val) const { return eop_aux::sqrt(val); } + }; + } + + + +template +inline +void +spop_sqrt::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_sqrt()); + } + + + +namespace priv + { + struct functor_abs + { + template + arma_inline eT operator()(const eT val) const { return eop_aux::arma_abs(val); } + }; + } + + + +template +inline +void +spop_abs::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_abs()); + } + + + +namespace priv + { + struct functor_cx_abs + { + template + arma_inline T operator()(const std::complex& val) const { return std::abs(val); } + }; + } + + + +template +inline +void +spop_cx_abs::apply(SpMat& out, const mtSpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform_mt(in.m, priv::functor_cx_abs()); + } + + + +namespace priv + { + struct functor_arg + { + template + arma_inline eT operator()(const eT val) const { return arma_arg::eval(val); } + }; + } + + + +template +inline +void +spop_arg::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_arg()); + } + + + +namespace priv + { + struct functor_cx_arg + { + template + arma_inline T operator()(const std::complex& val) const { return std::arg(val); } + }; + } + + + +template +inline +void +spop_cx_arg::apply(SpMat& out, const mtSpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform_mt(in.m, priv::functor_cx_arg()); + } + + + +namespace priv + { + struct functor_real + { + template + arma_inline T operator()(const std::complex& val) const { return val.real(); } + }; + } + + + +template +inline +void +spop_real::apply(SpMat& out, const mtSpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform_mt(in.m, priv::functor_real()); + } + + + +namespace priv + { + struct functor_imag + { + template + arma_inline T operator()(const std::complex& val) const { return val.imag(); } + }; + } + + + +template +inline +void +spop_imag::apply(SpMat& out, const mtSpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform_mt(in.m, priv::functor_imag()); + } + + + +namespace priv + { + struct functor_conj + { + template + arma_inline eT operator()(const eT val) const { return eop_aux::conj(val); } + }; + } + + + +template +inline +void +spop_conj::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_conj()); + } + + + +template +inline +void +spop_repelem::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + const SpMat& X = U.M; + + const uword copies_per_row = in.aux_uword_a; + const uword copies_per_col = in.aux_uword_b; + + const uword out_n_rows = X.n_rows * copies_per_row; + const uword out_n_cols = X.n_cols * copies_per_col; + const uword out_nnz = X.n_nonzero * copies_per_row * copies_per_col; + + if( (out_n_rows > 0) && (out_n_cols > 0) && (out_nnz > 0) ) + { + umat locs(2, out_nnz, arma_nozeros_indicator()); + Col vals( out_nnz, arma_nozeros_indicator()); + + uword* locs_mem = locs.memptr(); + eT* vals_mem = vals.memptr(); + + typename SpMat::const_iterator X_it = X.begin(); + typename SpMat::const_iterator X_end = X.end(); + + for(; X_it != X_end; ++X_it) + { + const uword col_base = copies_per_col * X_it.col(); + const uword row_base = copies_per_row * X_it.row(); + + const eT X_val = (*X_it); + + for(uword cols = 0; cols < copies_per_col; cols++) + for(uword rows = 0; rows < copies_per_row; rows++) + { + (*locs_mem) = row_base + rows; ++locs_mem; + (*locs_mem) = col_base + cols; ++locs_mem; + + (*vals_mem) = X_val; ++vals_mem; + } + } + + out = SpMat(locs, vals, out_n_rows, out_n_cols); + } + else + { + out.zeros(out_n_rows, out_n_cols); + } + } + + + +template +inline +void +spop_reshape::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out = in.m; + + out.reshape(in.aux_uword_a, in.aux_uword_b); + } + + + +template +inline +void +spop_resize::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out = in.m; + + out.resize(in.aux_uword_a, in.aux_uword_b); + } + + + +namespace priv + { + struct functor_floor + { + template + arma_inline eT operator()(const eT val) const { return eop_aux::floor(val); } + }; + } + + + +template +inline +void +spop_floor::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_floor()); + } + + + +namespace priv + { + struct functor_ceil + { + template + arma_inline eT operator()(const eT val) const { return eop_aux::ceil(val); } + }; + } + + + +template +inline +void +spop_ceil::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_ceil()); + } + + + +namespace priv + { + struct functor_round + { + template + arma_inline eT operator()(const eT val) const { return eop_aux::round(val); } + }; + } + + + +template +inline +void +spop_round::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_round()); + } + + + +namespace priv + { + struct functor_trunc + { + template + arma_inline eT operator()(const eT val) const { return eop_aux::trunc(val); } + }; + } + + + +template +inline +void +spop_trunc::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_trunc()); + } + + + +namespace priv + { + struct functor_sign + { + template + arma_inline eT operator()(const eT val) const { return arma_sign(val); } + }; + } + + + +template +inline +void +spop_sign::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out.init_xform(in.m, priv::functor_sign()); + } + + + +template +inline +void +spop_diagvec::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + + const SpMat& X = U.M; + + const uword a = in.aux_uword_a; + const uword b = in.aux_uword_b; + + const uword row_offset = (b > 0) ? a : 0; + const uword col_offset = (b == 0) ? a : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= X.n_rows)) || ((col_offset > 0) && (col_offset >= X.n_cols)), + "diagvec(): requested diagonal out of bounds" + ); + + const uword len = (std::min)(X.n_rows - row_offset, X.n_cols - col_offset); + + Col cache(len, arma_nozeros_indicator()); + eT* cache_mem = cache.memptr(); + + uword n_nonzero = 0; + + for(uword i=0; i < len; ++i) + { + const eT val = X.at(i + row_offset, i + col_offset); + + cache_mem[i] = val; + + n_nonzero += (val != eT(0)) ? uword(1) : uword(0); + } + + out.reserve(len, 1, n_nonzero); + + uword count = 0; + for(uword i=0; i < len; ++i) + { + const eT val = cache_mem[i]; + + if(val != eT(0)) + { + access::rw(out.row_indices[count]) = i; + access::rw(out.values[count]) = val; + ++count; + } + } + + access::rw(out.col_ptrs[0]) = 0; + access::rw(out.col_ptrs[1]) = n_nonzero; + } + + + +template +inline +void +spop_flipud::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out = reverse(in.m, 0); + } + + + +template +inline +void +spop_fliplr::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + out = reverse(in.m, 1); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_norm_bones.hpp b/src/armadillo/include/armadillo_bits/spop_norm_bones.hpp new file mode 100644 index 0000000..1d94451 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_norm_bones.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_norm +//! @{ + + +class spop_norm + : public traits_op_default + { + public: + + template inline static typename get_pod_type::result mat_norm_1(const SpMat& X); + + template inline static typename get_pod_type::result mat_norm_2(const SpMat& X, const typename arma_real_only::result* junk = nullptr); + template inline static typename get_pod_type::result mat_norm_2(const SpMat& X, const typename arma_cx_only::result* junk = nullptr); + + template inline static typename get_pod_type::result mat_norm_inf(const SpMat& X); + + template inline static typename get_pod_type::result vec_norm_k(const eT* mem, const uword N, const uword k); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_norm_meat.hpp b/src/armadillo/include/armadillo_bits/spop_norm_meat.hpp new file mode 100644 index 0000000..402ea29 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_norm_meat.hpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_norm +//! @{ + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_1(const SpMat& X) + { + arma_extra_debug_sigprint(); + + // TODO: this can be sped up with a dedicated implementation + return as_scalar( max( sum(abs(X), 0), 1) ); + } + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_2(const SpMat& X, const typename arma_real_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + // norm = sqrt( largest eigenvalue of (A^H)*A ), where ^H is the conjugate transpose + // http://math.stackexchange.com/questions/4368/computing-the-largest-eigenvalue-of-a-very-large-sparse-matrix + + typedef typename get_pod_type::result T; + + const SpMat& A = X; + const SpMat B = trans(A); + + const SpMat C = (A.n_rows <= A.n_cols) ? (A*B) : (B*A); + + Col eigval; + eigs_sym(eigval, C, 1); + + return (eigval.n_elem > 0) ? T(std::sqrt(eigval[0])) : T(0); + } + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_2(const SpMat& X, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + // we're calling eigs_gen(), which currently requires ARPACK + #if !defined(ARMA_USE_ARPACK) + { + arma_stop_logic_error("norm(): use of ARPACK must be enabled for norm of complex matrices"); + return T(0); + } + #endif + + const SpMat& A = X; + const SpMat B = trans(A); + + const SpMat C = (A.n_rows <= A.n_cols) ? (A*B) : (B*A); + + Col eigval; + eigs_gen(eigval, C, 1); + + return (eigval.n_elem > 0) ? T(std::sqrt(std::real(eigval[0]))) : T(0); + } + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_inf(const SpMat& X) + { + arma_extra_debug_sigprint(); + + // TODO: this can be sped up with a dedicated implementation + return as_scalar( max( sum(abs(X), 1), 0) ); + } + + + +template +inline +typename get_pod_type::result +spop_norm::vec_norm_k(const eT* mem, const uword N, const uword k) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (k == 0), "norm(): unsupported vector norm type" ); + + // create a fake dense vector to allow reuse of code for dense vectors + Col fake_vector( access::rwp(mem), N, false ); + + const Proxy< Col > P_fake_vector(fake_vector); + + if(k == uword(1)) { return op_norm::vec_norm_1(P_fake_vector); } + if(k == uword(2)) { return op_norm::vec_norm_2(P_fake_vector); } + + return op_norm::vec_norm_k(P_fake_vector, int(k)); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_normalise_bones.hpp b/src/armadillo/include/armadillo_bits/spop_normalise_bones.hpp new file mode 100644 index 0000000..839b9ca --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_normalise_bones.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_normalise +//! @{ + + +class spop_normalise + : public traits_op_passthru + { + public: + + template + inline static void apply(SpMat& out, const SpOp& expr); + + template + inline static void apply_direct(SpMat& out, const SpMat& X, const uword p); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_normalise_meat.hpp b/src/armadillo/include/armadillo_bits/spop_normalise_meat.hpp new file mode 100644 index 0000000..96a1759 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_normalise_meat.hpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_normalise +//! @{ + + + +template +inline +void +spop_normalise::apply(SpMat& out, const SpOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword p = expr.aux_uword_a; + const uword dim = expr.aux_uword_b; + + arma_debug_check( (p == 0), "normalise(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "normalise(): parameter 'dim' must be 0 or 1" ); + + const unwrap_spmat U(expr.m); + + const SpMat& X = U.M; + + X.sync(); + + if( X.is_empty() || (X.n_nonzero == 0) ) { out.zeros(X.n_rows, X.n_cols); return; } + + if(dim == 0) + { + spop_normalise::apply_direct(out, X, p); + } + else + if(dim == 1) + { + SpMat tmp1; + SpMat tmp2; + + spop_strans::apply_noalias(tmp1, X); + + spop_normalise::apply_direct(tmp2, tmp1, p); + + spop_strans::apply_noalias(out, tmp2); + } + } + + + +template +inline +void +spop_normalise::apply_direct(SpMat& out, const SpMat& X, const uword p) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + SpMat tmp(arma_reserve_indicator(), X.n_rows, X.n_cols, X.n_nonzero); + + bool has_zero = false; + + podarray norm_vals(X.n_cols); + + T* norm_vals_mem = norm_vals.memptr(); + + for(uword col=0; col < X.n_cols; ++col) + { + const uword col_offset = X.col_ptrs[col ]; + const uword next_col_offset = X.col_ptrs[col + 1]; + + const eT* start_ptr = &X.values[ col_offset]; + const eT* end_ptr = &X.values[next_col_offset]; + + const uword n_elem = end_ptr - start_ptr; + + const Col fake_vec(const_cast(start_ptr), n_elem, false, false); + + const T norm_val = norm(fake_vec, p); + + norm_vals_mem[col] = (norm_val != T(0)) ? norm_val : T(1); + } + + const uword N = X.n_nonzero; + + typename SpMat::const_iterator it = X.begin(); + + for(uword i=0; i < N; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + const eT val = (*it) / norm_vals_mem[col]; + + if(val == eT(0)) { has_zero = true; } + + access::rw(tmp.values[i]) = val; + access::rw(tmp.row_indices[i]) = row; + access::rw(tmp.col_ptrs[col + 1])++; + + ++it; + } + + for(uword c=0; c < tmp.n_cols; ++c) + { + access::rw(tmp.col_ptrs[c + 1]) += tmp.col_ptrs[c]; + } + + if(has_zero) { tmp.remove_zeros(); } + + out.steal_mem(tmp); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_repmat_bones.hpp b/src/armadillo/include/armadillo_bits/spop_repmat_bones.hpp new file mode 100644 index 0000000..7ee6843 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_repmat_bones.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_repmat +//! @{ + + + +class spop_repmat + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& X); + + template + inline static void apply_noalias(SpMat& out, const uword A_n_rows, const uword A_n_cols, const SpMat& B); + }; + + + +//! @} + + + diff --git a/src/armadillo/include/armadillo_bits/spop_repmat_meat.hpp b/src/armadillo/include/armadillo_bits/spop_repmat_meat.hpp new file mode 100644 index 0000000..4c09a3e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_repmat_meat.hpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_repmat +//! @{ + + + +template +inline +void +spop_repmat::apply(SpMat& out, const SpOp& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(X.m); + + if(U.is_alias(out)) + { + SpMat tmp; + + spop_repmat::apply_noalias(tmp, X.aux_uword_a, X.aux_uword_b, U.M); + + out.steal_mem(tmp); + } + else + { + spop_repmat::apply_noalias(out, X.aux_uword_a, X.aux_uword_b, U.M); + } + } + + + +template +inline +void +spop_repmat::apply_noalias(SpMat& out, const uword A_n_rows, const uword A_n_cols, const SpMat& B) + { + arma_extra_debug_sigprint(); + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + const uword out_n_nonzero = A_n_rows * A_n_cols * B.n_nonzero; + + out.reserve(A_n_rows * B_n_rows, A_n_cols * B_n_cols, out_n_nonzero); + + if(out_n_nonzero == 0) { return; } + + access::rw(out.col_ptrs[0]) = 0; + + uword count = 0; + + for(uword A_col=0; A_col < A_n_cols; ++A_col) + for(uword B_col=0; B_col < B_n_cols; ++B_col) + { + for(uword A_row=0; A_row < A_n_rows; ++A_row) + { + const uword out_row = A_row * B_n_rows; + + for(uword B_i = B.col_ptrs[B_col]; B_i < B.col_ptrs[B_col+1]; ++B_i) + { + access::rw(out.values[count]) = B.values[B_i]; + access::rw(out.row_indices[count]) = out_row + B.row_indices[B_i]; + + count++; + } + } + + access::rw(out.col_ptrs[A_col * B_n_cols + B_col + 1]) = count; + } + } + + + +// template +// inline +// void +// spop_repmat::apply(SpMat& out, const SpOp& in) +// { +// arma_extra_debug_sigprint(); +// +// typedef typename T1::elem_type eT; +// +// const unwrap_spmat U(in.m); +// const SpMat& X = U.M; +// +// const uword X_n_rows = X.n_rows; +// const uword X_n_cols = X.n_cols; +// +// const uword copies_per_row = in.aux_uword_a; +// const uword copies_per_col = in.aux_uword_b; +// +// // out.set_size(X_n_rows * copies_per_row, X_n_cols * copies_per_col); +// // +// // const uword out_n_rows = out.n_rows; +// // const uword out_n_cols = out.n_cols; +// // +// // if( (out_n_rows > 0) && (out_n_cols > 0) ) +// // { +// // for(uword col = 0; col < out_n_cols; col += X_n_cols) +// // for(uword row = 0; row < out_n_rows; row += X_n_rows) +// // { +// // out.submat(row, col, row+X_n_rows-1, col+X_n_cols-1) = X; +// // } +// // } +// +// const uword out_n_rows = X_n_rows * copies_per_row; +// const uword out_n_cols = X_n_cols * copies_per_col; +// const uword out_nnz = X.n_nonzero * copies_per_row * copies_per_col; +// +// if( (out_n_rows > 0) && (out_n_cols > 0) && (out_nnz > 0) ) +// { +// umat locs(2, out_nnz, arma_nozeros_indicator()); +// Col vals( out_nnz, arma_nozeros_indicator()); +// +// uword* locs_mem = locs.memptr(); +// eT* vals_mem = vals.memptr(); +// +// typename SpMat::const_iterator X_begin = X.begin(); +// typename SpMat::const_iterator X_end = X.end(); +// typename SpMat::const_iterator X_it; +// +// for(uword col_offset = 0; col_offset < out_n_cols; col_offset += X_n_cols) +// for(uword row_offset = 0; row_offset < out_n_rows; row_offset += X_n_rows) +// { +// for(X_it = X_begin; X_it != X_end; ++X_it) +// { +// const uword out_row = row_offset + X_it.row(); +// const uword out_col = col_offset + X_it.col(); +// +// (*locs_mem) = out_row; ++locs_mem; +// (*locs_mem) = out_col; ++locs_mem; +// +// (*vals_mem) = (*X_it); ++vals_mem; +// } +// } +// +// out = SpMat(locs, vals, out_n_rows, out_n_cols); +// } +// else +// { +// out.zeros(out_n_rows, out_n_cols); +// } +// } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_reverse_bones.hpp b/src/armadillo/include/armadillo_bits/spop_reverse_bones.hpp new file mode 100644 index 0000000..3ea80b6 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_reverse_bones.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_reverse +//! @{ + + +class spop_reverse + : public traits_op_passthru + { + public: + + template + inline static void apply_spmat(SpMat& out, const SpMat& X, const uword dim); + + template + inline static void apply_proxy(SpMat& out, const T1& X, const uword dim); + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_reverse_meat.hpp b/src/armadillo/include/armadillo_bits/spop_reverse_meat.hpp new file mode 100644 index 0000000..2ba6e45 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_reverse_meat.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_reverse +//! @{ + + + +template +inline +void +spop_reverse::apply_spmat(SpMat& out, const SpMat& X, const uword dim) + { + arma_extra_debug_sigprint(); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword X_n_rows_m1 = X_n_rows - 1; + const uword X_n_cols_m1 = X_n_cols - 1; + + const uword N = X.n_nonzero; + + if(N == uword(0)) + { + out.zeros(X_n_rows, X_n_cols); + return; + } + + umat locs(2, N, arma_nozeros_indicator()); + + uword* locs_mem = locs.memptr(); + + typename SpMat::const_iterator it = X.begin(); + + if(dim == 0) + { + for(uword i=0; i < N; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + (*locs_mem) = X_n_rows_m1 - row; locs_mem++; + (*locs_mem) = col; locs_mem++; + + ++it; + } + } + else + if(dim == 1) + { + for(uword i=0; i < N; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + (*locs_mem) = row; locs_mem++; + (*locs_mem) = X_n_cols_m1 - col; locs_mem++; + + ++it; + } + } + + const Col vals(const_cast(X.values), N, false); + + SpMat tmp(locs, vals, X_n_rows, X_n_cols, true, false); + + out.steal_mem(tmp); + } + + + +template +inline +void +spop_reverse::apply_proxy(SpMat& out, const T1& X, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(X); + + const uword P_n_rows = P.get_n_rows(); + const uword P_n_cols = P.get_n_cols(); + + const uword P_n_rows_m1 = P_n_rows - 1; + const uword P_n_cols_m1 = P_n_cols - 1; + + const uword N = P.get_n_nonzero(); + + if(N == uword(0)) + { + out.zeros(P_n_rows, P_n_cols); + return; + } + + umat locs(2, N, arma_nozeros_indicator()); + Col vals( N, arma_nozeros_indicator()); + + uword* locs_mem = locs.memptr(); + eT* vals_mem = vals.memptr(); + + typename SpProxy::const_iterator_type it = P.begin(); + + if(dim == 0) + { + for(uword i=0; i < N; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + (*locs_mem) = P_n_rows_m1 - row; locs_mem++; + (*locs_mem) = col; locs_mem++; + + (*vals_mem) = (*it); vals_mem++; + + ++it; + } + } + else + if(dim == 1) + { + for(uword i=0; i < N; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + (*locs_mem) = row; locs_mem++; + (*locs_mem) = P_n_cols_m1 - col; locs_mem++; + + (*vals_mem) = (*it); vals_mem++; + + ++it; + } + } + + SpMat tmp(locs, vals, P_n_rows, P_n_cols, true, false); + + out.steal_mem(tmp); + } + + + +template +inline +void +spop_reverse::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + const uword dim = in.aux_uword_a; + + arma_debug_check( (dim > 1), "reverse(): parameter 'dim' must be 0 or 1" ); + + if(is_SpMat::value) + { + const unwrap_spmat tmp(in.m); + + spop_reverse::apply_spmat(out, tmp.M, dim); + } + else + { + spop_reverse::apply_proxy(out, in.m, dim); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_strans_bones.hpp b/src/armadillo/include/armadillo_bits/spop_strans_bones.hpp new file mode 100644 index 0000000..7d52cca --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_strans_bones.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_strans +//! @{ + + +//! simple transpose operation (no complex conjugates) for sparse matrices + +class spop_strans + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + inline static void apply_noalias(SpMat& B, const SpMat& A); + + template + inline static void apply(SpMat& out, const SpOp& in); + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_strans_meat.hpp b/src/armadillo/include/armadillo_bits/spop_strans_meat.hpp new file mode 100644 index 0000000..4cb83c9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_strans_meat.hpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_strans +//! @{ + + + +template +inline +void +spop_strans::apply_noalias(SpMat& B, const SpMat& A) + { + arma_extra_debug_sigprint(); + + B.reserve(A.n_cols, A.n_rows, A.n_nonzero); // deliberately swapped + + if(A.n_nonzero == 0) { return; } + + // This follows the TRANSP algorithm described in + // 'Sparse Matrix Multiplication Package (SMMP)' + // (R.E. Bank and C.C. Douglas, 2001) + + const uword m = A.n_rows; + const uword n = A.n_cols; + + const eT* a = A.values; + eT* b = access::rwp(B.values); + + const uword* ia = A.col_ptrs; + const uword* ja = A.row_indices; + + uword* ib = access::rwp(B.col_ptrs); + uword* jb = access::rwp(B.row_indices); + + // // ib is already zeroed, as B is freshly constructed + // + // for(uword i=0; i < (m+1); ++i) + // { + // ib[i] = 0; + // } + + for(uword i=0; i < n; ++i) + { + for(uword j = ia[i]; j < ia[i+1]; ++j) + { + ib[ ja[j] + 1 ]++; + } + } + + for(uword i=0; i < m; ++i) + { + ib[i+1] += ib[i]; + } + + for(uword i=0; i < n; ++i) + { + for(uword j = ia[i]; j < ia[i+1]; ++j) + { + const uword jj = ja[j]; + + const uword ib_jj = ib[jj]; + + jb[ib_jj] = i; + + b[ib_jj] = a[j]; + + ib[jj]++; + } + } + + for(uword i = m-1; i >= 1; --i) + { + ib[i] = ib[i-1]; + } + + ib[0] = 0; + } + + + +template +inline +void +spop_strans::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + + if(U.is_alias(out)) + { + SpMat tmp; + + spop_strans::apply_noalias(tmp, U.M); + + out.steal_mem(tmp); + } + else + { + spop_strans::apply_noalias(out, U.M); + } + } + + + +//! for transpose of non-complex matrices, redirected from spop_htrans::apply() +template +inline +void +spop_strans::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + + if(U.is_alias(out)) + { + SpMat tmp; + + spop_strans::apply_noalias(tmp, U.M); + + out.steal_mem(tmp); + } + else + { + spop_strans::apply_noalias(out, U.M); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_sum_bones.hpp b/src/armadillo/include/armadillo_bits/spop_sum_bones.hpp new file mode 100644 index 0000000..2e4e558 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_sum_bones.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_sum +//! @{ + + +class spop_sum + : public traits_op_xvec + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_sum_meat.hpp b/src/armadillo/include/armadillo_bits/spop_sum_meat.hpp new file mode 100644 index 0000000..63badfd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_sum_meat.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_sum +//! @{ + + + +template +inline +void +spop_sum::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword dim = in.aux_uword_a; + arma_debug_check( (dim > 1), "sum(): parameter 'dim' must be 0 or 1" ); + + const SpProxy p(in.m); + + const uword p_n_rows = p.get_n_rows(); + const uword p_n_cols = p.get_n_cols(); + + if(p.get_n_nonzero() == 0) + { + if(dim == 0) { out.zeros(1,p_n_cols); } + if(dim == 1) { out.zeros(p_n_rows,1); } + + return; + } + + if(dim == 0) // find the sum in each column + { + Row acc(p_n_cols, arma_zeros_indicator()); + + eT* acc_mem = acc.memptr(); + + if(SpProxy::use_iterator) + { + typename SpProxy::const_iterator_type it = p.begin(); + + const uword N = p.get_n_nonzero(); + + for(uword i=0; i < N; ++i) + { + acc_mem[it.col()] += (*it); + ++it; + } + } + else + { + for(uword col = 0; col < p_n_cols; ++col) + { + acc_mem[col] = arrayops::accumulate + ( + &p.get_values()[p.get_col_ptrs()[col]], + p.get_col_ptrs()[col + 1] - p.get_col_ptrs()[col] + ); + } + } + + out = acc; + } + else + if(dim == 1) // find the sum in each row + { + Col acc(p_n_rows, arma_zeros_indicator()); + + eT* acc_mem = acc.memptr(); + + typename SpProxy::const_iterator_type it = p.begin(); + + const uword N = p.get_n_nonzero(); + + for(uword i=0; i < N; ++i) + { + acc_mem[it.row()] += (*it); + ++it; + } + + out = acc; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_symmat_bones.hpp b/src/armadillo/include/armadillo_bits/spop_symmat_bones.hpp new file mode 100644 index 0000000..cf130f2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_symmat_bones.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_symmat +//! @{ + + + +class spop_symmat + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_symmat_cx + : public traits_op_default + { + public: + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_symmat_meat.hpp b/src/armadillo/include/armadillo_bits/spop_symmat_meat.hpp new file mode 100644 index 0000000..2ce7cba --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_symmat_meat.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_symmat +//! @{ + + + +template +inline +void +spop_symmat::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + const SpMat& X = U.M; + + arma_debug_check( (X.n_rows != X.n_cols), "symmatu()/symmatl(): given matrix must be square sized" ); + + if(X.n_nonzero == uword(0)) { out.zeros(X.n_rows, X.n_cols); return; } + + const bool upper = (in.aux_uword_a == 0); + + const SpMat A = (upper) ? trimatu(X) : trimatl(X); // in this case trimatu() and trimatl() return the same type + const SpMat B = A.st(); + + spglue_merge::symmat_merge(out, A, B); + } + + + +template +inline +void +spop_symmat_cx::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + const SpMat& X = U.M; + + arma_debug_check( (X.n_rows != X.n_cols), "symmatu()/symmatl(): given matrix must be square sized" ); + + if(X.n_nonzero == uword(0)) { out.zeros(X.n_rows, X.n_cols); return; } + + const bool upper = (in.aux_uword_a == 0); + const bool do_conj = (in.aux_uword_b == 1); + + const SpMat A = (upper) ? trimatu(X) : trimatl(X); // in this case trimatu() and trimatl() return the same type + + if(do_conj) + { + const SpMat B = A.t(); + + spglue_merge::symmat_merge(out, A, B); + } + else + { + const SpMat B = A.st(); + + spglue_merge::symmat_merge(out, A, B); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_trimat_bones.hpp b/src/armadillo/include/armadillo_bits/spop_trimat_bones.hpp new file mode 100644 index 0000000..5b3aecc --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_trimat_bones.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_trimat +//! @{ + + + +class spop_trimat + : public traits_op_default + { + public: + + template + inline static void apply_noalias(SpMat& out, const SpProxy& P, const bool upper); + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_trimatu_ext + : public traits_op_default + { + public: + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset); + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_trimatl_ext + : public traits_op_default + { + public: + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset); + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_trimat_meat.hpp b/src/armadillo/include/armadillo_bits/spop_trimat_meat.hpp new file mode 100644 index 0000000..548b8e8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_trimat_meat.hpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_trimat +//! @{ + + + +template +inline +void +spop_trimat::apply_noalias(SpMat& out, const SpProxy& P, const bool upper) + { + arma_extra_debug_sigprint(); + + typename SpProxy::const_iterator_type it = P.begin(); + + const uword old_n_nonzero = P.get_n_nonzero(); + uword new_n_nonzero = 0; + + if(upper) + { + // upper triangular: count elements on the diagonal and above the diagonal + + for(uword i=0; i < old_n_nonzero; ++i) + { + new_n_nonzero += (it.row() <= it.col()) ? uword(1) : uword(0); + ++it; + } + } + else + { + // lower triangular: count elements on the diagonal and below the diagonal + + for(uword i=0; i < old_n_nonzero; ++i) + { + new_n_nonzero += (it.row() >= it.col()) ? uword(1) : uword(0); + ++it; + } + } + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.reserve(n_rows, n_cols, new_n_nonzero); + + uword new_index = 0; + + it = P.begin(); + + if(upper) + { + // upper triangular: copy elements on the diagonal and above the diagonal + + for(uword i=0; i < old_n_nonzero; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + if(row <= col) + { + access::rw(out.values[new_index]) = (*it); + access::rw(out.row_indices[new_index]) = row; + access::rw(out.col_ptrs[col + 1])++; + ++new_index; + } + + ++it; + } + } + else + { + // lower triangular: copy elements on the diagonal and below the diagonal + + for(uword i=0; i < old_n_nonzero; ++i) + { + const uword row = it.row(); + const uword col = it.col(); + + if(row >= col) + { + access::rw(out.values[new_index]) = (*it); + access::rw(out.row_indices[new_index]) = row; + access::rw(out.col_ptrs[col + 1])++; + ++new_index; + } + + ++it; + } + } + + for(uword i=0; i < n_cols; ++i) + { + access::rw(out.col_ptrs[i + 1]) += out.col_ptrs[i]; + } + } + + + +template +inline +void +spop_trimat::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const SpProxy P(in.m); + + arma_debug_check( (P.get_n_rows() != P.get_n_cols()), "trimatu()/trimatl(): given matrix must be square sized" ); + + const bool upper = (in.aux_uword_a == 0); + + if(P.is_alias(out)) + { + SpMat tmp; + spop_trimat::apply_noalias(tmp, P, upper); + out.steal_mem(tmp); + } + else + { + spop_trimat::apply_noalias(out, P, upper); + } + } + + + +// + + + +template +inline +void +spop_trimatu_ext::apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu(): requested diagonal is out of bounds" ); + + if(A.n_nonzero == 0) { out.zeros(n_rows, n_cols); return; } + + out.reserve(n_rows, n_cols, A.n_nonzero); // upper bound on n_nonzero + + uword count = 0; + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i < n_cols; ++i) + { + const uword col = i + col_offset; + + if(i < N) + { + typename SpMat::const_col_iterator it = A.begin_col_no_sync(col); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col); + + const uword end_row = i + row_offset; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + + if(it_row <= end_row) + { + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + else + { + break; + } + } + } + else + { + if(col < n_cols) + { + typename SpMat::const_col_iterator it = A.begin_col_no_sync(col); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col); + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + } + } + } + + for(uword i=0; i < n_cols; ++i) + { + access::rw(out.col_ptrs[i + 1]) += out.col_ptrs[i]; + } + + if(count < A.n_nonzero) { out.mem_resize(count); } + } + + + +template +inline +void +spop_trimatu_ext::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + const SpMat& A = U.M; + + arma_debug_check( (A.is_square() == false), "trimatu(): given matrix must be square sized" ); + + const uword row_offset = in.aux_uword_a; + const uword col_offset = in.aux_uword_b; + + if(U.is_alias(out)) + { + SpMat tmp; + spop_trimatu_ext::apply_noalias(tmp, A, row_offset, col_offset); + out.steal_mem(tmp); + } + else + { + spop_trimatu_ext::apply_noalias(out, A, row_offset, col_offset); + } + } + + + +// + + + +template +inline +void +spop_trimatl_ext::apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl(): requested diagonal is out of bounds" ); + + if(A.n_nonzero == 0) { out.zeros(n_rows, n_cols); return; } + + out.reserve(n_rows, n_cols, A.n_nonzero); // upper bound on n_nonzero + + uword count = 0; + + if(col_offset > 0) + { + typename SpMat::const_col_iterator it = A.begin_col_no_sync(0); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col_offset-1); + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + } + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i < N; ++i) + { + const uword start_row = i + row_offset; + const uword col = i + col_offset; + + typename SpMat::const_col_iterator it = A.begin_col_no_sync(col); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col); + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + + if(it_row >= start_row) + { + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + } + } + + for(uword i=0; i < n_cols; ++i) + { + access::rw(out.col_ptrs[i + 1]) += out.col_ptrs[i]; + } + + if(count < A.n_nonzero) { out.mem_resize(count); } + } + + + +template +inline +void +spop_trimatl_ext::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + const SpMat& A = U.M; + + arma_debug_check( (A.is_square() == false), "trimatl(): given matrix must be square sized" ); + + const uword row_offset = in.aux_uword_a; + const uword col_offset = in.aux_uword_b; + + if(U.is_alias(out)) + { + SpMat tmp; + spop_trimatl_ext::apply_noalias(tmp, A, row_offset, col_offset); + out.steal_mem(tmp); + } + else + { + spop_trimatl_ext::apply_noalias(out, A, row_offset, col_offset); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_var_bones.hpp b/src/armadillo/include/armadillo_bits/spop_var_bones.hpp new file mode 100644 index 0000000..09f0e24 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_var_bones.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_var +//! @{ + + + +//! Class for finding variance values of a sparse matrix +class spop_var + : public traits_op_xvec + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& in); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& p, const uword norm_type, const uword dim); + + // Calculate variance of a sparse vector, where we can directly use the memory. + template + inline static typename T1::pod_type var_vec(const T1& X, const uword norm_type = 0); + + // Calculate the variance directly. Because this is for sparse matrices, we + // specify both the number of elements in the array (the length of the array) + // as well as the actual number of elements when zeros are included. + template + inline static eT direct_var(const eT* const X, const uword length, const uword N, const uword norm_type = 0); + + // For complex numbers. + + template + inline static T direct_var(const std::complex* const X, const uword length, const uword N, const uword norm_type = 0); + + // Calculate the variance using iterators, for non-complex numbers. + template + inline static eT iterator_var(T1& it, const T1& end, const uword n_zero, const uword norm_type, const eT junk1, const typename arma_not_cx::result* junk2 = nullptr); + + // Calculate the variance using iterators, for complex numbers. + template + inline static typename get_pod_type::result iterator_var(T1& it, const T1& end, const uword n_zero, const uword norm_type, const eT junk1, const typename arma_cx_only::result* junk2 = nullptr); + + }; + + + +//! @} + diff --git a/src/armadillo/include/armadillo_bits/spop_var_meat.hpp b/src/armadillo/include/armadillo_bits/spop_var_meat.hpp new file mode 100644 index 0000000..8d01a44 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_var_meat.hpp @@ -0,0 +1,414 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_var +//! @{ + + + +template +inline +void +spop_var::apply(SpMat& out, const mtSpOp& in) + { + arma_extra_debug_sigprint(); + + //typedef typename T1::elem_type in_eT; + typedef typename T1::pod_type out_eT; + + const uword norm_type = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" ); + arma_debug_check( (dim > 1), "var(): parameter 'dim' must be 0 or 1" ); + + const SpProxy p(in.m); + + if(p.is_alias(out) == false) + { + spop_var::apply_noalias(out, p, norm_type, dim); + } + else + { + SpMat tmp; + + spop_var::apply_noalias(tmp, p, norm_type, dim); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spop_var::apply_noalias + ( + SpMat& out, + const SpProxy& p, + const uword norm_type, + const uword dim + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT; + //typedef typename T1::pod_type out_eT; + + const uword p_n_rows = p.get_n_rows(); + const uword p_n_cols = p.get_n_cols(); + + // TODO: this is slow; rewrite based on the approach used by sparse mean() + + if(dim == 0) // find variance in each column + { + arma_extra_debug_print("spop_var::apply_noalias(): dim = 0"); + + out.set_size((p_n_rows > 0) ? 1 : 0, p_n_cols); + + if( (p_n_rows == 0) || (p.get_n_nonzero() == 0) ) { return; } + + for(uword col = 0; col < p_n_cols; ++col) + { + if(SpProxy::use_iterator) + { + // We must use an iterator; we can't access memory directly. + typename SpProxy::const_iterator_type it = p.begin_col(col); + typename SpProxy::const_iterator_type end = p.begin_col(col + 1); + + const uword n_zero = p_n_rows - (end.pos() - it.pos()); + + // in_eT is used just to get the specialization right (complex / noncomplex) + out.at(0, col) = spop_var::iterator_var(it, end, n_zero, norm_type, in_eT(0)); + } + else + { + // We can use direct memory access to calculate the variance. + out.at(0, col) = spop_var::direct_var + ( + &p.get_values()[p.get_col_ptrs()[col]], + p.get_col_ptrs()[col + 1] - p.get_col_ptrs()[col], + p_n_rows, + norm_type + ); + } + } + } + else + if(dim == 1) // find variance in each row + { + arma_extra_debug_print("spop_var::apply_noalias(): dim = 1"); + + out.set_size(p_n_rows, (p_n_cols > 0) ? 1 : 0); + + if( (p_n_cols == 0) || (p.get_n_nonzero() == 0) ) { return; } + + for(uword row = 0; row < p_n_rows; ++row) + { + // We have to use an iterator here regardless of whether or not we can + // directly access memory. + typename SpProxy::const_row_iterator_type it = p.begin_row(row); + typename SpProxy::const_row_iterator_type end = p.end_row(row); + + const uword n_zero = p_n_cols - (end.pos() - it.pos()); + + out.at(row, 0) = spop_var::iterator_var(it, end, n_zero, norm_type, in_eT(0)); + } + } + } + + + +template +inline +typename T1::pod_type +spop_var::var_vec + ( + const T1& X, + const uword norm_type + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (norm_type > 1), "var(): parameter 'norm_type' must be 0 or 1" ); + + // conditionally unwrap it into a temporary and then directly operate. + + const unwrap_spmat tmp(X); + + return direct_var(tmp.M.values, tmp.M.n_nonzero, tmp.M.n_elem, norm_type); + } + + + +template +inline +eT +spop_var::direct_var + ( + const eT* const X, + const uword length, + const uword N, + const uword norm_type + ) + { + arma_extra_debug_sigprint(); + + if(length >= 2 && N >= 2) + { + const eT acc1 = spop_mean::direct_mean(X, length, N); + + eT acc2 = eT(0); + eT acc3 = eT(0); + + uword i, j; + + for(i = 0, j = 1; j < length; i += 2, j += 2) + { + const eT Xi = X[i]; + const eT Xj = X[j]; + + const eT tmpi = acc1 - Xi; + const eT tmpj = acc1 - Xj; + + acc2 += tmpi * tmpi + tmpj * tmpj; + acc3 += tmpi + tmpj; + } + + if(i < length) + { + const eT Xi = X[i]; + + const eT tmpi = acc1 - Xi; + + acc2 += tmpi * tmpi; + acc3 += tmpi; + } + + // Now add in all zero elements. + acc2 += (N - length) * (acc1 * acc1); + acc3 += (N - length) * acc1; + + const eT norm_val = (norm_type == 0) ? eT(N - 1) : eT(N); + const eT var_val = (acc2 - (acc3 * acc3) / eT(N)) / norm_val; + + return var_val; + } + else if(length == 1 && N > 1) // if N == 1, then variance is zero. + { + const eT mean = X[0] / eT(N); + const eT val = mean - X[0]; + + const eT acc2 = (val * val) + (N - length) * (mean * mean); + const eT acc3 = val + (N - length) * mean; + + const eT norm_val = (norm_type == 0) ? eT(N - 1) : eT(N); + const eT var_val = (acc2 - (acc3 * acc3) / eT(N)) / norm_val; + + return var_val; + } + else + { + return eT(0); + } + } + + + +template +inline +T +spop_var::direct_var + ( + const std::complex* const X, + const uword length, + const uword N, + const uword norm_type + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + if(length >= 2 && N >= 2) + { + const eT acc1 = spop_mean::direct_mean(X, length, N); + + T acc2 = T(0); + eT acc3 = eT(0); + + for(uword i = 0; i < length; ++i) + { + const eT tmp = acc1 - X[i]; + + acc2 += std::norm(tmp); + acc3 += tmp; + } + + // Add zero elements to sums + acc2 += std::norm(acc1) * T(N - length); + acc3 += acc1 * T(N - length); + + const T norm_val = (norm_type == 0) ? T(N - 1) : T(N); + const T var_val = (acc2 - std::norm(acc3) / T(N)) / norm_val; + + return var_val; + } + else if(length == 1 && N > 1) // if N == 1, then variance is zero. + { + const eT mean = X[0] / T(N); + const eT val = mean - X[0]; + + const T acc2 = std::norm(val) + (N - length) * std::norm(mean); + const eT acc3 = val + T(N - length) * mean; + + const T norm_val = (norm_type == 0) ? T(N - 1) : T(N); + const T var_val = (acc2 - std::norm(acc3) / T(N)) / norm_val; + + return var_val; + } + else + { + return T(0); // All elements are zero + } + } + + + +template +inline +eT +spop_var::iterator_var + ( + T1& it, + const T1& end, + const uword n_zero, + const uword norm_type, + const eT junk1, + const typename arma_not_cx::result* junk2 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + T1 new_it(it); // for mean + // T1 backup_it(it); // in case we have to call robust iterator_var + eT mean = spop_mean::iterator_mean(new_it, end, n_zero, eT(0)); + + eT acc2 = eT(0); + eT acc3 = eT(0); + + const uword it_begin_pos = it.pos(); + + while(it != end) + { + const eT tmp = mean - (*it); + + acc2 += (tmp * tmp); + acc3 += (tmp); + + ++it; + } + + const uword n_nonzero = (it.pos() - it_begin_pos); + if(n_nonzero == 0) + { + return eT(0); + } + + if(n_nonzero + n_zero == 1) + { + return eT(0); // only one element + } + + // Add in entries for zeros. + acc2 += eT(n_zero) * (mean * mean); + acc3 += eT(n_zero) * mean; + + const eT norm_val = (norm_type == 0) ? eT(n_zero + n_nonzero - 1) : eT(n_zero + n_nonzero); + const eT var_val = (acc2 - (acc3 * acc3) / eT(n_nonzero + n_zero)) / norm_val; + + return var_val; + } + + + +template +inline +typename get_pod_type::result +spop_var::iterator_var + ( + T1& it, + const T1& end, + const uword n_zero, + const uword norm_type, + const eT junk1, + const typename arma_cx_only::result* junk2 + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename get_pod_type::result T; + + T1 new_it(it); // for mean + // T1 backup_it(it); // in case we have to call robust iterator_var + eT mean = spop_mean::iterator_mean(new_it, end, n_zero, eT(0)); + + T acc2 = T(0); + eT acc3 = eT(0); + + const uword it_begin_pos = it.pos(); + + while(it != end) + { + eT tmp = mean - (*it); + + acc2 += std::norm(tmp); + acc3 += (tmp); + + ++it; + } + + const uword n_nonzero = (it.pos() - it_begin_pos); + if(n_nonzero == 0) + { + return T(0); + } + + if(n_nonzero + n_zero == 1) + { + return T(0); // only one element + } + + // Add in entries for zero elements. + acc2 += T(n_zero) * std::norm(mean); + acc3 += T(n_zero) * mean; + + const T norm_val = (norm_type == 0) ? T(n_zero + n_nonzero - 1) : T(n_zero + n_nonzero); + const T var_val = (acc2 - std::norm(acc3) / T(n_nonzero + n_zero)) / norm_val; + + return var_val; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_vecnorm_bones.hpp b/src/armadillo/include/armadillo_bits/spop_vecnorm_bones.hpp new file mode 100644 index 0000000..eaeecd8 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_vecnorm_bones.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_vecnorm +//! @{ + + +class spop_vecnorm + : public traits_op_xvec + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& expr); + + template + inline static void apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword k); + }; + + +// + + +class spop_vecnorm_ext + : public traits_op_xvec + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& expr); + + template + inline static void apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword method_id); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_vecnorm_meat.hpp b/src/armadillo/include/armadillo_bits/spop_vecnorm_meat.hpp new file mode 100644 index 0000000..56ef7d9 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_vecnorm_meat.hpp @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_vecnorm +//! @{ + + + +template +inline +void +spop_vecnorm::apply(SpMat& out, const mtSpOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const uword k = expr.aux_uword_a; + const uword dim = expr.aux_uword_b; + + arma_debug_check( (k == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + const unwrap_spmat U(expr.m); + const SpMat& X = U.M; + + X.sync(); + + if(dim == 0) + { + Mat tmp; + + spop_vecnorm::apply_direct(tmp, X, k); + + out = tmp; + } + else + if(dim == 1) + { + Mat< T> tmp; + SpMat Xt; + + spop_strans::apply_noalias(Xt, X); + + spop_vecnorm::apply_direct(tmp, Xt, k); + + out = tmp.t(); + } + } + + + +template +inline +void +spop_vecnorm::apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword k) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out.zeros(1, X.n_cols); + + T* out_mem = out.memptr(); + + for(uword col=0; col < X.n_cols; ++col) + { + const uword col_offset = X.col_ptrs[col ]; + const uword next_col_offset = X.col_ptrs[col + 1]; + + const eT* start_ptr = &X.values[ col_offset]; + const eT* end_ptr = &X.values[next_col_offset]; + + const uword n_elem = end_ptr - start_ptr; + + T out_val = T(0); + + if(n_elem > 0) + { + const Col tmp(const_cast(start_ptr), n_elem, false, false); + + const Proxy< Col > P(tmp); + + if(k == uword(1)) { out_val = op_norm::vec_norm_1(P); } + if(k == uword(2)) { out_val = op_norm::vec_norm_2(P); } + } + + out_mem[col] = out_val; + } + } + + + +// + + + +template +inline +void +spop_vecnorm_ext::apply(SpMat& out, const mtSpOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const uword method_id = expr.aux_uword_a; + const uword dim = expr.aux_uword_b; + + arma_debug_check( (method_id == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + const unwrap_spmat U(expr.m); + const SpMat& X = U.M; + + X.sync(); + + if(dim == 0) + { + Mat tmp; + + spop_vecnorm_ext::apply_direct(tmp, X, method_id); + + out = tmp; + } + else + if(dim == 1) + { + Mat< T> tmp; + SpMat Xt; + + spop_strans::apply_noalias(Xt, X); + + spop_vecnorm_ext::apply_direct(tmp, Xt, method_id); + + out = tmp.t(); + } + } + + + +template +inline +void +spop_vecnorm_ext::apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out.zeros(1, X.n_cols); + + T* out_mem = out.memptr(); + + for(uword col=0; col < X.n_cols; ++col) + { + const uword col_offset = X.col_ptrs[col ]; + const uword next_col_offset = X.col_ptrs[col + 1]; + + const eT* start_ptr = &X.values[ col_offset]; + const eT* end_ptr = &X.values[next_col_offset]; + + const uword n_elem = end_ptr - start_ptr; + + T out_val = T(0); + + if(n_elem > 0) + { + const Col tmp(const_cast(start_ptr), n_elem, false, false); + + const Proxy< Col > P(tmp); + + if(method_id == uword(1)) + { + out_val = op_norm::vec_norm_max(P); + } + else + if(method_id == uword(2)) + { + const T tmp_val = op_norm::vec_norm_min(P); + + out_val = (n_elem < X.n_rows) ? T((std::min)(T(0), tmp_val)) : T(tmp_val); + } + } + + out_mem[col] = out_val; + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_vectorise_bones.hpp b/src/armadillo/include/armadillo_bits/spop_vectorise_bones.hpp new file mode 100644 index 0000000..3e38b25 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_vectorise_bones.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup spop_vectorise +//! @{ + + +class spop_vectorise_col + : public traits_op_col + { + public: + + template inline static void apply(SpMat& out, const SpOp& in); + + template inline static void apply_direct(SpMat& out, const T1& expr); + }; + + + +class spop_vectorise_row + : public traits_op_row + { + public: + + template inline static void apply(SpMat& out, const SpOp& in); + + template inline static void apply_direct(SpMat& out, const T1& expr); + }; + + + +class spop_vectorise_all + : public traits_op_xvec + { + public: + + template inline static void apply(SpMat& out, const SpOp& in); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spop_vectorise_meat.hpp b/src/armadillo/include/armadillo_bits/spop_vectorise_meat.hpp new file mode 100644 index 0000000..5673390 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spop_vectorise_meat.hpp @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup spop_vectorise +//! @{ + + + +template +inline +void +spop_vectorise_col::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + spop_vectorise_col::apply_direct(out, in.m); + } + + + +template +inline +void +spop_vectorise_col::apply_direct(SpMat& out, const T1& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(out.vec_state == 0) + { + out = expr; + + out.reshape(out.n_elem, 1); + } + else + { + SpMat tmp = expr; + + tmp.reshape(tmp.n_elem, 1); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spop_vectorise_row::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + spop_vectorise_row::apply_direct(out, in.m); + } + + + +template +inline +void +spop_vectorise_row::apply_direct(SpMat& out, const T1& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // NOTE: row-wise vectorisation of sparse matrices is not recommended due to the CSC storage format + + if(out.vec_state == 0) + { + out = strans(expr); + + out.reshape(1, out.n_elem); + } + else + { + SpMat tmp = strans(expr); + + tmp.reshape(1, tmp.n_elem); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spop_vectorise_all::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + const uword dim = in.aux_uword_a; + + if(dim == 0) + { + spop_vectorise_col::apply_direct(out, in.m); + } + else + { + spop_vectorise_row::apply_direct(out, in.m); + } + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spsolve_factoriser_bones.hpp b/src/armadillo/include/armadillo_bits/spsolve_factoriser_bones.hpp new file mode 100644 index 0000000..4616e26 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spsolve_factoriser_bones.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spsolve_factoriser +//! @{ + + + +class spsolve_factoriser + { + private: + + void_ptr worker_ptr = nullptr; + uword elem_type_indicator = 0; + uword n_rows = 0; + double rcond_value = double(0); + + template inline void delete_worker(); + + inline void cleanup(); + + + public: + + inline ~spsolve_factoriser(); + inline spsolve_factoriser(); + + inline void reset(); + + inline double rcond() const; + + template inline bool factorise(const SpBase& A_expr, const spsolve_opts_base& settings = spsolve_opts_none(), const typename arma_blas_type_only::result* junk = nullptr); + + template inline bool solve(Mat& X, const Base& B_expr, const typename arma_blas_type_only::result* junk = nullptr); + + inline spsolve_factoriser(const spsolve_factoriser&) = delete; + inline void operator= (const spsolve_factoriser&) = delete; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/spsolve_factoriser_meat.hpp b/src/armadillo/include/armadillo_bits/spsolve_factoriser_meat.hpp new file mode 100644 index 0000000..4450a15 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/spsolve_factoriser_meat.hpp @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spsolve_factoriser +//! @{ + + + +template +inline +void +spsolve_factoriser::delete_worker() + { + arma_extra_debug_sigprint(); + + if(worker_ptr != nullptr) + { + worker_type* ptr = reinterpret_cast(worker_ptr); + + delete ptr; + + worker_ptr = nullptr; + } + } + + + +inline +void +spsolve_factoriser::cleanup() + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + if(elem_type_indicator == 1) { delete_worker< superlu_worker< float> >(); } + else if(elem_type_indicator == 2) { delete_worker< superlu_worker< double> >(); } + else if(elem_type_indicator == 3) { delete_worker< superlu_worker< cx_float> >(); } + else if(elem_type_indicator == 4) { delete_worker< superlu_worker >(); } + } + #endif + + worker_ptr = nullptr; + elem_type_indicator = 0; + n_rows = 0; + rcond_value = double(0); + } + + + +inline +spsolve_factoriser::~spsolve_factoriser() + { + arma_extra_debug_sigprint_this(this); + + cleanup(); + } + + + +inline +spsolve_factoriser::spsolve_factoriser() + { + arma_extra_debug_sigprint_this(this); + } + + + +inline +void +spsolve_factoriser::reset() + { + arma_extra_debug_sigprint(); + + cleanup(); + } + + + +inline +double +spsolve_factoriser::rcond() const + { + arma_extra_debug_sigprint(); + + return rcond_value; + } + + + +template +inline +bool +spsolve_factoriser::factorise + ( + const SpBase& A_expr, + const spsolve_opts_base& settings, + const typename arma_blas_type_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + #if defined(ARMA_USE_SUPERLU) + { + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typedef superlu_worker worker_type; + + // + + cleanup(); + + // + + const unwrap_spmat U(A_expr.get_ref()); + const SpMat& A = U.M; + + if(A.is_square() == false) + { + arma_debug_warn_level(1, "spsolve_factoriser::factorise(): solving under-determined / over-determined systems is currently not supported"); + return false; + } + + n_rows = A.n_rows; + + // + + superlu_opts superlu_opts_default; + + const superlu_opts& opts = (settings.id == 1) ? static_cast(settings) : superlu_opts_default; + + if( (opts.pivot_thresh < double(0)) || (opts.pivot_thresh > double(1)) ) + { + arma_debug_warn_level(1, "spsolve_factoriser::factorise(): pivot_thresh must be in the [0,1] interval" ); + return false; + } + + // + + worker_ptr = new(std::nothrow) worker_type; + + if(worker_ptr == nullptr) + { + arma_debug_warn_level(3, "spsolve_factoriser::factorise(): could not construct worker object"); + return false; + } + + // + + if( is_float::value) { elem_type_indicator = 1; } + else if( is_double::value) { elem_type_indicator = 2; } + else if( is_cx_float::value) { elem_type_indicator = 3; } + else if(is_cx_double::value) { elem_type_indicator = 4; } + + // + + worker_type* local_worker_ptr = reinterpret_cast(worker_ptr); + worker_type& local_worker_ref = (*local_worker_ptr); + + // + + T local_rcond_value = T(0); + + const bool status = local_worker_ref.factorise(local_rcond_value, A, opts); + + rcond_value = double(local_rcond_value); + + if( (status == false) || arma_isnan(local_rcond_value) || ((opts.allow_ugly == false) && (local_rcond_value < std::numeric_limits::epsilon())) ) + { + arma_debug_warn_level(3, "spsolve_factoriser::factorise(): factorisation failed; rcond: ", local_rcond_value); + delete_worker(); + return false; + } + + return true; + } + #else + { + arma_ignore(A_expr); + arma_ignore(settings); + arma_stop_logic_error("spsolve_factoriser::factorise(): use of SuperLU must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +spsolve_factoriser::solve + ( + Mat& X, + const Base& B_expr, + const typename arma_blas_type_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + #if defined(ARMA_USE_SUPERLU) + { + typedef typename T1::elem_type eT; + + typedef superlu_worker worker_type; + + if(worker_ptr == nullptr) + { + arma_debug_warn_level(2, "spsolve_factoriser::solve(): no factorisation available"); + X.soft_reset(); + return false; + } + + bool type_mismatch = false; + + if( (is_float::value) && (elem_type_indicator != 1) ) { type_mismatch = true; } + else if( (is_double::value) && (elem_type_indicator != 2) ) { type_mismatch = true; } + else if( (is_cx_float::value) && (elem_type_indicator != 3) ) { type_mismatch = true; } + else if((is_cx_double::value) && (elem_type_indicator != 4) ) { type_mismatch = true; } + + if(type_mismatch) + { + arma_debug_warn_level(1, "spsolve_factoriser::solve(): matrix type mismatch"); + X.soft_reset(); + return false; + } + + const quasi_unwrap U(B_expr.get_ref()); + const Mat& B = U.M; + + if(n_rows != B.n_rows) + { + arma_debug_warn_level(1, "spsolve_factoriser::solve(): matrix size mismatch"); + X.soft_reset(); + return false; + } + + const bool is_alias = U.is_alias(X); + + Mat tmp; + Mat& out = is_alias ? tmp : X; + + worker_type* local_worker_ptr = reinterpret_cast(worker_ptr); + worker_type& local_worker_ref = (*local_worker_ptr); + + const bool status = local_worker_ref.solve(out,B); + + if(is_alias) { X.steal_mem(tmp); } + + if(status == false) + { + arma_debug_warn_level(3, "spsolve_factoriser::solve(): solution not found"); + X.soft_reset(); + return false; + } + + return true; + } + #else + { + arma_ignore(X); + arma_ignore(B_expr); + arma_stop_logic_error("spsolve_factoriser::solve(): use of SuperLU must be enabled"); + return false; + } + #endif + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/strip.hpp b/src/armadillo/include/armadillo_bits/strip.hpp new file mode 100644 index 0000000..73850c5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/strip.hpp @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup strip +//! @{ + + + +template +struct strip_diagmat + { + typedef T1 stored_type; + + inline + strip_diagmat(const T1& X) + : M(X) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_diagmat = false; + + const T1& M; + }; + + + +template +struct strip_diagmat< Op > + { + typedef T1 stored_type; + + inline + strip_diagmat(const Op& X) + : M(X.m) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_diagmat = true; + + const T1& M; + }; + + + +template +struct strip_inv + { + typedef T1 stored_type; + + inline + strip_inv(const T1& X) + : M(X) + { + arma_extra_debug_sigprint(); + } + + const T1& M; + + static constexpr bool do_inv_gen = false; + static constexpr bool do_inv_spd = false; + }; + + + +template +struct strip_inv< Op > + { + typedef T1 stored_type; + + inline + strip_inv(const Op& X) + : M(X.m) + { + arma_extra_debug_sigprint(); + } + + const T1& M; + + static constexpr bool do_inv_gen = true; + static constexpr bool do_inv_spd = false; + }; + + + +template +struct strip_inv< Op > + { + typedef T1 stored_type; + + inline + strip_inv(const Op& X) + : M(X.m) + { + arma_extra_debug_sigprint(); + } + + const T1& M; + + static constexpr bool do_inv_gen = false; + static constexpr bool do_inv_spd = true; + }; + + + +template +struct strip_trimat + { + typedef T1 stored_type; + + const T1& M; + + static constexpr bool do_trimat = false; + static constexpr bool do_triu = false; + static constexpr bool do_tril = false; + + inline + strip_trimat(const T1& X) + : M(X) + { + arma_extra_debug_sigprint(); + } + }; + + + +template +struct strip_trimat< Op > + { + typedef T1 stored_type; + + const T1& M; + + static constexpr bool do_trimat = true; + + const bool do_triu; + const bool do_tril; + + inline + strip_trimat(const Op& X) + : M(X.m) + , do_triu(X.aux_uword_a == 0) + , do_tril(X.aux_uword_a == 1) + { + arma_extra_debug_sigprint(); + } + }; + + + +// + + + +template +struct sp_strip_trans + { + typedef T1 stored_type; + + inline + sp_strip_trans(const T1& X) + : M(X) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_htrans = false; + static constexpr bool do_strans = false; + + const T1& M; + }; + + + +template +struct sp_strip_trans< SpOp > + { + typedef T1 stored_type; + + inline + sp_strip_trans(const SpOp& X) + : M(X.m) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_htrans = true; + static constexpr bool do_strans = false; + + const T1& M; + }; + + + +template +struct sp_strip_trans< SpOp > + { + typedef T1 stored_type; + + inline + sp_strip_trans(const SpOp& X) + : M(X.m) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_htrans = false; + static constexpr bool do_strans = true; + + const T1& M; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_bones.hpp b/src/armadillo/include/armadillo_bits/subview_bones.hpp new file mode 100644 index 0000000..95553a1 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_bones.hpp @@ -0,0 +1,673 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview +//! @{ + + +//! Class for storing data required to construct or apply operations to a submatrix +//! (ie. where the submatrix starts and ends as well as a reference/pointer to the original matrix), +template +class subview : public Base< eT, subview > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + arma_aligned const Mat& m; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const uword aux_row1; + const uword aux_col1; + + const uword n_rows; + const uword n_cols; + const uword n_elem; + + protected: + + arma_inline subview(const Mat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols); + + + public: + + inline ~subview(); + inline subview() = delete; + + inline subview(const subview& in); + inline subview( subview&& in); + + template inline void inplace_op(const eT val ); + template inline void inplace_op(const Base& x, const char* identifier); + template inline void inplace_op(const subview& x, const char* identifier); + + // deliberately returning void + + inline void operator= (const eT val); + inline void operator+= (const eT val); + inline void operator-= (const eT val); + inline void operator*= (const eT val); + inline void operator/= (const eT val); + + inline void operator= (const subview& x); + inline void operator+= (const subview& x); + inline void operator-= (const subview& x); + inline void operator%= (const subview& x); + inline void operator/= (const subview& x); + + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + + template inline void operator= (const SpBase& x); + template inline void operator+= (const SpBase& x); + template inline void operator-= (const SpBase& x); + template inline void operator%= (const SpBase& x); + template inline void operator/= (const SpBase& x); + + template + inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); + + inline void operator=(const std::initializer_list& list); + inline void operator=(const std::initializer_list< std::initializer_list >& list); + + + inline static void extract(Mat& out, const subview& in); + + inline static void plus_inplace(Mat& out, const subview& in); + inline static void minus_inplace(Mat& out, const subview& in); + inline static void schur_inplace(Mat& out, const subview& in); + inline static void div_inplace(Mat& out, const subview& in); + + template inline void for_each(functor F); + template inline void for_each(functor F) const; + + template inline void transform(functor F); + template inline void imbue(functor F); + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + inline void eye(); + inline void randu(); + inline void randn(); + + arma_warn_unused inline eT at_alt (const uword ii) const; + + arma_warn_unused inline eT& operator[](const uword ii); + arma_warn_unused inline eT operator[](const uword ii) const; + + arma_warn_unused inline eT& operator()(const uword ii); + arma_warn_unused inline eT operator()(const uword ii) const; + + arma_warn_unused inline eT& operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; + + arma_warn_unused inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; + + arma_warn_unused inline eT& front(); + arma_warn_unused inline eT front() const; + + arma_warn_unused inline eT& back(); + arma_warn_unused inline eT back() const; + + arma_inline eT* colptr(const uword in_col); + arma_inline const eT* colptr(const uword in_col) const; + + template + inline bool check_overlap(const subview& x) const; + + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_finite() const; + arma_warn_unused inline bool is_zero(const pod_type tol = 0) const; + + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; + + inline subview_row row(const uword row_num); + inline const subview_row row(const uword row_num) const; + + inline subview_row operator()(const uword row_num, const span& col_span); + inline const subview_row operator()(const uword row_num, const span& col_span) const; + + inline subview_col col(const uword col_num); + inline const subview_col col(const uword col_num) const; + + inline subview_col operator()(const span& row_span, const uword col_num); + inline const subview_col operator()(const span& row_span, const uword col_num) const; + + inline Col unsafe_col(const uword col_num); + inline const Col unsafe_col(const uword col_num) const; + + inline subview rows(const uword in_row1, const uword in_row2); + inline const subview rows(const uword in_row1, const uword in_row2) const; + + inline subview cols(const uword in_col1, const uword in_col2); + inline const subview cols(const uword in_col1, const uword in_col2) const; + + inline subview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); + inline const subview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const; + + inline subview submat (const span& row_span, const span& col_span); + inline const subview submat (const span& row_span, const span& col_span) const; + + inline subview operator()(const span& row_span, const span& col_span); + inline const subview operator()(const span& row_span, const span& col_span) const; + + inline subview_each1< subview, 0 > each_col(); + inline subview_each1< subview, 1 > each_row(); + + template inline subview_each2< subview, 0, T1 > each_col(const Base& indices); + template inline subview_each2< subview, 1, T1 > each_row(const Base& indices); + + inline void each_col(const std::function< void( Col&) >& F); + inline void each_col(const std::function< void(const Col&) >& F) const; + + inline void each_row(const std::function< void( Row&) >& F); + inline void each_row(const std::function< void(const Row&) >& F) const; + + inline diagview diag(const sword in_id = 0); + inline const diagview diag(const sword in_id = 0) const; + + inline void swap_rows(const uword in_row1, const uword in_row2); + inline void swap_cols(const uword in_col1, const uword in_col2); + + + class const_iterator; + + class iterator + { + public: + + inline iterator(); + inline iterator(const iterator& X); + inline iterator(subview& in_sv, const uword in_row, const uword in_col); + + arma_warn_unused inline eT& operator*(); + + inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); + + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; + + typedef std::forward_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef eT* pointer; + typedef eT& reference; + + arma_aligned Mat* M; + arma_aligned eT* current_ptr; + arma_aligned uword current_row; + arma_aligned uword current_col; + + arma_aligned const uword aux_row1; + arma_aligned const uword aux_row2_p1; + }; + + + class const_iterator + { + public: + + inline const_iterator(); + inline const_iterator(const iterator& X); + inline const_iterator(const const_iterator& X); + inline const_iterator(const subview& in_sv, const uword in_row, const uword in_col); + + arma_warn_unused inline const eT& operator*(); + + inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); + + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; + + // So that we satisfy the STL iterator types. + typedef std::forward_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef const eT* pointer; + typedef const eT& reference; + + arma_aligned const Mat* M; + arma_aligned const eT* current_ptr; + arma_aligned uword current_row; + arma_aligned uword current_col; + + arma_aligned const uword aux_row1; + arma_aligned const uword aux_row2_p1; + }; + + + class const_row_iterator; + + class row_iterator + { + public: + + inline row_iterator(); + inline row_iterator(const row_iterator& X); + inline row_iterator(subview& in_sv, const uword in_row, const uword in_col); + + arma_warn_unused inline eT& operator* (); + + inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); + + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; + + typedef std::forward_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef eT* pointer; + typedef eT& reference; + + arma_aligned Mat* M; + arma_aligned uword current_row; + arma_aligned uword current_col; + + arma_aligned const uword aux_col1; + arma_aligned const uword aux_col2_p1; + }; + + + class const_row_iterator + { + public: + + inline const_row_iterator(); + inline const_row_iterator(const row_iterator& X); + inline const_row_iterator(const const_row_iterator& X); + inline const_row_iterator(const subview& in_sv, const uword in_row, const uword in_col); + + arma_warn_unused inline const eT& operator*() const; + + inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); + + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; + + typedef std::forward_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef const eT* pointer; + typedef const eT& reference; + + arma_aligned const Mat* M; + arma_aligned uword current_row; + arma_aligned uword current_col; + + arma_aligned const uword aux_col1; + arma_aligned const uword aux_col2_p1; + }; + + + + inline iterator begin(); + inline const_iterator begin() const; + inline const_iterator cbegin() const; + + inline iterator end(); + inline const_iterator end() const; + inline const_iterator cend() const; + + + friend class Mat; + }; + + + +template +class subview_col : public subview + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + const eT* colmem; + + inline void operator= (const subview& x); + inline void operator= (const subview_col& x); + inline void operator= (const eT val); + inline void operator= (const std::initializer_list& list); + + template inline void operator= (const Base& x); + template inline void operator= (const SpBase& x); + + template + inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); + + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; + + arma_warn_unused arma_inline const Op,op_strans> as_row() const; + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + + arma_inline eT at_alt (const uword i) const; + + arma_inline eT& operator[](const uword i); + arma_inline eT operator[](const uword i) const; + + inline eT& operator()(const uword i); + inline eT operator()(const uword i) const; + + inline eT& operator()(const uword in_row, const uword in_col); + inline eT operator()(const uword in_row, const uword in_col) const; + + inline eT& at(const uword in_row, const uword in_col); + inline eT at(const uword in_row, const uword in_col) const; + + arma_inline eT* colptr(const uword in_col); + arma_inline const eT* colptr(const uword in_col) const; + + inline subview_col rows(const uword in_row1, const uword in_row2); + inline const subview_col rows(const uword in_row1, const uword in_row2) const; + + inline subview_col subvec(const uword in_row1, const uword in_row2); + inline const subview_col subvec(const uword in_row1, const uword in_row2) const; + + inline subview_col subvec(const uword start_row, const SizeMat& s); + inline const subview_col subvec(const uword start_row, const SizeMat& s) const; + + inline subview_col head(const uword N); + inline const subview_col head(const uword N) const; + + inline subview_col tail(const uword N); + inline const subview_col tail(const uword N) const; + + arma_warn_unused inline eT min() const; + arma_warn_unused inline eT max() const; + + inline eT min(uword& index_of_min_val) const; + inline eT max(uword& index_of_max_val) const; + + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + inline subview_col(const subview_col& in); + inline subview_col( subview_col&& in); + + + protected: + + inline subview_col(const Mat& in_m, const uword in_col); + inline subview_col(const Mat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows); + inline subview_col() = delete; + + + friend class Mat; + friend class Col; + friend class subview; + }; + + + +template +class subview_cols : public subview + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + inline subview_cols(const subview_cols& in); + inline subview_cols( subview_cols&& in); + + inline void operator= (const subview& x); + inline void operator= (const subview_cols& x); + inline void operator= (const eT val); + inline void operator= (const std::initializer_list& list); + inline void operator= (const std::initializer_list< std::initializer_list >& list); + + template inline void operator= (const Base& x); + template inline void operator= (const SpBase& x); + + template + inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); + + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; + + arma_warn_unused arma_inline const Op,op_vectorise_col> as_col() const; + + arma_warn_unused inline eT at_alt (const uword ii) const; + + arma_warn_unused inline eT& operator[](const uword ii); + arma_warn_unused inline eT operator[](const uword ii) const; + + arma_warn_unused inline eT& operator()(const uword ii); + arma_warn_unused inline eT operator()(const uword ii) const; + + arma_warn_unused inline eT& operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; + + arma_warn_unused inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; + + arma_inline eT* colptr(const uword in_col); + arma_inline const eT* colptr(const uword in_col) const; + + protected: + + inline subview_cols(const Mat& in_m, const uword in_col1, const uword in_n_cols); + inline subview_cols() = delete; + + friend class Mat; + friend class subview; + }; + + + +template +class subview_row : public subview + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + inline void operator= (const subview& x); + inline void operator= (const subview_row& x); + inline void operator= (const eT val); + inline void operator= (const std::initializer_list& list); + + template inline void operator= (const Base& x); + template inline void operator= (const SpBase& x); + + template + inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); + + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; + + arma_warn_unused arma_inline const Op,op_strans> as_col() const; + + inline eT at_alt (const uword i) const; + + inline eT& operator[](const uword i); + inline eT operator[](const uword i) const; + + inline eT& operator()(const uword i); + inline eT operator()(const uword i) const; + + inline eT& operator()(const uword in_row, const uword in_col); + inline eT operator()(const uword in_row, const uword in_col) const; + + inline eT& at(const uword in_row, const uword in_col); + inline eT at(const uword in_row, const uword in_col) const; + + inline subview_row cols(const uword in_col1, const uword in_col2); + inline const subview_row cols(const uword in_col1, const uword in_col2) const; + + inline subview_row subvec(const uword in_col1, const uword in_col2); + inline const subview_row subvec(const uword in_col1, const uword in_col2) const; + + inline subview_row subvec(const uword start_col, const SizeMat& s); + inline const subview_row subvec(const uword start_col, const SizeMat& s) const; + + inline subview_row head(const uword N); + inline const subview_row head(const uword N) const; + + inline subview_row tail(const uword N); + inline const subview_row tail(const uword N) const; + + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + inline typename subview::row_iterator begin(); + inline typename subview::const_row_iterator begin() const; + inline typename subview::const_row_iterator cbegin() const; + + inline typename subview::row_iterator end(); + inline typename subview::const_row_iterator end() const; + inline typename subview::const_row_iterator cend() const; + + inline subview_row(const subview_row& in); + inline subview_row( subview_row&& in); + + + protected: + + inline subview_row(const Mat& in_m, const uword in_row); + inline subview_row(const Mat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols); + inline subview_row() = delete; + + + friend class Mat; + friend class Row; + friend class subview; + }; + + + +template +class subview_row_strans : public Base< eT, subview_row_strans > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row& sv_row; + + const uword n_rows; // equal to n_elem + const uword n_elem; + + static constexpr uword n_cols = 1; + + + inline explicit subview_row_strans(const subview_row& in_sv_row); + + inline void extract(Mat& out) const; + + inline eT at_alt (const uword i) const; + + inline eT operator[](const uword i) const; + inline eT operator()(const uword i) const; + + inline eT operator()(const uword in_row, const uword in_col) const; + inline eT at(const uword in_row, const uword in_col) const; + }; + + + +template +class subview_row_htrans : public Base< eT, subview_row_htrans > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const subview_row& sv_row; + + const uword n_rows; // equal to n_elem + const uword n_elem; + + static constexpr uword n_cols = 1; + + + inline explicit subview_row_htrans(const subview_row& in_sv_row); + + inline void extract(Mat& out) const; + + inline eT at_alt (const uword i) const; + + inline eT operator[](const uword i) const; + inline eT operator()(const uword i) const; + + inline eT operator()(const uword in_row, const uword in_col) const; + inline eT at(const uword in_row, const uword in_col) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_cube_bones.hpp b/src/armadillo/include/armadillo_bits/subview_cube_bones.hpp new file mode 100644 index 0000000..ae71e67 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_cube_bones.hpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_cube +//! @{ + + +//! Class for storing data required to construct or apply operations to a subcube +//! (ie. where the subcube starts and ends as well as a reference/pointer to the original cube), +template +class subview_cube : public BaseCube< eT, subview_cube > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + arma_aligned const Cube& m; + + const uword aux_row1; + const uword aux_col1; + const uword aux_slice1; + + const uword n_rows; + const uword n_cols; + const uword n_elem_slice; + const uword n_slices; + const uword n_elem; + + + protected: + + arma_inline subview_cube(const Cube& in_m, const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); + + + public: + + inline ~subview_cube(); + inline subview_cube() = delete; + + inline subview_cube(const subview_cube& in); + inline subview_cube( subview_cube&& in); + + template inline void inplace_op(const eT val ); + template inline void inplace_op(const BaseCube& x, const char* identifier); + template inline void inplace_op(const subview_cube& x, const char* identifier); + + inline void operator= (const eT val); + inline void operator+= (const eT val); + inline void operator-= (const eT val); + inline void operator*= (const eT val); + inline void operator/= (const eT val); + + // deliberately returning void + template inline void operator= (const BaseCube& x); + template inline void operator+= (const BaseCube& x); + template inline void operator-= (const BaseCube& x); + template inline void operator%= (const BaseCube& x); + template inline void operator/= (const BaseCube& x); + + inline void operator= (const subview_cube& x); + inline void operator+= (const subview_cube& x); + inline void operator-= (const subview_cube& x); + inline void operator%= (const subview_cube& x); + inline void operator/= (const subview_cube& x); + + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + + template inline void operator=(const GenCube& x); + + inline static void extract(Cube& out, const subview_cube& in); + inline static void plus_inplace(Cube& out, const subview_cube& in); + inline static void minus_inplace(Cube& out, const subview_cube& in); + inline static void schur_inplace(Cube& out, const subview_cube& in); + inline static void div_inplace(Cube& out, const subview_cube& in); + + inline static void extract(Mat& out, const subview_cube& in); + inline static void plus_inplace(Mat& out, const subview_cube& in); + inline static void minus_inplace(Mat& out, const subview_cube& in); + inline static void schur_inplace(Mat& out, const subview_cube& in); + inline static void div_inplace(Mat& out, const subview_cube& in); + + template inline void for_each(functor F); + template inline void for_each(functor F) const; + + template inline void transform(functor F); + template inline void imbue(functor F); + + inline void each_slice(const std::function< void( Mat&) >& F); + inline void each_slice(const std::function< void(const Mat&) >& F) const; + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + inline void randu(); + inline void randn(); + + arma_warn_unused inline bool is_finite() const; + arma_warn_unused inline bool is_zero(const pod_type tol = 0) const; + + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; + + inline eT at_alt (const uword i) const; + + inline eT& operator[](const uword i); + inline eT operator[](const uword i) const; + + inline eT& operator()(const uword i); + inline eT operator()(const uword i) const; + + arma_inline eT& operator()(const uword in_row, const uword in_col, const uword in_slice); + arma_inline eT operator()(const uword in_row, const uword in_col, const uword in_slice) const; + + arma_inline eT& at(const uword in_row, const uword in_col, const uword in_slice); + arma_inline eT at(const uword in_row, const uword in_col, const uword in_slice) const; + + arma_inline eT* slice_colptr(const uword in_slice, const uword in_col); + arma_inline const eT* slice_colptr(const uword in_slice, const uword in_col) const; + + template + inline bool check_overlap(const subview_cube& x) const; + + inline bool check_overlap(const Mat& x) const; + + + class const_iterator; + + class iterator + { + public: + + inline iterator(); + inline iterator(const iterator& X); + inline iterator(subview_cube& in_sv, const uword in_row, const uword in_col, const uword in_slice); + + arma_warn_unused inline eT& operator*(); + + inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); + + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; + + typedef std::forward_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef eT* pointer; + typedef eT& reference; + + arma_aligned Cube* M; + arma_aligned eT* current_ptr; + arma_aligned uword current_row; + arma_aligned uword current_col; + arma_aligned uword current_slice; + + arma_aligned const uword aux_row1; + arma_aligned const uword aux_col1; + + arma_aligned const uword aux_row2_p1; + arma_aligned const uword aux_col2_p1; + }; + + + class const_iterator + { + public: + + inline const_iterator(); + inline const_iterator(const iterator& X); + inline const_iterator(const const_iterator& X); + inline const_iterator(const subview_cube& in_sv, const uword in_row, const uword in_col, const uword in_slice); + + arma_warn_unused inline const eT& operator*(); + + inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); + + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; + + // So that we satisfy the STL iterator types. + typedef std::forward_iterator_tag iterator_category; + typedef eT value_type; + typedef std::ptrdiff_t difference_type; // TODO: not certain on this one + typedef const eT* pointer; + typedef const eT& reference; + + arma_aligned const Cube* M; + arma_aligned const eT* current_ptr; + arma_aligned uword current_row; + arma_aligned uword current_col; + arma_aligned uword current_slice; + + arma_aligned const uword aux_row1; + arma_aligned const uword aux_col1; + + arma_aligned const uword aux_row2_p1; + arma_aligned const uword aux_col2_p1; + }; + + + inline iterator begin(); + inline const_iterator begin() const; + inline const_iterator cbegin() const; + + inline iterator end(); + inline const_iterator end() const; + inline const_iterator cend() const; + + + friend class Mat; + friend class Cube; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_cube_each_bones.hpp b/src/armadillo/include/armadillo_bits/subview_cube_each_bones.hpp new file mode 100644 index 0000000..29d81fd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_cube_each_bones.hpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_cube_each +//! @{ + + + +template +class subview_cube_each_common + { + public: + + const Cube& P; + + template + inline void check_size(const Mat& A) const; + + + protected: + + arma_inline subview_cube_each_common(const Cube& in_p); + inline subview_cube_each_common() = delete; + + template + arma_cold inline const std::string incompat_size_string(const Mat& A) const; + }; + + + + +template +class subview_cube_each1 : public subview_cube_each_common + { + protected: + + arma_inline subview_cube_each1(const Cube& in_p); + inline subview_cube_each1() = delete; + + + public: + + inline ~subview_cube_each1(); + + // deliberately returning void + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + template inline void operator*= (const Base& x); + + + friend class Cube; + }; + + + +template +class subview_cube_each2 : public subview_cube_each_common + { + protected: + + inline subview_cube_each2(const Cube& in_p, const Base& in_indices); + inline subview_cube_each2() = delete; + + + public: + + const Base& base_indices; + + inline void check_indices(const Mat& indices) const; + inline ~subview_cube_each2(); + + // deliberately returning void + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + + + friend class Cube; + }; + + + +class subview_cube_each1_aux + { + public: + + template + static inline Cube operator_plus(const subview_cube_each1& X, const Base& Y); + + template + static inline Cube operator_minus(const subview_cube_each1& X, const Base& Y); + + template + static inline Cube operator_minus(const Base& X, const subview_cube_each1& Y); + + template + static inline Cube operator_schur(const subview_cube_each1& X, const Base& Y); + + template + static inline Cube operator_div(const subview_cube_each1& X,const Base& Y); + + template + static inline Cube operator_div(const Base& X, const subview_cube_each1& Y); + + template + static inline Cube operator_times(const subview_cube_each1& X,const Base& Y); + + template + static inline Cube operator_times(const Base& X, const subview_cube_each1& Y); + }; + + + +class subview_cube_each2_aux + { + public: + + template + static inline Cube operator_plus(const subview_cube_each2& X, const Base& Y); + + template + static inline Cube operator_minus(const subview_cube_each2& X, const Base& Y); + + template + static inline Cube operator_minus(const Base& X, const subview_cube_each2& Y); + + template + static inline Cube operator_schur(const subview_cube_each2& X, const Base& Y); + + template + static inline Cube operator_div(const subview_cube_each2& X, const Base& Y); + + template + static inline Cube operator_div(const Base& X, const subview_cube_each2& Y); + + // TODO: operator_times + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_cube_each_meat.hpp b/src/armadillo/include/armadillo_bits/subview_cube_each_meat.hpp new file mode 100644 index 0000000..a306918 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_cube_each_meat.hpp @@ -0,0 +1,1035 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_cube_each +//! @{ + + +// +// +// subview_cube_each_common + +template +inline +subview_cube_each_common::subview_cube_each_common(const Cube& in_p) + : P(in_p) + { + arma_extra_debug_sigprint(); + } + + + +template +template +inline +void +subview_cube_each_common::check_size(const Mat& A) const + { + if(arma_config::debug) + { + if( (A.n_rows != P.n_rows) || (A.n_cols != P.n_cols) ) + { + arma_stop_logic_error( incompat_size_string(A) ); + } + } + } + + + +template +template +inline +const std::string +subview_cube_each_common::incompat_size_string(const Mat& A) const + { + std::ostringstream tmp; + + tmp << "each_slice(): incompatible size; expected " << P.n_rows << 'x' << P.n_cols << ", got " << A.n_rows << 'x' << A.n_cols; + + return tmp.str(); + } + + + +// +// +// subview_cube_each1 + + + +template +inline +subview_cube_each1::~subview_cube_each1() + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cube_each1::subview_cube_each1(const Cube& in_p) + : subview_cube_each_common::subview_cube_each_common(in_p) + { + arma_extra_debug_sigprint(); + } + + + +template +template +inline +void +subview_cube_each1::operator= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < p_n_slices; ++i) { arrayops::copy( p.slice_memptr(i), A_mem, p_n_elem_slice ); } + } + + + +template +template +inline +void +subview_cube_each1::operator+= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < p_n_slices; ++i) { arrayops::inplace_plus( p.slice_memptr(i), A_mem, p_n_elem_slice ); } + } + + + +template +template +inline +void +subview_cube_each1::operator-= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < p_n_slices; ++i) { arrayops::inplace_minus( p.slice_memptr(i), A_mem, p_n_elem_slice ); } + } + + + +template +template +inline +void +subview_cube_each1::operator%= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < p_n_slices; ++i) { arrayops::inplace_mul( p.slice_memptr(i), A_mem, p_n_elem_slice ); } + } + + + +template +template +inline +void +subview_cube_each1::operator/= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < p_n_slices; ++i) { arrayops::inplace_div( p.slice_memptr(i), A_mem, p_n_elem_slice ); } + } + + + +template +template +inline +void +subview_cube_each1::operator*= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& C = access::rw(subview_cube_each_common::P); + + C = C.each_slice() * in.get_ref(); + } + + + +// +// +// subview_cube_each2 + + + +template +inline +subview_cube_each2::~subview_cube_each2() + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cube_each2::subview_cube_each2(const Cube& in_p, const Base& in_indices) + : subview_cube_each_common::subview_cube_each_common(in_p) + , base_indices(in_indices) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_cube_each2::check_indices(const Mat& indices) const + { + arma_debug_check( ((indices.is_vec() == false) && (indices.is_empty() == false)), "each_slice(): list of indices must be a vector" ); + } + + + +template +template +inline +void +subview_cube_each2::operator= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const unwrap U( base_indices.get_ref() ); + + check_indices(U.M); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::copy(p.slice_memptr(slice), A_mem, p_n_elem_slice); + } + } + + + +template +template +inline +void +subview_cube_each2::operator+= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const unwrap U( base_indices.get_ref() ); + + check_indices(U.M); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_plus(p.slice_memptr(slice), A_mem, p_n_elem_slice); + } + } + + + +template +template +inline +void +subview_cube_each2::operator-= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const unwrap U( base_indices.get_ref() ); + + check_indices(U.M); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_minus(p.slice_memptr(slice), A_mem, p_n_elem_slice); + } + } + + + +template +template +inline +void +subview_cube_each2::operator%= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const unwrap U( base_indices.get_ref() ); + + check_indices(U.M); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_mul(p.slice_memptr(slice), A_mem, p_n_elem_slice); + } + } + + + +template +template +inline +void +subview_cube_each2::operator/= (const Base& in) + { + arma_extra_debug_sigprint(); + + Cube& p = access::rw(subview_cube_each_common::P); + + const unwrap tmp( in.get_ref() ); + const Mat& A = tmp.M; + + subview_cube_each_common::check_size(A); + + const unwrap U( base_indices.get_ref() ); + + check_indices(U.M); + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_div(p.slice_memptr(slice), A_mem, p_n_elem_slice); + } + } + + + +// +// +// subview_cube_each1_aux + + + +template +inline +Cube +subview_cube_each1_aux::operator_plus + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + for(uword i=0; i < p_n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(i)), p_n_rows, p_n_cols, false, true); + + out_slice = p_slice + A; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each1_aux::operator_minus + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + for(uword i=0; i < p_n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(i)), p_n_rows, p_n_cols, false, true); + + out_slice = p_slice - A; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each1_aux::operator_minus + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); + + const unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + Y.check_size(A); + + for(uword i=0; i < p_n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(i)), p_n_rows, p_n_cols, false, true); + + out_slice = A - p_slice; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each1_aux::operator_schur + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + for(uword i=0; i < p_n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(i)), p_n_rows, p_n_cols, false, true); + + out_slice = p_slice % A; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each1_aux::operator_div + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + for(uword i=0; i < p_n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(i)), p_n_rows, p_n_cols, false, true); + + out_slice = p_slice / A; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each1_aux::operator_div + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); + + const unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + Y.check_size(A); + + for(uword i=0; i < p_n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(i)), p_n_rows, p_n_cols, false, true); + + out_slice = A / p_slice; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each1_aux::operator_times + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& C = X.P; + + const unwrap tmp(Y.get_ref()); + const Mat& M = tmp.M; + + Cube out(C.n_rows, M.n_cols, C.n_slices, arma_nozeros_indicator()); + + for(uword i=0; i < C.n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), C.n_rows, M.n_cols, false, true); + const Mat C_slice(const_cast(C.slice_memptr(i)), C.n_rows, C.n_cols, false, true); + + out_slice = C_slice * M; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each1_aux::operator_times + ( + const Base& X, + const subview_cube_each1& Y + ) + { + arma_extra_debug_sigprint(); + + const unwrap tmp(X.get_ref()); + const Mat& M = tmp.M; + + const Cube& C = Y.P; + + Cube out(M.n_rows, C.n_cols, C.n_slices, arma_nozeros_indicator()); + + for(uword i=0; i < C.n_slices; ++i) + { + Mat out_slice( out.slice_memptr(i), M.n_rows, C.n_cols, false, true); + const Mat C_slice(const_cast(C.slice_memptr(i)), C.n_rows, C.n_cols, false, true); + + out_slice = M * C_slice; + } + + return out; + } + + + +// +// +// subview_cube_each2_aux + + + +template +inline +Cube +subview_cube_each2_aux::operator_plus + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + Cube out = p; + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_plus(out.slice_memptr(slice), A_mem, p_n_elem_slice); + } + + return out; + } + + + +template +inline +Cube +subview_cube_each2_aux::operator_minus + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + Cube out = p; + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_minus(out.slice_memptr(slice), A_mem, p_n_elem_slice); + } + + return out; + } + + + +template +inline +Cube +subview_cube_each2_aux::operator_minus + ( + const Base& X, + const subview_cube_each2& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out = p; + + const unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(Y.base_indices.get_ref()); + + Y.check_size(A); + Y.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + Mat out_slice( out.slice_memptr(slice), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(slice)), p_n_rows, p_n_cols, false, true); + + out_slice = A - p_slice; + } + + return out; + } + + + +template +inline +Cube +subview_cube_each2_aux::operator_schur + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + Cube out = p; + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_mul(out.slice_memptr(slice), A_mem, p_n_elem_slice); + } + + return out; + } + + + +template +inline +Cube +subview_cube_each2_aux::operator_div + ( + const subview_cube_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = X.P; + + const uword p_n_slices = p.n_slices; + const uword p_n_elem_slice = p.n_elem_slice; + + Cube out = p; + + const unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + arrayops::inplace_div(out.slice_memptr(slice), A_mem, p_n_elem_slice); + } + + return out; + } + + + +template +inline +Cube +subview_cube_each2_aux::operator_div + ( + const Base& X, + const subview_cube_each2& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + const uword p_n_slices = p.n_slices; + + Cube out = p; + + const unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(Y.base_indices.get_ref()); + + Y.check_size(A); + Y.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + for(uword i=0; i < N; ++i) + { + const uword slice = indices_mem[i]; + + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + + Mat out_slice( out.slice_memptr(slice), p_n_rows, p_n_cols, false, true); + const Mat p_slice(const_cast(p.slice_memptr(slice)), p_n_rows, p_n_cols, false, true); + + out_slice = A / p_slice; + } + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_cube_meat.hpp b/src/armadillo/include/armadillo_bits/subview_cube_meat.hpp new file mode 100644 index 0000000..039e80c --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_cube_meat.hpp @@ -0,0 +1,2722 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_cube +//! @{ + + +template +inline +subview_cube::~subview_cube() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +arma_inline +subview_cube::subview_cube + ( + const Cube& in_m, + const uword in_row1, + const uword in_col1, + const uword in_slice1, + const uword in_n_rows, + const uword in_n_cols, + const uword in_n_slices + ) + : m (in_m) + , aux_row1 (in_row1) + , aux_col1 (in_col1) + , aux_slice1 (in_slice1) + , n_rows (in_n_rows) + , n_cols (in_n_cols) + , n_elem_slice(in_n_rows * in_n_cols) + , n_slices (in_n_slices) + , n_elem (n_elem_slice * in_n_slices) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +subview_cube::subview_cube(const subview_cube& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , aux_slice1 (in.aux_slice1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem_slice(in.n_elem_slice) + , n_slices (in.n_slices ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + } + + + +template +inline +subview_cube::subview_cube(subview_cube&& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , aux_slice1 (in.aux_slice1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem_slice(in.n_elem_slice) + , n_slices (in.n_slices ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + + // for paranoia + + access::rw(in.aux_row1 ) = 0; + access::rw(in.aux_col1 ) = 0; + access::rw(in.aux_slice1 ) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_cols ) = 0; + access::rw(in.n_elem_slice) = 0; + access::rw(in.n_slices ) = 0; + access::rw(in.n_elem ) = 0; + } + + + +template +template +inline +void +subview_cube::inplace_op(const eT val) + { + arma_extra_debug_sigprint(); + + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + if(is_same_type::yes) { arrayops::inplace_plus ( slice_colptr(s,c), val, t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( slice_colptr(s,c), val, t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( slice_colptr(s,c), val, t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( slice_colptr(s,c), val, t_n_rows ); } + } + } + + + + + + +template +template +inline +void +subview_cube::inplace_op(const BaseCube& in, const char* identifier) + { + arma_extra_debug_sigprint(); + + const ProxyCube P(in.get_ref()); + + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + arma_debug_assert_same_size(t, P, identifier); + + const bool use_mp = arma_config::openmp && ProxyCube::use_mp && mp_gate::eval(t.n_elem); + const bool has_overlap = P.has_overlap(t); + + if(has_overlap) { arma_extra_debug_print("aliasing or overlap detected"); } + + if( (is_Cube::stored_type>::value) || (use_mp) || (has_overlap) ) + { + const unwrap_cube_check::stored_type> tmp(P.Q, has_overlap); + const Cube& B = tmp.M; + + if( (is_same_type::yes) && (t.aux_row1 == 0) && (t_n_rows == t.m.n_rows) ) + { + for(uword s=0; s < t_n_slices; ++s) + { + arrayops::copy( t.slice_colptr(s,0), B.slice_colptr(s,0), t.n_elem_slice ); + } + } + else + { + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + if(is_same_type::yes) { arrayops::copy ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_plus ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + } + } + } + else // use the Proxy + { + if(ProxyCube::use_at) + { + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + eT* t_col_data = t.slice_colptr(s,c); + + for(uword r=0; r < t_n_rows; ++r) + { + const eT tmp = P.at(r,c,s); + + if(is_same_type::yes) { (*t_col_data) = tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) += tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) -= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) *= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) /= tmp; t_col_data++; } + } + } + } + else + { + typename ProxyCube::ea_type Pea = P.get_ea(); + + uword count = 0; + + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + eT* t_col_data = t.slice_colptr(s,c); + + for(uword r=0; r < t_n_rows; ++r) + { + const eT tmp = Pea[count]; count++; + + if(is_same_type::yes) { (*t_col_data) = tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) += tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) -= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) *= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) /= tmp; t_col_data++; } + } + } + } + } + } + + + +template +template +inline +void +subview_cube::inplace_op(const subview_cube& x, const char* identifier) + { + arma_extra_debug_sigprint(); + + if(check_overlap(x)) + { + const Cube tmp(x); + + if(is_same_type::yes) { (*this).operator= (tmp); } + if(is_same_type::yes) { (*this).operator+=(tmp); } + if(is_same_type::yes) { (*this).operator-=(tmp); } + if(is_same_type::yes) { (*this).operator%=(tmp); } + if(is_same_type::yes) { (*this).operator/=(tmp); } + + return; + } + + subview_cube& t = *this; + + arma_debug_assert_same_size(t, x, identifier); + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + if(is_same_type::yes) { arrayops::copy ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_plus ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + } + } + + + +template +inline +void +subview_cube::operator= (const eT val) + { + arma_extra_debug_sigprint(); + + if(n_elem != 1) + { + arma_debug_assert_same_size(n_rows, n_cols, n_slices, 1, 1, 1, "copy into subcube"); + } + + Cube& Q = const_cast< Cube& >(m); + + Q.at(aux_row1, aux_col1, aux_slice1) = val; + } + + + +template +inline +void +subview_cube::operator+= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_cube::operator-= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_cube::operator*= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_cube::operator/= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +template +inline +void +subview_cube::operator= (const BaseCube& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "copy into subcube"); + } + + + +template +template +inline +void +subview_cube::operator+= (const BaseCube& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "addition"); + } + + + +template +template +inline +void +subview_cube::operator-= (const BaseCube& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "subtraction"); + } + + + +template +template +inline +void +subview_cube::operator%= (const BaseCube& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "element-wise multiplication"); + } + + + +template +template +inline +void +subview_cube::operator/= (const BaseCube& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "element-wise division"); + } + + + +//! x.subcube(...) = y.subcube(...) +template +inline +void +subview_cube::operator= (const subview_cube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "copy into subcube"); + } + + + +template +inline +void +subview_cube::operator+= (const subview_cube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "addition"); + } + + + +template +inline +void +subview_cube::operator-= (const subview_cube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "subtraction"); + } + + + +template +inline +void +subview_cube::operator%= (const subview_cube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "element-wise multiplication"); + } + + + +template +inline +void +subview_cube::operator/= (const subview_cube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "element-wise division"); + } + + + +template +template +inline +void +subview_cube::operator= (const Base& in) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(in.get_ref()); + + const Mat& x = tmp.M; + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + + if( ((x_n_rows == 1) || (x_n_cols == 1)) && (t_n_rows == 1) && (t_n_cols == 1) && (x.n_elem == t_n_slices) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + const eT* x_mem = x.memptr(); + + uword i,j; + for(i=0, j=1; j < t_n_slices; i+=2, j+=2) + { + const eT tmp_i = x_mem[i]; + const eT tmp_j = x_mem[j]; + + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) = tmp_i; + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + j) = tmp_j; + } + + if(i < t_n_slices) + { + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) = x_mem[i]; + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == x_n_cols) && (t_n_slices == 1) ) + { + // interpret the matrix as a cube with one slice + + for(uword col = 0; col < t_n_cols; ++col) + { + arrayops::copy( t.slice_colptr(0, col), x.colptr(col), t_n_rows ); + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == 1) && (t_n_slices == x_n_cols) ) + { + for(uword i=0; i < t_n_slices; ++i) + { + arrayops::copy( t.slice_colptr(i, 0), x.colptr(i), t_n_rows ); + } + } + else + if( (t_n_rows == 1) && (t_n_cols == x_n_rows) && (t_n_slices == x_n_cols) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + for(uword slice=0; slice < t_n_slices; ++slice) + { + const uword mod_slice = t_aux_slice1 + slice; + + const eT* x_colptr = x.colptr(slice); + + uword i,j; + for(i=0, j=1; j < t_n_cols; i+=2, j+=2) + { + const eT tmp_i = x_colptr[i]; + const eT tmp_j = x_colptr[j]; + + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) = tmp_i; + Q.at(t_aux_row1, t_aux_col1 + j, mod_slice) = tmp_j; + } + + if(i < t_n_cols) + { + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) = x_colptr[i]; + } + } + } + else + { + if(arma_config::debug) + { + arma_stop_logic_error( arma_incompat_size_string(t, x, "copy into subcube") ); + } + } + } + + + +template +template +inline +void +subview_cube::operator+= (const Base& in) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(in.get_ref()); + + const Mat& x = tmp.M; + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + + if( ((x_n_rows == 1) || (x_n_cols == 1)) && (t_n_rows == 1) && (t_n_cols == 1) && (x.n_elem == t_n_slices) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + const eT* x_mem = x.memptr(); + + uword i,j; + for(i=0, j=1; j < t_n_slices; i+=2, j+=2) + { + const eT tmp_i = x_mem[i]; + const eT tmp_j = x_mem[j]; + + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) += tmp_i; + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + j) += tmp_j; + } + + if(i < t_n_slices) + { + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) += x_mem[i]; + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == x_n_cols) && (t_n_slices == 1) ) + { + for(uword col = 0; col < t_n_cols; ++col) + { + arrayops::inplace_plus( t.slice_colptr(0, col), x.colptr(col), t_n_rows ); + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == 1) && (t_n_slices == x_n_cols) ) + { + for(uword i=0; i < t_n_slices; ++i) + { + arrayops::inplace_plus( t.slice_colptr(i, 0), x.colptr(i), t_n_rows ); + } + } + else + if( (t_n_rows == 1) && (t_n_cols == x_n_rows) && (t_n_slices == x_n_cols) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + for(uword slice=0; slice < t_n_slices; ++slice) + { + const uword mod_slice = t_aux_slice1 + slice; + + const eT* x_colptr = x.colptr(slice); + + uword i,j; + for(i=0, j=1; j < t_n_cols; i+=2, j+=2) + { + const eT tmp_i = x_colptr[i]; + const eT tmp_j = x_colptr[j]; + + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) += tmp_i; + Q.at(t_aux_row1, t_aux_col1 + j, mod_slice) += tmp_j; + } + + if(i < t_n_cols) + { + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) += x_colptr[i]; + } + } + } + else + { + if(arma_config::debug) + { + arma_stop_logic_error( arma_incompat_size_string(t, x, "addition") ); + } + } + } + + + +template +template +inline +void +subview_cube::operator-= (const Base& in) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(in.get_ref()); + + const Mat& x = tmp.M; + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + + if( ((x_n_rows == 1) || (x_n_cols == 1)) && (t_n_rows == 1) && (t_n_cols == 1) && (x.n_elem == t_n_slices) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + const eT* x_mem = x.memptr(); + + uword i,j; + for(i=0, j=1; j < t_n_slices; i+=2, j+=2) + { + const eT tmp_i = x_mem[i]; + const eT tmp_j = x_mem[j]; + + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) -= tmp_i; + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + j) -= tmp_j; + } + + if(i < t_n_slices) + { + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) -= x_mem[i]; + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == x_n_cols) && (t_n_slices == 1) ) + { + for(uword col = 0; col < t_n_cols; ++col) + { + arrayops::inplace_minus( t.slice_colptr(0, col), x.colptr(col), t_n_rows ); + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == 1) && (t_n_slices == x_n_cols) ) + { + for(uword i=0; i < t_n_slices; ++i) + { + arrayops::inplace_minus( t.slice_colptr(i, 0), x.colptr(i), t_n_rows ); + } + } + else + if( (t_n_rows == 1) && (t_n_cols == x_n_rows) && (t_n_slices == x_n_cols) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + for(uword slice=0; slice < t_n_slices; ++slice) + { + const uword mod_slice = t_aux_slice1 + slice; + + const eT* x_colptr = x.colptr(slice); + + uword i,j; + for(i=0, j=1; j < t_n_cols; i+=2, j+=2) + { + const eT tmp_i = x_colptr[i]; + const eT tmp_j = x_colptr[j]; + + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) -= tmp_i; + Q.at(t_aux_row1, t_aux_col1 + j, mod_slice) -= tmp_j; + } + + if(i < t_n_cols) + { + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) -= x_colptr[i]; + } + } + } + else + { + if(arma_config::debug) + { + arma_stop_logic_error( arma_incompat_size_string(t, x, "subtraction") ); + } + } + } + + + +template +template +inline +void +subview_cube::operator%= (const Base& in) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(in.get_ref()); + + const Mat& x = tmp.M; + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + + if( ((x_n_rows == 1) || (x_n_cols == 1)) && (t_n_rows == 1) && (t_n_cols == 1) && (x.n_elem == t_n_slices) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + const eT* x_mem = x.memptr(); + + uword i,j; + for(i=0, j=1; j < t_n_slices; i+=2, j+=2) + { + const eT tmp_i = x_mem[i]; + const eT tmp_j = x_mem[j]; + + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) *= tmp_i; + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + j) *= tmp_j; + } + + if(i < t_n_slices) + { + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) *= x_mem[i]; + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == x_n_cols) && (t_n_slices == 1) ) + { + for(uword col = 0; col < t_n_cols; ++col) + { + arrayops::inplace_mul( t.slice_colptr(0, col), x.colptr(col), t_n_rows ); + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == 1) && (t_n_slices == x_n_cols) ) + { + for(uword i=0; i < t_n_slices; ++i) + { + arrayops::inplace_mul( t.slice_colptr(i, 0), x.colptr(i), t_n_rows ); + } + } + else + if( (t_n_rows == 1) && (t_n_cols == x_n_rows) && (t_n_slices == x_n_cols) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + for(uword slice=0; slice < t_n_slices; ++slice) + { + const uword mod_slice = t_aux_slice1 + slice; + + const eT* x_colptr = x.colptr(slice); + + uword i,j; + for(i=0, j=1; j < t_n_cols; i+=2, j+=2) + { + const eT tmp_i = x_colptr[i]; + const eT tmp_j = x_colptr[j]; + + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) *= tmp_i; + Q.at(t_aux_row1, t_aux_col1 + j, mod_slice) *= tmp_j; + } + + if(i < t_n_cols) + { + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) *= x_colptr[i]; + } + } + } + else + { + if(arma_config::debug) + { + arma_stop_logic_error( arma_incompat_size_string(t, x, "element-wise multiplication") ); + } + } + } + + + +template +template +inline +void +subview_cube::operator/= (const Base& in) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp(in.get_ref()); + + const Mat& x = tmp.M; + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + const uword x_n_rows = x.n_rows; + const uword x_n_cols = x.n_cols; + + if( ((x_n_rows == 1) || (x_n_cols == 1)) && (t_n_rows == 1) && (t_n_cols == 1) && (x.n_elem == t_n_slices) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + const eT* x_mem = x.memptr(); + + uword i,j; + for(i=0, j=1; j < t_n_slices; i+=2, j+=2) + { + const eT tmp_i = x_mem[i]; + const eT tmp_j = x_mem[j]; + + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) /= tmp_i; + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + j) /= tmp_j; + } + + if(i < t_n_slices) + { + Q.at(t_aux_row1, t_aux_col1, t_aux_slice1 + i) /= x_mem[i]; + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == x_n_cols) && (t_n_slices == 1) ) + { + for(uword col = 0; col < t_n_cols; ++col) + { + arrayops::inplace_div( t.slice_colptr(0, col), x.colptr(col), t_n_rows ); + } + } + else + if( (t_n_rows == x_n_rows) && (t_n_cols == 1) && (t_n_slices == x_n_cols) ) + { + for(uword i=0; i < t_n_slices; ++i) + { + arrayops::inplace_div( t.slice_colptr(i, 0), x.colptr(i), t_n_rows ); + } + } + else + if( (t_n_rows == 1) && (t_n_cols == x_n_rows) && (t_n_slices == x_n_cols) ) + { + Cube& Q = const_cast< Cube& >(t.m); + + const uword t_aux_row1 = t.aux_row1; + const uword t_aux_col1 = t.aux_col1; + const uword t_aux_slice1 = t.aux_slice1; + + for(uword slice=0; slice < t_n_slices; ++slice) + { + const uword mod_slice = t_aux_slice1 + slice; + + const eT* x_colptr = x.colptr(slice); + + uword i,j; + for(i=0, j=1; j < t_n_cols; i+=2, j+=2) + { + const eT tmp_i = x_colptr[i]; + const eT tmp_j = x_colptr[j]; + + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) /= tmp_i; + Q.at(t_aux_row1, t_aux_col1 + j, mod_slice) /= tmp_j; + } + + if(i < t_n_cols) + { + Q.at(t_aux_row1, t_aux_col1 + i, mod_slice) /= x_colptr[i]; + } + } + } + else + { + if(arma_config::debug) + { + arma_stop_logic_error( arma_incompat_size_string(t, x, "element-wise division") ); + } + } + } + + + +template +template +inline +void +subview_cube::operator= (const GenCube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(n_rows, n_cols, n_slices, in.n_rows, in.n_cols, in.n_slices, "copy into subcube"); + + in.apply(*this); + } + + + +//! apply a functor to each element +template +template +inline +void +subview_cube::for_each(functor F) + { + arma_extra_debug_sigprint(); + + Cube& Q = const_cast< Cube& >(m); + + const uword start_col = aux_col1; + const uword start_row = aux_row1; + const uword start_slice = aux_slice1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + const uword end_slice_plus1 = start_slice + n_slices; + + for(uword uslice = start_slice; uslice < end_slice_plus1; ++uslice) + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol ) + for(uword urow = start_row; urow < end_row_plus1; ++urow ) + { + F( Q.at(urow, ucol, uslice) ); + } + } + + + +template +template +inline +void +subview_cube::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + const Cube& Q = m; + + const uword start_col = aux_col1; + const uword start_row = aux_row1; + const uword start_slice = aux_slice1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + const uword end_slice_plus1 = start_slice + n_slices; + + for(uword uslice = start_slice; uslice < end_slice_plus1; ++uslice) + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol ) + for(uword urow = start_row; urow < end_row_plus1; ++urow ) + { + F( Q.at(urow, ucol, uslice) ); + } + } + + + +//! transform each element in the subview using a functor +template +template +inline +void +subview_cube::transform(functor F) + { + arma_extra_debug_sigprint(); + + Cube& Q = const_cast< Cube& >(m); + + const uword start_col = aux_col1; + const uword start_row = aux_row1; + const uword start_slice = aux_slice1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + const uword end_slice_plus1 = start_slice + n_slices; + + for(uword uslice = start_slice; uslice < end_slice_plus1; ++uslice) + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol ) + for(uword urow = start_row; urow < end_row_plus1; ++urow ) + { + Q.at(urow, ucol, uslice) = eT( F( Q.at(urow, ucol, uslice) ) ); + } + } + + + +//! imbue (fill) the subview with values provided by a functor +template +template +inline +void +subview_cube::imbue(functor F) + { + arma_extra_debug_sigprint(); + + Cube& Q = const_cast< Cube& >(m); + + const uword start_col = aux_col1; + const uword start_row = aux_row1; + const uword start_slice = aux_slice1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + const uword end_slice_plus1 = start_slice + n_slices; + + for(uword uslice = start_slice; uslice < end_slice_plus1; ++uslice) + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol ) + for(uword urow = start_row; urow < end_row_plus1; ++urow ) + { + Q.at(urow, ucol, uslice) = eT( F() ); + } + } + + + +//! apply a lambda function to each slice, where each slice is interpreted as a matrix +template +inline +void +subview_cube::each_slice(const std::function< void(Mat&) >& F) + { + arma_extra_debug_sigprint(); + + Mat tmp1(n_rows, n_cols, arma_nozeros_indicator()); + Mat tmp2('j', tmp1.memptr(), n_rows, n_cols); + + for(uword slice_id=0; slice_id < n_slices; ++slice_id) + { + for(uword col_id=0; col_id < n_cols; ++col_id) + { + arrayops::copy( tmp1.colptr(col_id), slice_colptr(slice_id, col_id), n_rows ); + } + + F(tmp2); + + for(uword col_id=0; col_id < n_cols; ++col_id) + { + arrayops::copy( slice_colptr(slice_id, col_id), tmp1.colptr(col_id), n_rows ); + } + } + } + + + +template +inline +void +subview_cube::each_slice(const std::function< void(const Mat&) >& F) const + { + arma_extra_debug_sigprint(); + + Mat tmp1(n_rows, n_cols, arma_nozeros_indicator()); + const Mat tmp2('j', tmp1.memptr(), n_rows, n_cols); + + for(uword slice_id=0; slice_id < n_slices; ++slice_id) + { + for(uword col_id=0; col_id < n_cols; ++col_id) + { + arrayops::copy( tmp1.colptr(col_id), slice_colptr(slice_id, col_id), n_rows ); + } + + F(tmp2); + } + } + + + +template +inline +void +subview_cube::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arrayops::replace(slice_colptr(slice,col), local_n_rows, old_val, new_val); + } + } + } + + + +template +inline +void +subview_cube::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arrayops::clean( slice_colptr(slice,col), local_n_rows, threshold ); + } + } + } + + + +template +inline +void +subview_cube::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview_cube::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview_cube::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "subview_cube::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arrayops::clamp( slice_colptr(slice,col), local_n_rows, min_val, max_val ); + } + } + } + + + +template +inline +void +subview_cube::fill(const eT val) + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arrayops::inplace_set( slice_colptr(slice,col), val, local_n_rows ); + } + } + } + + + +template +inline +void +subview_cube::zeros() + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arrayops::fill_zeros( slice_colptr(slice,col), local_n_rows ); + } + } + } + + + +template +inline +void +subview_cube::ones() + { + arma_extra_debug_sigprint(); + + fill(eT(1)); + } + + + +template +inline +void +subview_cube::randu() + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arma_rng::randu::fill( slice_colptr(slice,col), local_n_rows ); + } + } + } + + + +template +inline +void +subview_cube::randn() + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arma_rng::randn::fill( slice_colptr(slice,col), local_n_rows ); + } + } + } + + + +template +inline +bool +subview_cube::is_finite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + if(arrayops::is_finite(slice_colptr(slice,col), local_n_rows) == false) { return false; } + } + } + + return true; + } + + + +template +inline +bool +subview_cube::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + if(arrayops::is_zero(slice_colptr(slice,col), local_n_rows, tol) == false) { return false; } + } + } + + return true; + } + + + +template +inline +bool +subview_cube::has_inf() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + if(arrayops::has_inf(slice_colptr(slice,col), local_n_rows)) { return true; } + } + } + + return false; + } + + + +template +inline +bool +subview_cube::has_nan() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + if(arrayops::has_nan(slice_colptr(slice,col), local_n_rows)) { return true; } + } + } + + return false; + } + + + +template +inline +bool +subview_cube::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + if(arrayops::is_finite(slice_colptr(slice,col), local_n_rows) == false) { return true; } + } + } + + return false; + } + + + +template +inline +eT +subview_cube::at_alt(const uword i) const + { + return operator[](i); + } + + + +template +inline +eT& +subview_cube::operator[](const uword i) + { + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + const uword in_col = j / n_rows; + const uword in_row = j % n_rows; + + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Cube& >(m)).mem[index] ); + } + + + +template +inline +eT +subview_cube::operator[](const uword i) const + { + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + const uword in_col = j / n_rows; + const uword in_row = j % n_rows; + + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +inline +eT& +subview_cube::operator()(const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "subview_cube::operator(): index out of bounds" ); + + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + const uword in_col = j / n_rows; + const uword in_row = j % n_rows; + + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Cube& >(m)).mem[index] ); + } + + + +template +inline +eT +subview_cube::operator()(const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "subview_cube::operator(): index out of bounds" ); + + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + const uword in_col = j / n_rows; + const uword in_row = j % n_rows; + + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +arma_inline +eT& +subview_cube::operator()(const uword in_row, const uword in_col, const uword in_slice) + { + arma_debug_check_bounds( ( (in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices) ), "subview_cube::operator(): location out of bounds" ); + + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Cube& >(m)).mem[index] ); + } + + + +template +arma_inline +eT +subview_cube::operator()(const uword in_row, const uword in_col, const uword in_slice) const + { + arma_debug_check_bounds( ( (in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices) ), "subview_cube::operator(): location out of bounds" ); + + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +arma_inline +eT& +subview_cube::at(const uword in_row, const uword in_col, const uword in_slice) + { + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Cube& >(m)).mem[index] ); + } + + + +template +arma_inline +eT +subview_cube::at(const uword in_row, const uword in_col, const uword in_slice) const + { + const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +arma_inline +eT* +subview_cube::slice_colptr(const uword in_slice, const uword in_col) + { + return & access::rw((const_cast< Cube& >(m)).mem[ (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 ]); + } + + + +template +arma_inline +const eT* +subview_cube::slice_colptr(const uword in_slice, const uword in_col) const + { + return & m.mem[ (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 ]; + } + + + +template +template +inline +bool +subview_cube::check_overlap(const subview_cube& x) const + { + if(is_same_type::value == false) { return false; } + + const subview_cube& t = (*this); + + if(void_ptr(&(t.m)) != void_ptr(&(x.m))) { return false; } + + if( (t.n_elem == 0) || (x.n_elem == 0) ) { return false; } + + const uword t_row_start = t.aux_row1; + const uword t_row_end_p1 = t_row_start + t.n_rows; + + const uword t_col_start = t.aux_col1; + const uword t_col_end_p1 = t_col_start + t.n_cols; + + const uword t_slice_start = t.aux_slice1; + const uword t_slice_end_p1 = t_slice_start + t.n_slices; + + + const uword x_row_start = x.aux_row1; + const uword x_row_end_p1 = x_row_start + x.n_rows; + + const uword x_col_start = x.aux_col1; + const uword x_col_end_p1 = x_col_start + x.n_cols; + + const uword x_slice_start = x.aux_slice1; + const uword x_slice_end_p1 = x_slice_start + x.n_slices; + + + const bool outside_rows = ( (x_row_start >= t_row_end_p1 ) || (t_row_start >= x_row_end_p1 ) ); + const bool outside_cols = ( (x_col_start >= t_col_end_p1 ) || (t_col_start >= x_col_end_p1 ) ); + const bool outside_slices = ( (x_slice_start >= t_slice_end_p1) || (t_slice_start >= x_slice_end_p1) ); + + return ( (outside_rows == false) && (outside_cols == false) && (outside_slices == false) ); + } + + + +template +inline +bool +subview_cube::check_overlap(const Mat& x) const + { + const subview_cube& t = *this; + + const uword t_aux_slice1 = t.aux_slice1; + const uword t_aux_slice2_plus_1 = t_aux_slice1 + t.n_slices; + + for(uword slice = t_aux_slice1; slice < t_aux_slice2_plus_1; ++slice) + { + if(t.m.mat_ptrs[slice] != nullptr) + { + const Mat& y = *(t.m.mat_ptrs[slice]); + + if( x.memptr() == y.memptr() ) { return true; } + } + } + + return false; + } + + + +//! cube X = Y.subcube(...) +template +inline +void +subview_cube::extract(Cube& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the cube has already been set to the correct size and there is no aliasing; + // size setting and alias checking is done by either the Cube contructor or operator=() + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + const uword n_slices = in.n_slices; + + arma_extra_debug_print(arma_str::format("out.n_rows = %u out.n_cols = %u out.n_slices = %u in.m.n_rows = %u in.m.n_cols = %u in.m.n_slices = %u") % out.n_rows % out.n_cols % out.n_slices % in.m.n_rows % in.m.n_cols % in.m.n_slices); + + if( (in.aux_row1 == 0) && (n_rows == in.m.n_rows) ) + { + for(uword s=0; s < n_slices; ++s) + { + arrayops::copy( out.slice_colptr(s,0), in.slice_colptr(s,0), in.n_elem_slice ); + } + + return; + } + + for(uword s=0; s < n_slices; ++s) + for(uword c=0; c < n_cols; ++c) + { + arrayops::copy( out.slice_colptr(s,c), in.slice_colptr(s,c), n_rows ); + } + } + + + +//! cube X += Y.subcube(...) +template +inline +void +subview_cube::plus_inplace(Cube& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "addition"); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + const uword n_slices = out.n_slices; + + for(uword slice = 0; slice +inline +void +subview_cube::minus_inplace(Cube& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "subtraction"); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + const uword n_slices = out.n_slices; + + for(uword slice = 0; slice +inline +void +subview_cube::schur_inplace(Cube& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "element-wise multiplication"); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + const uword n_slices = out.n_slices; + + for(uword slice = 0; slice +inline +void +subview_cube::div_inplace(Cube& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "element-wise division"); + + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + const uword n_slices = out.n_slices; + + for(uword slice = 0; slice +inline +void +subview_cube::extract(Mat& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_cube_as_mat(out, in, "copy into matrix", false); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + out.set_size(in_n_rows, in_n_cols); + + for(uword col=0; col < in_n_cols; ++col) + { + arrayops::copy( out.colptr(col), in.slice_colptr(0, col), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if(in_n_cols == 1) + { + out.set_size(in_n_rows, in_n_slices); + + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::copy( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if(in_n_rows == 1) + { + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + out.set_size(in_n_cols, in_n_slices); + + for(uword slice=0; slice < in_n_slices; ++slice) + { + const uword mod_slice = in_aux_slice1 + slice; + + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + const eT tmp_j = Q.at(in_aux_row1, in_aux_col1 + j, mod_slice); + + out_colptr[i] = tmp_i; + out_colptr[j] = tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] = Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + } + } + } + } + else + { + out.set_size(in_n_slices); + + eT* out_mem = out.memptr(); + + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword i=0; i +inline +void +subview_cube::plus_inplace(Mat& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_cube_as_mat(out, in, "addition", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + if( (arma_config::debug) && ((out_n_rows != in_n_rows) || (out_n_cols != in_n_cols)) ) + { + std::ostringstream tmp; + + tmp + << "in-place addition: " + << out_n_rows << 'x' << out_n_cols << " output matrix is incompatible with " + << in_n_rows << 'x' << in_n_cols << 'x' << in_n_slices << " cube interpreted as " + << in_n_rows << 'x' << in_n_cols << " matrix"; + + arma_stop_logic_error(tmp.str()); + } + + for(uword col=0; col < in_n_cols; ++col) + { + arrayops::inplace_plus( out.colptr(col), in.slice_colptr(0, col), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_plus( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword slice=0; slice < in_n_slices; ++slice) + { + const uword mod_slice = in_aux_slice1 + slice; + + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + const eT tmp_j = Q.at(in_aux_row1, in_aux_col1 + j, mod_slice); + + out_colptr[i] += tmp_i; + out_colptr[j] += tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] += Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword i=0; i +inline +void +subview_cube::minus_inplace(Mat& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_cube_as_mat(out, in, "subtraction", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + if( (arma_config::debug) && ((out_n_rows != in_n_rows) || (out_n_cols != in_n_cols)) ) + { + std::ostringstream tmp; + + tmp + << "in-place subtraction: " + << out_n_rows << 'x' << out_n_cols << " output matrix is incompatible with " + << in_n_rows << 'x' << in_n_cols << 'x' << in_n_slices << " cube interpreted as " + << in_n_rows << 'x' << in_n_cols << " matrix"; + + arma_stop_logic_error(tmp.str()); + } + + for(uword col=0; col < in_n_cols; ++col) + { + arrayops::inplace_minus( out.colptr(col), in.slice_colptr(0, col), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_minus( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword slice=0; slice < in_n_slices; ++slice) + { + const uword mod_slice = in_aux_slice1 + slice; + + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + const eT tmp_j = Q.at(in_aux_row1, in_aux_col1 + j, mod_slice); + + out_colptr[i] -= tmp_i; + out_colptr[j] -= tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] -= Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword i=0; i +inline +void +subview_cube::schur_inplace(Mat& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_cube_as_mat(out, in, "element-wise multiplication", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + if( (arma_config::debug) && ((out_n_rows != in_n_rows) || (out_n_cols != in_n_cols)) ) + { + std::ostringstream tmp; + + tmp + << "in-place element-wise multiplication: " + << out_n_rows << 'x' << out_n_cols << " output matrix is incompatible with " + << in_n_rows << 'x' << in_n_cols << 'x' << in_n_slices << " cube interpreted as " + << in_n_rows << 'x' << in_n_cols << " matrix"; + + arma_stop_logic_error(tmp.str()); + } + + for(uword col=0; col < in_n_cols; ++col) + { + arrayops::inplace_mul( out.colptr(col), in.slice_colptr(0, col), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_mul( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword slice=0; slice < in_n_slices; ++slice) + { + const uword mod_slice = in_aux_slice1 + slice; + + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + const eT tmp_j = Q.at(in_aux_row1, in_aux_col1 + j, mod_slice); + + out_colptr[i] *= tmp_i; + out_colptr[j] *= tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] *= Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword i=0; i +inline +void +subview_cube::div_inplace(Mat& out, const subview_cube& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_cube_as_mat(out, in, "element-wise division", true); + + const uword in_n_rows = in.n_rows; + const uword in_n_cols = in.n_cols; + const uword in_n_slices = in.n_slices; + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + const uword out_vec_state = out.vec_state; + + if(in_n_slices == 1) + { + if( (arma_config::debug) && ((out_n_rows != in_n_rows) || (out_n_cols != in_n_cols)) ) + { + std::ostringstream tmp; + + tmp + << "in-place element-wise division: " + << out_n_rows << 'x' << out_n_cols << " output matrix is incompatible with " + << in_n_rows << 'x' << in_n_cols << 'x' << in_n_slices << " cube interpreted as " + << in_n_rows << 'x' << in_n_cols << " matrix"; + + arma_stop_logic_error(tmp.str()); + } + + for(uword col=0; col < in_n_cols; ++col) + { + arrayops::inplace_div( out.colptr(col), in.slice_colptr(0, col), in_n_rows ); + } + } + else + { + if(out_vec_state == 0) + { + if( (in_n_rows == out_n_rows) && (in_n_cols == 1) && (in_n_slices == out_n_cols) ) + { + for(uword i=0; i < in_n_slices; ++i) + { + arrayops::inplace_div( out.colptr(i), in.slice_colptr(i, 0), in_n_rows ); + } + } + else + if( (in_n_rows == 1) && (in_n_cols == out_n_rows) && (in_n_slices == out_n_cols) ) + { + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword slice=0; slice < in_n_slices; ++slice) + { + const uword mod_slice = in_aux_slice1 + slice; + + eT* out_colptr = out.colptr(slice); + + uword i,j; + for(i=0, j=1; j < in_n_cols; i+=2, j+=2) + { + const eT tmp_i = Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + const eT tmp_j = Q.at(in_aux_row1, in_aux_col1 + j, mod_slice); + + out_colptr[i] /= tmp_i; + out_colptr[j] /= tmp_j; + } + + if(i < in_n_cols) + { + out_colptr[i] /= Q.at(in_aux_row1, in_aux_col1 + i, mod_slice); + } + } + } + } + else + { + eT* out_mem = out.memptr(); + + const Cube& Q = in.m; + + const uword in_aux_row1 = in.aux_row1; + const uword in_aux_col1 = in.aux_col1; + const uword in_aux_slice1 = in.aux_slice1; + + for(uword i=0; i +inline +typename subview_cube::iterator +subview_cube::begin() + { + return iterator(*this, aux_row1, aux_col1, aux_slice1); + } + + + +template +inline +typename subview_cube::const_iterator +subview_cube::begin() const + { + return const_iterator(*this, aux_row1, aux_col1, aux_slice1); + } + + + +template +inline +typename subview_cube::const_iterator +subview_cube::cbegin() const + { + return const_iterator(*this, aux_row1, aux_col1, aux_slice1); + } + + + +template +inline +typename subview_cube::iterator +subview_cube::end() + { + return iterator(*this, aux_row1, aux_col1, aux_slice1 + n_slices); + } + + + +template +inline +typename subview_cube::const_iterator +subview_cube::end() const + { + return const_iterator(*this, aux_row1, aux_col1, aux_slice1 + n_slices); + } + + + +template +inline +typename subview_cube::const_iterator +subview_cube::cend() const + { + return const_iterator(*this, aux_row1, aux_col1, aux_slice1 + n_slices); + } + + + +// +// +// + + + +template +inline +subview_cube::iterator::iterator() + : M (nullptr) + , current_ptr (nullptr) + , current_row (0 ) + , current_col (0 ) + , current_slice(0 ) + , aux_row1 (0 ) + , aux_col1 (0 ) + , aux_row2_p1 (0 ) + , aux_col2_p1 (0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +subview_cube::iterator::iterator(const iterator& X) + : M (X.M ) + , current_ptr (X.current_ptr ) + , current_row (X.current_row ) + , current_col (X.current_col ) + , current_slice(X.current_slice) + , aux_row1 (X.aux_row1 ) + , aux_col1 (X.aux_col1 ) + , aux_row2_p1 (X.aux_row2_p1 ) + , aux_col2_p1 (X.aux_col2_p1 ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cube::iterator::iterator(subview_cube& in_sv, const uword in_row, const uword in_col, const uword in_slice) + : M (&(const_cast< Cube& >(in_sv.m))) + , current_ptr (&(M->at(in_row,in_col,in_slice)) ) + , current_row (in_row ) + , current_col (in_col ) + , current_slice(in_slice ) + , aux_row1 (in_sv.aux_row1 ) + , aux_col1 (in_sv.aux_col1 ) + , aux_row2_p1 (in_sv.aux_row1 + in_sv.n_rows ) + , aux_col2_p1 (in_sv.aux_col1 + in_sv.n_cols ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eT& +subview_cube::iterator::operator*() + { + return (*current_ptr); + } + + + +template +inline +typename subview_cube::iterator& +subview_cube::iterator::operator++() + { + current_row++; + + if(current_row == aux_row2_p1) + { + current_row = aux_row1; + current_col++; + + if(current_col == aux_col2_p1) + { + current_col = aux_col1; + current_slice++; + } + + current_ptr = &( (*M).at(current_row,current_col,current_slice) ); + } + else + { + current_ptr++; + } + + return *this; + } + + + +template +inline +typename subview_cube::iterator +subview_cube::iterator::operator++(int) + { + typename subview_cube::iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +bool +subview_cube::iterator::operator==(const iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview_cube::iterator::operator!=(const iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +bool +subview_cube::iterator::operator==(const const_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview_cube::iterator::operator!=(const const_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +// +// +// + + + +template +inline +subview_cube::const_iterator::const_iterator() + : M (nullptr) + , current_ptr (nullptr) + , current_row (0 ) + , current_col (0 ) + , current_slice(0 ) + , aux_row1 (0 ) + , aux_col1 (0 ) + , aux_row2_p1 (0 ) + , aux_col2_p1 (0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +subview_cube::const_iterator::const_iterator(const iterator& X) + : M (X.M ) + , current_ptr (X.current_ptr ) + , current_row (X.current_row ) + , current_col (X.current_col ) + , current_slice(X.current_slice) + , aux_row1 (X.aux_row1 ) + , aux_col1 (X.aux_col1 ) + , aux_row2_p1 (X.aux_row2_p1 ) + , aux_col2_p1 (X.aux_col2_p1 ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cube::const_iterator::const_iterator(const const_iterator& X) + : M (X.M ) + , current_ptr (X.current_ptr ) + , current_row (X.current_row ) + , current_col (X.current_col ) + , current_slice(X.current_slice) + , aux_row1 (X.aux_row1 ) + , aux_col1 (X.aux_col1 ) + , aux_row2_p1 (X.aux_row2_p1 ) + , aux_col2_p1 (X.aux_col2_p1 ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cube::const_iterator::const_iterator(const subview_cube& in_sv, const uword in_row, const uword in_col, const uword in_slice) + : M (&(in_sv.m) ) + , current_ptr (&(M->at(in_row,in_col,in_slice))) + , current_row (in_row ) + , current_col (in_col ) + , current_slice(in_slice ) + , aux_row1 (in_sv.aux_row1 ) + , aux_col1 (in_sv.aux_col1 ) + , aux_row2_p1 (in_sv.aux_row1 + in_sv.n_rows ) + , aux_col2_p1 (in_sv.aux_col1 + in_sv.n_cols ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +const eT& +subview_cube::const_iterator::operator*() + { + return (*current_ptr); + } + + + +template +inline +typename subview_cube::const_iterator& +subview_cube::const_iterator::operator++() + { + current_row++; + + if(current_row == aux_row2_p1) + { + current_row = aux_row1; + current_col++; + + if(current_col == aux_col2_p1) + { + current_col = aux_col1; + current_slice++; + } + + current_ptr = &( (*M).at(current_row,current_col,current_slice) ); + } + else + { + current_ptr++; + } + + return *this; + } + + + +template +inline +typename subview_cube::const_iterator +subview_cube::const_iterator::operator++(int) + { + typename subview_cube::const_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +bool +subview_cube::const_iterator::operator==(const iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview_cube::const_iterator::operator!=(const iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +bool +subview_cube::const_iterator::operator==(const const_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview_cube::const_iterator::operator!=(const const_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_cube_slices_bones.hpp b/src/armadillo/include/armadillo_bits/subview_cube_slices_bones.hpp new file mode 100644 index 0000000..e19890f --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_cube_slices_bones.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_cube_slices +//! @{ + + + +template +class subview_cube_slices : public BaseCube< eT, subview_cube_slices > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + arma_aligned const Cube& m; + arma_aligned const Base& base_si; + + + protected: + + arma_inline subview_cube_slices(const Cube& in_m, const Base& in_si); + + + public: + + inline ~subview_cube_slices(); + inline subview_cube_slices() = delete; + + inline void inplace_rand(const uword rand_mode); + + template inline void inplace_op(const eT val); + template inline void inplace_op(const BaseCube& x); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + inline void randu(); + inline void randn(); + + inline void operator+= (const eT val); + inline void operator-= (const eT val); + inline void operator*= (const eT val); + inline void operator/= (const eT val); + + + // deliberately returning void + template inline void operator_equ(const subview_cube_slices& x); + template inline void operator= (const subview_cube_slices& x); + inline void operator= (const subview_cube_slices& x); + + template inline void operator+= (const subview_cube_slices& x); + template inline void operator-= (const subview_cube_slices& x); + template inline void operator%= (const subview_cube_slices& x); + template inline void operator/= (const subview_cube_slices& x); + + template inline void operator= (const BaseCube& x); + template inline void operator+= (const BaseCube& x); + template inline void operator-= (const BaseCube& x); + template inline void operator%= (const BaseCube& x); + template inline void operator/= (const BaseCube& x); + + inline static void extract(Cube& out, const subview_cube_slices& in); + + inline static void plus_inplace(Cube& out, const subview_cube_slices& in); + inline static void minus_inplace(Cube& out, const subview_cube_slices& in); + inline static void schur_inplace(Cube& out, const subview_cube_slices& in); + inline static void div_inplace(Cube& out, const subview_cube_slices& in); + + + friend class Cube; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_cube_slices_meat.hpp b/src/armadillo/include/armadillo_bits/subview_cube_slices_meat.hpp new file mode 100644 index 0000000..f520da0 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_cube_slices_meat.hpp @@ -0,0 +1,555 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_cube_slices +//! @{ + + +template +inline +subview_cube_slices::~subview_cube_slices() + { + arma_extra_debug_sigprint(); + } + + +template +arma_inline +subview_cube_slices::subview_cube_slices + ( + const Cube& in_m, + const Base& in_si + ) + : m (in_m ) + , base_si(in_si) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_cube_slices::inplace_rand(const uword rand_mode) + { + arma_extra_debug_sigprint(); + + Cube& m_local = const_cast< Cube& >(m); + + const uword m_n_slices = m_local.n_slices; + const uword m_n_elem_slice = m_local.n_elem_slice; + + const quasi_unwrap U(base_si.get_ref()); + const umat& si = U.M; + + arma_debug_check + ( + ( (si.is_vec() == false) && (si.is_empty() == false) ), + "Cube::slices(): given object must be a vector" + ); + + const uword* si_mem = si.memptr(); + const uword si_n_elem = si.n_elem; + + for(uword si_count=0; si_count < si_n_elem; ++si_count) + { + const uword i = si_mem[si_count]; + + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); + + eT* m_slice_ptr = m_local.slice_memptr(i); + + if(rand_mode == 0) { arma_rng::randu::fill(m_slice_ptr, m_n_elem_slice); } + if(rand_mode == 1) { arma_rng::randn::fill(m_slice_ptr, m_n_elem_slice); } + } + } + + + +template +template +inline +void +subview_cube_slices::inplace_op(const eT val) + { + arma_extra_debug_sigprint(); + + Cube& m_local = const_cast< Cube& >(m); + + const uword m_n_slices = m_local.n_slices; + const uword m_n_elem_slice = m_local.n_elem_slice; + + const quasi_unwrap U(base_si.get_ref()); + const umat& si = U.M; + + arma_debug_check + ( + ( (si.is_vec() == false) && (si.is_empty() == false) ), + "Cube::slices(): given object must be a vector" + ); + + const uword* si_mem = si.memptr(); + const uword si_n_elem = si.n_elem; + + for(uword si_count=0; si_count < si_n_elem; ++si_count) + { + const uword i = si_mem[si_count]; + + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); + + eT* m_slice_ptr = m_local.slice_memptr(i); + + if(is_same_type::yes) { arrayops::inplace_set (m_slice_ptr, val, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_plus (m_slice_ptr, val, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_minus(m_slice_ptr, val, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_mul (m_slice_ptr, val, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_div (m_slice_ptr, val, m_n_elem_slice); } + } + } + + + +template +template +inline +void +subview_cube_slices::inplace_op(const BaseCube& x) + { + arma_extra_debug_sigprint(); + + Cube& m_local = const_cast< Cube& >(m); + + const uword m_n_slices = m_local.n_slices; + const uword m_n_elem_slice = m_local.n_elem_slice; + + const quasi_unwrap U(base_si.get_ref()); + const umat& si = U.M; + + arma_debug_check + ( + ( (si.is_vec() == false) && (si.is_empty() == false) ), + "Cube::slices(): given object must be a vector" + ); + + const uword* si_mem = si.memptr(); + const uword si_n_elem = si.n_elem; + + const unwrap_cube_check tmp(x.get_ref(), m_local); + const Cube& X = tmp.M; + + arma_debug_assert_same_size( m_local.n_rows, m_local.n_cols, si_n_elem, X.n_rows, X.n_cols, X.n_slices, "Cube::slices()" ); + + for(uword si_count=0; si_count < si_n_elem; ++si_count) + { + const uword i = si_mem[si_count]; + + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); + + eT* m_slice_ptr = m_local.slice_memptr(i); + const eT* X_slice_ptr = X.slice_memptr(si_count); + + if(is_same_type::yes) { arrayops::copy (m_slice_ptr, X_slice_ptr, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_plus (m_slice_ptr, X_slice_ptr, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_minus(m_slice_ptr, X_slice_ptr, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_mul (m_slice_ptr, X_slice_ptr, m_n_elem_slice); } + if(is_same_type::yes) { arrayops::inplace_div (m_slice_ptr, X_slice_ptr, m_n_elem_slice); } + } + } + + + +// +// + + + +template +inline +void +subview_cube_slices::fill(const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_cube_slices::zeros() + { + arma_extra_debug_sigprint(); + + inplace_op(eT(0)); + } + + + +template +inline +void +subview_cube_slices::ones() + { + arma_extra_debug_sigprint(); + + inplace_op(eT(1)); + } + + + +template +inline +void +subview_cube_slices::randu() + { + arma_extra_debug_sigprint(); + + inplace_rand(0); + } + + + +template +inline +void +subview_cube_slices::randn() + { + arma_extra_debug_sigprint(); + + inplace_rand(1); + } + + + +template +inline +void +subview_cube_slices::operator+= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_cube_slices::operator-= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_cube_slices::operator*= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_cube_slices::operator/= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +// +// + + + +template +template +inline +void +subview_cube_slices::operator_equ(const subview_cube_slices& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + + +template +template +inline +void +subview_cube_slices::operator= (const subview_cube_slices& x) + { + arma_extra_debug_sigprint(); + + (*this).operator_equ(x); + } + + + +//! work around compiler bugs +template +inline +void +subview_cube_slices::operator= (const subview_cube_slices& x) + { + arma_extra_debug_sigprint(); + + (*this).operator_equ(x); + } + + + +template +template +inline +void +subview_cube_slices::operator+= (const subview_cube_slices& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator-= (const subview_cube_slices& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator%= (const subview_cube_slices& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator/= (const subview_cube_slices& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator= (const BaseCube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator+= (const BaseCube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator-= (const BaseCube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator%= (const BaseCube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_cube_slices::operator/= (const BaseCube& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +// +// + + + +template +inline +void +subview_cube_slices::extract(Cube& out, const subview_cube_slices& in) + { + arma_extra_debug_sigprint(); + + const Cube& m_local = in.m; + + const uword m_n_slices = m_local.n_slices; + const uword m_n_elem_slice = m_local.n_elem_slice; + + const quasi_unwrap U(in.base_si.get_ref()); + const umat& si = U.M; + + arma_debug_check + ( + ( (si.is_vec() == false) && (si.is_empty() == false) ), + "Cube::slices(): given object must be a vector" + ); + + const uword* si_mem = si.memptr(); + const uword si_n_elem = si.n_elem; + + out.set_size(m_local.n_rows, m_local.n_cols, si_n_elem); + + for(uword si_count=0; si_count < si_n_elem; ++si_count) + { + const uword i = si_mem[si_count]; + + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); + + eT* out_slice_ptr = out.slice_memptr(si_count); + const eT* m_slice_ptr = m_local.slice_memptr(i); + + arrayops::copy(out_slice_ptr, m_slice_ptr, m_n_elem_slice); + } + } + + + +// TODO: implement a dedicated function instead of creating a temporary +template +inline +void +subview_cube_slices::plus_inplace(Cube& out, const subview_cube_slices& in) + { + arma_extra_debug_sigprint(); + + const Cube tmp(in); + + out += tmp; + } + + + +template +inline +void +subview_cube_slices::minus_inplace(Cube& out, const subview_cube_slices& in) + { + arma_extra_debug_sigprint(); + + const Cube tmp(in); + + out -= tmp; + } + + + +template +inline +void +subview_cube_slices::schur_inplace(Cube& out, const subview_cube_slices& in) + { + arma_extra_debug_sigprint(); + + const Cube tmp(in); + + out %= tmp; + } + + + +template +inline +void +subview_cube_slices::div_inplace(Cube& out, const subview_cube_slices& in) + { + arma_extra_debug_sigprint(); + + const Cube tmp(in); + + out /= tmp; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_each_bones.hpp b/src/armadillo/include/armadillo_bits/subview_each_bones.hpp new file mode 100644 index 0000000..dcb58cd --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_each_bones.hpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_each +//! @{ + + + +template +class subview_each_common + { + public: + + typedef typename parent::elem_type eT; + + const parent& P; + + template + inline void check_size(const Mat& A) const; + + + protected: + + arma_inline subview_each_common(const parent& in_P); + inline subview_each_common() = delete; + + arma_inline const Mat& get_mat_ref_helper(const Mat & X) const; + arma_inline const Mat& get_mat_ref_helper(const subview& X) const; + + arma_inline const Mat& get_mat_ref() const; + + template + arma_cold inline const std::string incompat_size_string(const Mat& A) const; + }; + + + + +template +class subview_each1 : public subview_each_common + { + protected: + + arma_inline subview_each1(const parent& in_P); + + + public: + + typedef typename parent::elem_type eT; + + inline ~subview_each1(); + inline subview_each1() = delete; + + // deliberately returning void + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + + + friend class Mat; + friend class subview; + }; + + + +template +class subview_each2 : public subview_each_common + { + protected: + + inline subview_each2(const parent& in_P, const Base& in_indices); + + + public: + + const Base& base_indices; + + typedef typename parent::elem_type eT; + + inline void check_indices(const Mat& indices) const; + + inline ~subview_each2(); + inline subview_each2() = delete; + + // deliberately returning void + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + + + friend class Mat; + friend class subview; + }; + + + +class subview_each1_aux + { + public: + + template + static inline Mat operator_plus(const subview_each1& X, const Base& Y); + + template + static inline Mat operator_minus(const subview_each1& X, const Base& Y); + + template + static inline Mat operator_minus(const Base& X, const subview_each1& Y); + + template + static inline Mat operator_schur(const subview_each1& X, const Base& Y); + + template + static inline Mat operator_div(const subview_each1& X,const Base& Y); + + template + static inline Mat operator_div(const Base& X, const subview_each1& Y); + }; + + + +class subview_each2_aux + { + public: + + template + static inline Mat operator_plus(const subview_each2& X, const Base& Y); + + template + static inline Mat operator_minus(const subview_each2& X, const Base& Y); + + template + static inline Mat operator_minus(const Base& X, const subview_each2& Y); + + template + static inline Mat operator_schur(const subview_each2& X, const Base& Y); + + template + static inline Mat operator_div(const subview_each2& X, const Base& Y); + + template + static inline Mat operator_div(const Base& X, const subview_each2& Y); + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_each_meat.hpp b/src/armadillo/include/armadillo_bits/subview_each_meat.hpp new file mode 100644 index 0000000..12d263e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_each_meat.hpp @@ -0,0 +1,1404 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_each +//! @{ + + +// +// +// subview_each_common + +template +inline +subview_each_common::subview_each_common(const parent& in_P) + : P(in_P) + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +const Mat& +subview_each_common::get_mat_ref_helper(const Mat& X) const + { + return X; + } + + + +template +arma_inline +const Mat& +subview_each_common::get_mat_ref_helper(const subview& X) const + { + return X.m; + } + + + +template +arma_inline +const Mat& +subview_each_common::get_mat_ref() const + { + return get_mat_ref_helper(P); + } + + + +template +template +inline +void +subview_each_common::check_size(const Mat& A) const + { + if(arma_config::debug) + { + if(mode == 0) + { + if( (A.n_rows != P.n_rows) || (A.n_cols != 1) ) + { + arma_stop_logic_error( incompat_size_string(A) ); + } + } + else + { + if( (A.n_rows != 1) || (A.n_cols != P.n_cols) ) + { + arma_stop_logic_error( incompat_size_string(A) ); + } + } + } + } + + + +template +template +inline +const std::string +subview_each_common::incompat_size_string(const Mat& A) const + { + std::ostringstream tmp; + + if(mode == 0) + { + tmp << "each_col(): incompatible size; expected " << P.n_rows << "x1" << ", got " << A.n_rows << 'x' << A.n_cols; + } + else + { + tmp << "each_row(): incompatible size; expected 1x" << P.n_cols << ", got " << A.n_rows << 'x' << A.n_cols; + } + + return tmp.str(); + } + + + +// +// +// subview_each1 + + + +template +inline +subview_each1::~subview_each1() + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_each1::subview_each1(const parent& in_P) + : subview_each_common::subview_each_common(in_P) + { + arma_extra_debug_sigprint(); + } + + + +template +template +inline +void +subview_each1::operator= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const eT* A_mem = A.memptr(); + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::copy( p.colptr(i), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_set( p.colptr(i), A_mem[i], p_n_rows); + } + } + } + + + +template +template +inline +void +subview_each1::operator+= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const eT* A_mem = A.memptr(); + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_plus( p.colptr(i), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_plus( p.colptr(i), A_mem[i], p_n_rows); + } + } + } + + + +template +template +inline +void +subview_each1::operator-= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const eT* A_mem = A.memptr(); + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_minus( p.colptr(i), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_minus( p.colptr(i), A_mem[i], p_n_rows); + } + } + } + + + +template +template +inline +void +subview_each1::operator%= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const eT* A_mem = A.memptr(); + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_mul( p.colptr(i), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_mul( p.colptr(i), A_mem[i], p_n_rows); + } + } + } + + + +template +template +inline +void +subview_each1::operator/= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const eT* A_mem = A.memptr(); + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_div( p.colptr(i), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + arrayops::inplace_div( p.colptr(i), A_mem[i], p_n_rows); + } + } + } + + + +// +// +// subview_each2 + + + +template +inline +subview_each2::~subview_each2() + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_each2::subview_each2(const parent& in_P, const Base& in_indices) + : subview_each_common::subview_each_common(in_P) + , base_indices(in_indices) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_each2::check_indices(const Mat& indices) const + { + if(mode == 0) + { + arma_debug_check( ((indices.is_vec() == false) && (indices.is_empty() == false)), "each_col(): list of indices must be a vector" ); + } + else + { + arma_debug_check( ((indices.is_vec() == false) && (indices.is_empty() == false)), "each_row(): list of indices must be a vector" ); + } + } + + + +template +template +inline +void +subview_each2::operator= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const unwrap_check_mixed U( base_indices.get_ref(), (*this).get_mat_ref() ); + + check_indices(U.M); + + const eT* A_mem = A.memptr(); + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // each column + { + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::copy( p.colptr(col), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + for(uword col=0; col < p_n_cols; ++col) + { + p.at(row,col) = A_mem[col]; + } + } + } + } + + + +template +template +inline +void +subview_each2::operator+= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const unwrap_check_mixed U( base_indices.get_ref(), (*this).get_mat_ref() ); + + check_indices(U.M); + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // each column + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_plus( p.colptr(col), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + p.row(row) += A; + } + } + } + + + +template +template +inline +void +subview_each2::operator-= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const unwrap_check_mixed U( base_indices.get_ref(), (*this).get_mat_ref() ); + + check_indices(U.M); + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // each column + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_minus( p.colptr(col), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + p.row(row) -= A; + } + } + } + + + +template +template +inline +void +subview_each2::operator%= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const unwrap_check_mixed U( base_indices.get_ref(), (*this).get_mat_ref() ); + + check_indices(U.M); + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // each column + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_mul( p.colptr(col), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + p.row(row) %= A; + } + } + } + + + +template +template +inline +void +subview_each2::operator/= (const Base& in) + { + arma_extra_debug_sigprint(); + + parent& p = access::rw(subview_each_common::P); + + const unwrap_check tmp( in.get_ref(), (*this).get_mat_ref() ); + const Mat& A = tmp.M; + + subview_each_common::check_size(A); + + const unwrap_check_mixed U( base_indices.get_ref(), (*this).get_mat_ref() ); + + check_indices(U.M); + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // each column + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_div( p.colptr(col), A_mem, p_n_rows ); + } + } + else // each row + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + p.row(row) /= A; + } + } + } + + + +// +// +// subview_each1_aux + + + +template +inline +Mat +subview_each1_aux::operator_plus + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + const eT* A_mem = A.memptr(); + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] + A_mem[row]; + } + } + } + + if(mode == 1) // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + const eT A_val = A_mem[i]; + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] + A_val; + } + } + } + + return out; + } + + + +template +inline +Mat +subview_each1_aux::operator_minus + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + const eT* A_mem = A.memptr(); + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] - A_mem[row]; + } + } + } + + if(mode == 1) // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + const eT A_val = A_mem[i]; + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] - A_val; + } + } + } + + return out; + } + + + +template +inline +Mat +subview_each1_aux::operator_minus + ( + const Base& X, + const subview_each1& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + Y.check_size(A); + + const eT* A_mem = A.memptr(); + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = A_mem[row] - p_mem[row]; + } + } + } + + if(mode == 1) // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + const eT A_val = A_mem[i]; + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = A_val - p_mem[row]; + } + } + } + + return out; + } + + + +template +inline +Mat +subview_each1_aux::operator_schur + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + const eT* A_mem = A.memptr(); + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] * A_mem[row]; + } + } + } + + if(mode == 1) // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + const eT A_val = A_mem[i]; + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] * A_val; + } + } + } + + return out; + } + + + +template +inline +Mat +subview_each1_aux::operator_div + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + X.check_size(A); + + const eT* A_mem = A.memptr(); + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] / A_mem[row]; + } + } + } + + if(mode == 1) // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + const eT A_val = A_mem[i]; + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = p_mem[row] / A_val; + } + } + } + + return out; + } + + + +template +inline +Mat +subview_each1_aux::operator_div + ( + const Base& X, + const subview_each1& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + Y.check_size(A); + + const eT* A_mem = A.memptr(); + + if(mode == 0) // each column + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = A_mem[row] / p_mem[row]; + } + } + } + + if(mode == 1) // each row + { + for(uword i=0; i < p_n_cols; ++i) + { + const eT* p_mem = p.colptr(i); + eT* out_mem = out.colptr(i); + + const eT A_val = A_mem[i]; + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = A_val / p_mem[row]; + } + } + } + + return out; + } + + + +// +// +// subview_each2_aux + + + +template +inline +Mat +subview_each2_aux::operator_plus + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out = p; + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // process columns + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_plus( out.colptr(col), A_mem, p_n_rows ); + } + } + + if(mode == 1) // process rows + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + out.row(row) += A; + } + } + + return out; + } + + + +template +inline +Mat +subview_each2_aux::operator_minus + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out = p; + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // process columns + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_minus( out.colptr(col), A_mem, p_n_rows ); + } + } + + if(mode == 1) // process rows + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + out.row(row) -= A; + } + } + + return out; + } + + + +template +inline +Mat +subview_each2_aux::operator_minus + ( + const Base& X, + const subview_each2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out = p; + + const quasi_unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(Y.base_indices.get_ref()); + + Y.check_size(A); + Y.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // process columns + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + const eT* p_mem = p.colptr(col); + eT* out_mem = out.colptr(col); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = A_mem[row] - p_mem[row]; + } + } + } + + if(mode == 1) // process rows + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + out.row(row) = A - p.row(row); + } + } + + return out; + } + + + +template +inline +Mat +subview_each2_aux::operator_schur + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out = p; + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // process columns + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_mul( out.colptr(col), A_mem, p_n_rows ); + } + } + + if(mode == 1) // process rows + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + out.row(row) %= A; + } + } + + return out; + } + + + +template +inline +Mat +subview_each2_aux::operator_div + ( + const subview_each2& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = X.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out = p; + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(X.base_indices.get_ref()); + + X.check_size(A); + X.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // process columns + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + arrayops::inplace_div( out.colptr(col), A_mem, p_n_rows ); + } + } + + if(mode == 1) // process rows + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + out.row(row) /= A; + } + } + + return out; + } + + + +template +inline +Mat +subview_each2_aux::operator_div + ( + const Base& X, + const subview_each2& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& p = Y.P; + + const uword p_n_rows = p.n_rows; + const uword p_n_cols = p.n_cols; + + Mat out = p; + + const quasi_unwrap tmp(X.get_ref()); + const Mat& A = tmp.M; + + const unwrap U(Y.base_indices.get_ref()); + + Y.check_size(A); + Y.check_indices(U.M); + + const uword* indices_mem = U.M.memptr(); + const uword N = U.M.n_elem; + + if(mode == 0) // process columns + { + const eT* A_mem = A.memptr(); + + for(uword i=0; i < N; ++i) + { + const uword col = indices_mem[i]; + + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); + + const eT* p_mem = p.colptr(col); + eT* out_mem = out.colptr(col); + + for(uword row=0; row < p_n_rows; ++row) + { + out_mem[row] = A_mem[row] / p_mem[row]; + } + } + } + + if(mode == 1) // process rows + { + for(uword i=0; i < N; ++i) + { + const uword row = indices_mem[i]; + + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); + + out.row(row) = A / p.row(row); + } + } + + return out; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_elem1_bones.hpp b/src/armadillo/include/armadillo_bits/subview_elem1_bones.hpp new file mode 100644 index 0000000..2ac3cda --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_elem1_bones.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_elem1 +//! @{ + + + +template +class subview_elem1 : public Base< eT, subview_elem1 > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const Mat fake_m; + arma_aligned const Mat& m; + arma_aligned const Base& a; + + + protected: + + arma_inline subview_elem1(const Mat& in_m, const Base& in_a); + arma_inline subview_elem1(const Cube& in_q, const Base& in_a); + + + public: + + inline ~subview_elem1(); + inline subview_elem1() = delete; + + template inline void inplace_op(const eT val); + template inline void inplace_op(const subview_elem1& x ); + template inline void inplace_op(const Base& x ); + + arma_inline const Op,op_htrans> t() const; + arma_inline const Op,op_htrans> ht() const; + arma_inline const Op,op_strans> st() const; + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + inline void randu(); + inline void randn(); + + inline void operator+= (const eT val); + inline void operator-= (const eT val); + inline void operator*= (const eT val); + inline void operator/= (const eT val); + + + // deliberately returning void + template inline void operator_equ(const subview_elem1& x); + template inline void operator= (const subview_elem1& x); + inline void operator= (const subview_elem1& x); + template inline void operator+= (const subview_elem1& x); + template inline void operator-= (const subview_elem1& x); + template inline void operator%= (const subview_elem1& x); + template inline void operator/= (const subview_elem1& x); + + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + + inline static void extract(Mat& out, const subview_elem1& in); + + template inline static void mat_inplace_op(Mat& out, const subview_elem1& in); + + inline static void plus_inplace(Mat& out, const subview_elem1& in); + inline static void minus_inplace(Mat& out, const subview_elem1& in); + inline static void schur_inplace(Mat& out, const subview_elem1& in); + inline static void div_inplace(Mat& out, const subview_elem1& in); + + + friend class Mat; + friend class Cube; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_elem1_meat.hpp b/src/armadillo/include/armadillo_bits/subview_elem1_meat.hpp new file mode 100644 index 0000000..d1b6712 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_elem1_meat.hpp @@ -0,0 +1,953 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_elem1 +//! @{ + + +template +inline +subview_elem1::~subview_elem1() + { + arma_extra_debug_sigprint(); + } + + +template +arma_inline +subview_elem1::subview_elem1(const Mat& in_m, const Base& in_a) + : m(in_m) + , a(in_a) + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +subview_elem1::subview_elem1(const Cube& in_q, const Base& in_a) + : fake_m( const_cast< eT* >(in_q.memptr()), in_q.n_elem, 1, false ) + , m( fake_m ) + , a( in_a ) + { + arma_extra_debug_sigprint(); + } + + + +template +template +inline +void +subview_elem1::inplace_op(const eT val) + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(m); + + eT* m_mem = m_local.memptr(); + const uword m_n_elem = m_local.n_elem; + + const unwrap_check_mixed tmp(a.get_ref(), m_local); + const umat& aa = tmp.M; + + arma_debug_check + ( + ( (aa.is_vec() == false) && (aa.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* aa_mem = aa.memptr(); + const uword aa_n_elem = aa.n_elem; + + uword iq,jq; + for(iq=0, jq=1; jq < aa_n_elem; iq+=2, jq+=2) + { + const uword ii = aa_mem[iq]; + const uword jj = aa_mem[jq]; + + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_mem[ii] = val; m_mem[jj] = val; } + if(is_same_type::yes) { m_mem[ii] += val; m_mem[jj] += val; } + if(is_same_type::yes) { m_mem[ii] -= val; m_mem[jj] -= val; } + if(is_same_type::yes) { m_mem[ii] *= val; m_mem[jj] *= val; } + if(is_same_type::yes) { m_mem[ii] /= val; m_mem[jj] /= val; } + } + + if(iq < aa_n_elem) + { + const uword ii = aa_mem[iq]; + + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_mem[ii] = val; } + if(is_same_type::yes) { m_mem[ii] += val; } + if(is_same_type::yes) { m_mem[ii] -= val; } + if(is_same_type::yes) { m_mem[ii] *= val; } + if(is_same_type::yes) { m_mem[ii] /= val; } + } + } + + + +template +template +inline +void +subview_elem1::inplace_op(const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + subview_elem1& s = *this; + + if(&(s.m) == &(x.m)) + { + arma_extra_debug_print("subview_elem1::inplace_op(): aliasing detected"); + + const Mat tmp(x); + + if(is_same_type::yes) { s.operator= (tmp); } + if(is_same_type::yes) { s.operator+=(tmp); } + if(is_same_type::yes) { s.operator-=(tmp); } + if(is_same_type::yes) { s.operator%=(tmp); } + if(is_same_type::yes) { s.operator/=(tmp); } + } + else + { + Mat& s_m_local = const_cast< Mat& >(s.m); + const Mat& x_m_local = x.m; + + const unwrap_check_mixed s_tmp(s.a.get_ref(), s_m_local); + const unwrap_check_mixed x_tmp(x.a.get_ref(), s_m_local); + + const umat& s_aa = s_tmp.M; + const umat& x_aa = x_tmp.M; + + arma_debug_check + ( + ( ((s_aa.is_vec() == false) && (s_aa.is_empty() == false)) || ((x_aa.is_vec() == false) && (x_aa.is_empty() == false)) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* s_aa_mem = s_aa.memptr(); + const uword* x_aa_mem = x_aa.memptr(); + + const uword s_aa_n_elem = s_aa.n_elem; + + arma_debug_check( (s_aa_n_elem != x_aa.n_elem), "Mat::elem(): size mismatch" ); + + + eT* s_m_mem = s_m_local.memptr(); + const uword s_m_n_elem = s_m_local.n_elem; + + const eT* x_m_mem = x_m_local.memptr(); + const uword x_m_n_elem = x_m_local.n_elem; + + uword iq,jq; + for(iq=0, jq=1; jq < s_aa_n_elem; iq+=2, jq+=2) + { + const uword s_ii = s_aa_mem[iq]; + const uword s_jj = s_aa_mem[jq]; + + const uword x_ii = x_aa_mem[iq]; + const uword x_jj = x_aa_mem[jq]; + + arma_debug_check_bounds + ( + (s_ii >= s_m_n_elem) || (s_jj >= s_m_n_elem) || (x_ii >= x_m_n_elem) || (x_jj >= x_m_n_elem), + "Mat::elem(): index out of bounds" + ); + + if(is_same_type::yes) { s_m_mem[s_ii] = x_m_mem[x_ii]; s_m_mem[s_jj] = x_m_mem[x_jj]; } + if(is_same_type::yes) { s_m_mem[s_ii] += x_m_mem[x_ii]; s_m_mem[s_jj] += x_m_mem[x_jj]; } + if(is_same_type::yes) { s_m_mem[s_ii] -= x_m_mem[x_ii]; s_m_mem[s_jj] -= x_m_mem[x_jj]; } + if(is_same_type::yes) { s_m_mem[s_ii] *= x_m_mem[x_ii]; s_m_mem[s_jj] *= x_m_mem[x_jj]; } + if(is_same_type::yes) { s_m_mem[s_ii] /= x_m_mem[x_ii]; s_m_mem[s_jj] /= x_m_mem[x_jj]; } + } + + if(iq < s_aa_n_elem) + { + const uword s_ii = s_aa_mem[iq]; + const uword x_ii = x_aa_mem[iq]; + + arma_debug_check_bounds + ( + ( (s_ii >= s_m_n_elem) || (x_ii >= x_m_n_elem) ), + "Mat::elem(): index out of bounds" + ); + + if(is_same_type::yes) { s_m_mem[s_ii] = x_m_mem[x_ii]; } + if(is_same_type::yes) { s_m_mem[s_ii] += x_m_mem[x_ii]; } + if(is_same_type::yes) { s_m_mem[s_ii] -= x_m_mem[x_ii]; } + if(is_same_type::yes) { s_m_mem[s_ii] *= x_m_mem[x_ii]; } + if(is_same_type::yes) { s_m_mem[s_ii] /= x_m_mem[x_ii]; } + } + } + } + + + +template +template +inline +void +subview_elem1::inplace_op(const Base& x) + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(m); + + eT* m_mem = m_local.memptr(); + const uword m_n_elem = m_local.n_elem; + + const unwrap_check_mixed aa_tmp(a.get_ref(), m_local); + const umat& aa = aa_tmp.M; + + arma_debug_check + ( + ( (aa.is_vec() == false) && (aa.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* aa_mem = aa.memptr(); + const uword aa_n_elem = aa.n_elem; + + const Proxy P(x.get_ref()); + + arma_debug_check( (aa_n_elem != P.get_n_elem()), "Mat::elem(): size mismatch" ); + + const bool is_alias = P.is_alias(m); + + if( (is_alias == false) && (Proxy::use_at == false) ) + { + typename Proxy::ea_type X = P.get_ea(); + + uword iq,jq; + for(iq=0, jq=1; jq < aa_n_elem; iq+=2, jq+=2) + { + const uword ii = aa_mem[iq]; + const uword jj = aa_mem[jq]; + + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_mem[ii] = X[iq]; m_mem[jj] = X[jq]; } + if(is_same_type::yes) { m_mem[ii] += X[iq]; m_mem[jj] += X[jq]; } + if(is_same_type::yes) { m_mem[ii] -= X[iq]; m_mem[jj] -= X[jq]; } + if(is_same_type::yes) { m_mem[ii] *= X[iq]; m_mem[jj] *= X[jq]; } + if(is_same_type::yes) { m_mem[ii] /= X[iq]; m_mem[jj] /= X[jq]; } + } + + if(iq < aa_n_elem) + { + const uword ii = aa_mem[iq]; + + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_mem[ii] = X[iq]; } + if(is_same_type::yes) { m_mem[ii] += X[iq]; } + if(is_same_type::yes) { m_mem[ii] -= X[iq]; } + if(is_same_type::yes) { m_mem[ii] *= X[iq]; } + if(is_same_type::yes) { m_mem[ii] /= X[iq]; } + } + } + else + { + arma_extra_debug_print("subview_elem1::inplace_op(): aliasing or use_at detected"); + + const unwrap_check::stored_type> tmp(P.Q, is_alias); + const Mat& M = tmp.M; + + const eT* X = M.memptr(); + + uword iq,jq; + for(iq=0, jq=1; jq < aa_n_elem; iq+=2, jq+=2) + { + const uword ii = aa_mem[iq]; + const uword jj = aa_mem[jq]; + + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_mem[ii] = X[iq]; m_mem[jj] = X[jq]; } + if(is_same_type::yes) { m_mem[ii] += X[iq]; m_mem[jj] += X[jq]; } + if(is_same_type::yes) { m_mem[ii] -= X[iq]; m_mem[jj] -= X[jq]; } + if(is_same_type::yes) { m_mem[ii] *= X[iq]; m_mem[jj] *= X[jq]; } + if(is_same_type::yes) { m_mem[ii] /= X[iq]; m_mem[jj] /= X[jq]; } + } + + if(iq < aa_n_elem) + { + const uword ii = aa_mem[iq]; + + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_mem[ii] = X[iq]; } + if(is_same_type::yes) { m_mem[ii] += X[iq]; } + if(is_same_type::yes) { m_mem[ii] -= X[iq]; } + if(is_same_type::yes) { m_mem[ii] *= X[iq]; } + if(is_same_type::yes) { m_mem[ii] /= X[iq]; } + } + } + } + + + +// +// + + + +template +arma_inline +const Op,op_htrans> +subview_elem1::t() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_htrans> +subview_elem1::ht() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +subview_elem1::st() const + { + return Op,op_strans>(*this); + } + + + +template +inline +void +subview_elem1::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(m); + + eT* m_mem = m_local.memptr(); + const uword m_n_elem = m_local.n_elem; + + const unwrap_check_mixed tmp(a.get_ref(), m_local); + const umat& aa = tmp.M; + + arma_debug_check + ( + ( (aa.is_vec() == false) && (aa.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* aa_mem = aa.memptr(); + const uword aa_n_elem = aa.n_elem; + + if(arma_isnan(old_val)) + { + for(uword iq=0; iq < aa_n_elem; ++iq) + { + const uword ii = aa_mem[iq]; + + arma_debug_check_bounds( (ii >= m_n_elem), "Mat::elem(): index out of bounds" ); + + eT& val = m_mem[ii]; + + val = (arma_isnan(val)) ? new_val : val; + } + } + else + { + for(uword iq=0; iq < aa_n_elem; ++iq) + { + const uword ii = aa_mem[iq]; + + arma_debug_check_bounds( (ii >= m_n_elem), "Mat::elem(): index out of bounds" ); + + eT& val = m_mem[ii]; + + val = (val == old_val) ? new_val : val; + } + } + } + + + +template +inline +void +subview_elem1::clean(const pod_type threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem1::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem1::fill(const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem1::zeros() + { + arma_extra_debug_sigprint(); + + inplace_op(eT(0)); + } + + + +template +inline +void +subview_elem1::ones() + { + arma_extra_debug_sigprint(); + + inplace_op(eT(1)); + } + + + +template +inline +void +subview_elem1::randu() + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(m); + + eT* m_mem = m_local.memptr(); + const uword m_n_elem = m_local.n_elem; + + const unwrap_check_mixed tmp(a.get_ref(), m_local); + const umat& aa = tmp.M; + + arma_debug_check + ( + ( (aa.is_vec() == false) && (aa.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* aa_mem = aa.memptr(); + const uword aa_n_elem = aa.n_elem; + + uword iq,jq; + for(iq=0, jq=1; jq < aa_n_elem; iq+=2, jq+=2) + { + const uword ii = aa_mem[iq]; + const uword jj = aa_mem[jq]; + + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + + const eT val1 = eT(arma_rng::randu()); + const eT val2 = eT(arma_rng::randu()); + + m_mem[ii] = val1; + m_mem[jj] = val2; + } + + if(iq < aa_n_elem) + { + const uword ii = aa_mem[iq]; + + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + + m_mem[ii] = eT(arma_rng::randu()); + } + } + + + +template +inline +void +subview_elem1::randn() + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(m); + + eT* m_mem = m_local.memptr(); + const uword m_n_elem = m_local.n_elem; + + const unwrap_check_mixed tmp(a.get_ref(), m_local); + const umat& aa = tmp.M; + + arma_debug_check + ( + ( (aa.is_vec() == false) && (aa.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* aa_mem = aa.memptr(); + const uword aa_n_elem = aa.n_elem; + + uword iq,jq; + for(iq=0, jq=1; jq < aa_n_elem; iq+=2, jq+=2) + { + const uword ii = aa_mem[iq]; + const uword jj = aa_mem[jq]; + + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + + arma_rng::randn::dual_val( m_mem[ii], m_mem[jj] ); + } + + if(iq < aa_n_elem) + { + const uword ii = aa_mem[iq]; + + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + + m_mem[ii] = eT(arma_rng::randn()); + } + } + + + +template +inline +void +subview_elem1::operator+= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem1::operator-= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem1::operator*= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem1::operator/= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +// +// + + + +template +template +inline +void +subview_elem1::operator_equ(const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + + +template +template +inline +void +subview_elem1::operator= (const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + (*this).operator_equ(x); + } + + + +//! work around compiler bugs +template +inline +void +subview_elem1::operator= (const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + (*this).operator_equ(x); + } + + + +template +template +inline +void +subview_elem1::operator+= (const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator-= (const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator%= (const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator/= (const subview_elem1& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator+= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator-= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator%= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem1::operator/= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +// +// + + + +template +inline +void +subview_elem1::extract(Mat& actual_out, const subview_elem1& in) + { + arma_extra_debug_sigprint(); + + const unwrap_check_mixed tmp1(in.a.get_ref(), actual_out); + const umat& aa = tmp1.M; + + arma_debug_check + ( + ( (aa.is_vec() == false) && (aa.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* aa_mem = aa.memptr(); + const uword aa_n_elem = aa.n_elem; + + const Mat& m_local = in.m; + + const eT* m_mem = m_local.memptr(); + const uword m_n_elem = m_local.n_elem; + + const bool alias = (&actual_out == &m_local); + + if(alias) { arma_extra_debug_print("subview_elem1::extract(): aliasing detected"); } + + Mat* tmp_out = alias ? new Mat() : nullptr; + Mat& out = alias ? *tmp_out : actual_out; + + out.set_size(aa_n_elem, 1); + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + + out_mem[i] = m_mem[ii]; + out_mem[j] = m_mem[jj]; + } + + if(i < aa_n_elem) + { + const uword ii = aa_mem[i]; + + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + + out_mem[i] = m_mem[ii]; + } + + if(alias) + { + actual_out.steal_mem(out); + delete tmp_out; + } + } + + + +template +template +inline +void +subview_elem1::mat_inplace_op(Mat& out, const subview_elem1& in) + { + arma_extra_debug_sigprint(); + + const unwrap tmp1(in.a.get_ref()); + const umat& aa = tmp1.M; + + arma_debug_check + ( + ( (aa.is_vec() == false) && (aa.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* aa_mem = aa.memptr(); + const uword aa_n_elem = aa.n_elem; + + const unwrap_check< Mat > tmp2(in.m, out); + const Mat& m_local = tmp2.M; + + const eT* m_mem = m_local.memptr(); + const uword m_n_elem = m_local.n_elem; + + arma_debug_check( (out.n_elem != aa_n_elem), "Mat::elem(): size mismatch" ); + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { out_mem[i] += m_mem[ii]; out_mem[j] += m_mem[jj]; } + if(is_same_type::yes) { out_mem[i] -= m_mem[ii]; out_mem[j] -= m_mem[jj]; } + if(is_same_type::yes) { out_mem[i] *= m_mem[ii]; out_mem[j] *= m_mem[jj]; } + if(is_same_type::yes) { out_mem[i] /= m_mem[ii]; out_mem[j] /= m_mem[jj]; } + } + + if(i < aa_n_elem) + { + const uword ii = aa_mem[i]; + + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { out_mem[i] += m_mem[ii]; } + if(is_same_type::yes) { out_mem[i] -= m_mem[ii]; } + if(is_same_type::yes) { out_mem[i] *= m_mem[ii]; } + if(is_same_type::yes) { out_mem[i] /= m_mem[ii]; } + } + } + + + +template +inline +void +subview_elem1::plus_inplace(Mat& out, const subview_elem1& in) + { + arma_extra_debug_sigprint(); + + mat_inplace_op(out, in); + } + + + +template +inline +void +subview_elem1::minus_inplace(Mat& out, const subview_elem1& in) + { + arma_extra_debug_sigprint(); + + mat_inplace_op(out, in); + } + + + +template +inline +void +subview_elem1::schur_inplace(Mat& out, const subview_elem1& in) + { + arma_extra_debug_sigprint(); + + mat_inplace_op(out, in); + } + + + +template +inline +void +subview_elem1::div_inplace(Mat& out, const subview_elem1& in) + { + arma_extra_debug_sigprint(); + + mat_inplace_op(out, in); + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_elem2_bones.hpp b/src/armadillo/include/armadillo_bits/subview_elem2_bones.hpp new file mode 100644 index 0000000..d4c4cbe --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_elem2_bones.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_elem2 +//! @{ + + + +template +class subview_elem2 : public Base< eT, subview_elem2 > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const Mat& m; + + arma_aligned const Base& base_ri; + arma_aligned const Base& base_ci; + + const bool all_rows; + const bool all_cols; + + + protected: + + arma_inline subview_elem2(const Mat& in_m, const Base& in_ri, const Base& in_ci, const bool in_all_rows, const bool in_all_cols); + + + public: + + inline ~subview_elem2(); + inline subview_elem2() = delete; + + template + inline void inplace_op(const eT val); + + template + inline void inplace_op(const Base& x); + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + + inline void operator+= (const eT val); + inline void operator-= (const eT val); + inline void operator*= (const eT val); + inline void operator/= (const eT val); + + + // deliberately returning void + template inline void operator_equ(const subview_elem2& x); + template inline void operator= (const subview_elem2& x); + inline void operator= (const subview_elem2& x); + + template inline void operator+= (const subview_elem2& x); + template inline void operator-= (const subview_elem2& x); + template inline void operator%= (const subview_elem2& x); + template inline void operator/= (const subview_elem2& x); + + template inline void operator= (const Base& x); + template inline void operator+= (const Base& x); + template inline void operator-= (const Base& x); + template inline void operator%= (const Base& x); + template inline void operator/= (const Base& x); + + template inline void operator= (const SpBase& x); + template inline void operator+= (const SpBase& x); + template inline void operator-= (const SpBase& x); + template inline void operator%= (const SpBase& x); + template inline void operator/= (const SpBase& x); + + inline static void extract(Mat& out, const subview_elem2& in); + + inline static void plus_inplace(Mat& out, const subview_elem2& in); + inline static void minus_inplace(Mat& out, const subview_elem2& in); + inline static void schur_inplace(Mat& out, const subview_elem2& in); + inline static void div_inplace(Mat& out, const subview_elem2& in); + + + friend class Mat; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_elem2_meat.hpp b/src/armadillo/include/armadillo_bits/subview_elem2_meat.hpp new file mode 100644 index 0000000..69d5f5d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_elem2_meat.hpp @@ -0,0 +1,873 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_elem2 +//! @{ + + +template +inline +subview_elem2::~subview_elem2() + { + arma_extra_debug_sigprint(); + } + + +template +arma_inline +subview_elem2::subview_elem2 + ( + const Mat& in_m, + const Base& in_ri, + const Base& in_ci, + const bool in_all_rows, + const bool in_all_cols + ) + : m (in_m ) + , base_ri (in_ri ) + , base_ci (in_ci ) + , all_rows (in_all_rows) + , all_cols (in_all_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +template +inline +void +subview_elem2::inplace_op(const eT val) + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(m); + + const uword m_n_rows = m_local.n_rows; + const uword m_n_cols = m_local.n_cols; + + if( (all_rows == false) && (all_cols == false) ) + { + const unwrap_check_mixed tmp1(base_ri.get_ref(), m_local); + const unwrap_check_mixed tmp2(base_ci.get_ref(), m_local); + + const umat& ri = tmp1.M; + const umat& ci = tmp2.M; + + arma_debug_check + ( + ( ((ri.is_vec() == false) && (ri.is_empty() == false)) || ((ci.is_vec() == false) && (ci.is_empty() == false)) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ri_mem = ri.memptr(); + const uword ri_n_elem = ri.n_elem; + + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword col = ci_mem[ci_count]; + + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + + for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) + { + const uword row = ri_mem[ri_count]; + + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_local.at(row,col) = val; } + if(is_same_type::yes) { m_local.at(row,col) += val; } + if(is_same_type::yes) { m_local.at(row,col) -= val; } + if(is_same_type::yes) { m_local.at(row,col) *= val; } + if(is_same_type::yes) { m_local.at(row,col) /= val; } + } + } + } + else + if( (all_rows == true) && (all_cols == false) ) + { + const unwrap_check_mixed tmp2(base_ci.get_ref(), m_local); + + const umat& ci = tmp2.M; + + arma_debug_check + ( + ( (ci.is_vec() == false) && (ci.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword col = ci_mem[ci_count]; + + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + + eT* colptr = m_local.colptr(col); + + if(is_same_type::yes) { arrayops::inplace_set (colptr, val, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_plus (colptr, val, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_minus(colptr, val, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_mul (colptr, val, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_div (colptr, val, m_n_rows); } + } + } + else + if( (all_rows == false) && (all_cols == true) ) + { + const unwrap_check_mixed tmp1(base_ri.get_ref(), m_local); + + const umat& ri = tmp1.M; + + arma_debug_check + ( + ( (ri.is_vec() == false) && (ri.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ri_mem = ri.memptr(); + const uword ri_n_elem = ri.n_elem; + + for(uword col=0; col < m_n_cols; ++col) + { + for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) + { + const uword row = ri_mem[ri_count]; + + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_local.at(row,col) = val; } + if(is_same_type::yes) { m_local.at(row,col) += val; } + if(is_same_type::yes) { m_local.at(row,col) -= val; } + if(is_same_type::yes) { m_local.at(row,col) *= val; } + if(is_same_type::yes) { m_local.at(row,col) /= val; } + } + } + } + } + + + +template +template +inline +void +subview_elem2::inplace_op(const Base& x) + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(m); + + const uword m_n_rows = m_local.n_rows; + const uword m_n_cols = m_local.n_cols; + + const unwrap_check tmp(x.get_ref(), m_local); + const Mat& X = tmp.M; + + if( (all_rows == false) && (all_cols == false) ) + { + const unwrap_check_mixed tmp1(base_ri.get_ref(), m_local); + const unwrap_check_mixed tmp2(base_ci.get_ref(), m_local); + + const umat& ri = tmp1.M; + const umat& ci = tmp2.M; + + arma_debug_check + ( + ( ((ri.is_vec() == false) && (ri.is_empty() == false)) || ((ci.is_vec() == false) && (ci.is_empty() == false)) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ri_mem = ri.memptr(); + const uword ri_n_elem = ri.n_elem; + + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_assert_same_size( ri_n_elem, ci_n_elem, X.n_rows, X.n_cols, "Mat::elem()" ); + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword col = ci_mem[ci_count]; + + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + + for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) + { + const uword row = ri_mem[ri_count]; + + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_local.at(row,col) = X.at(ri_count, ci_count); } + if(is_same_type::yes) { m_local.at(row,col) += X.at(ri_count, ci_count); } + if(is_same_type::yes) { m_local.at(row,col) -= X.at(ri_count, ci_count); } + if(is_same_type::yes) { m_local.at(row,col) *= X.at(ri_count, ci_count); } + if(is_same_type::yes) { m_local.at(row,col) /= X.at(ri_count, ci_count); } + } + } + } + else + if( (all_rows == true) && (all_cols == false) ) + { + const unwrap_check_mixed tmp2(base_ci.get_ref(), m_local); + + const umat& ci = tmp2.M; + + arma_debug_check + ( + ( (ci.is_vec() == false) && (ci.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_assert_same_size( m_n_rows, ci_n_elem, X.n_rows, X.n_cols, "Mat::elem()" ); + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword col = ci_mem[ci_count]; + + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + + eT* m_colptr = m_local.colptr(col); + const eT* X_colptr = X.colptr(ci_count); + + if(is_same_type::yes) { arrayops::copy (m_colptr, X_colptr, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_plus (m_colptr, X_colptr, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_minus(m_colptr, X_colptr, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_mul (m_colptr, X_colptr, m_n_rows); } + if(is_same_type::yes) { arrayops::inplace_div (m_colptr, X_colptr, m_n_rows); } + } + } + else + if( (all_rows == false) && (all_cols == true) ) + { + const unwrap_check_mixed tmp1(base_ri.get_ref(), m_local); + + const umat& ri = tmp1.M; + + arma_debug_check + ( + ( (ri.is_vec() == false) && (ri.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ri_mem = ri.memptr(); + const uword ri_n_elem = ri.n_elem; + + arma_debug_assert_same_size( ri_n_elem, m_n_cols, X.n_rows, X.n_cols, "Mat::elem()" ); + + for(uword col=0; col < m_n_cols; ++col) + { + for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) + { + const uword row = ri_mem[ri_count]; + + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + + if(is_same_type::yes) { m_local.at(row,col) = X.at(ri_count, col); } + if(is_same_type::yes) { m_local.at(row,col) += X.at(ri_count, col); } + if(is_same_type::yes) { m_local.at(row,col) -= X.at(ri_count, col); } + if(is_same_type::yes) { m_local.at(row,col) *= X.at(ri_count, col); } + if(is_same_type::yes) { m_local.at(row,col) /= X.at(ri_count, col); } + } + } + } + } + + + +// +// + + + +template +inline +void +subview_elem2::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.replace(old_val, new_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem2::clean(const pod_type threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem2::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem2::fill(const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem2::zeros() + { + arma_extra_debug_sigprint(); + + inplace_op(eT(0)); + } + + + +template +inline +void +subview_elem2::ones() + { + arma_extra_debug_sigprint(); + + inplace_op(eT(1)); + } + + + +template +inline +void +subview_elem2::operator+= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem2::operator-= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem2::operator*= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview_elem2::operator/= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +// +// + + + +template +template +inline +void +subview_elem2::operator_equ(const subview_elem2& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + + +template +template +inline +void +subview_elem2::operator= (const subview_elem2& x) + { + arma_extra_debug_sigprint(); + + (*this).operator_equ(x); + } + + + +//! work around compiler bugs +template +inline +void +subview_elem2::operator= (const subview_elem2& x) + { + arma_extra_debug_sigprint(); + + (*this).operator_equ(x); + } + + + +template +template +inline +void +subview_elem2::operator+= (const subview_elem2& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator-= (const subview_elem2& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator%= (const subview_elem2& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator/= (const subview_elem2& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator+= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator-= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator%= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +template +template +inline +void +subview_elem2::operator/= (const Base& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x); + } + + + +// +// + + + +template +template +inline +void +subview_elem2::operator= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator+= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator-= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator%= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator/= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +// +// + + + +template +inline +void +subview_elem2::extract(Mat& actual_out, const subview_elem2& in) + { + arma_extra_debug_sigprint(); + + Mat& m_local = const_cast< Mat& >(in.m); + + const uword m_n_rows = m_local.n_rows; + const uword m_n_cols = m_local.n_cols; + + const bool alias = (&actual_out == &m_local); + + if(alias) { arma_extra_debug_print("subview_elem2::extract(): aliasing detected"); } + + Mat* tmp_out = alias ? new Mat() : nullptr; + Mat& out = alias ? *tmp_out : actual_out; + + if( (in.all_rows == false) && (in.all_cols == false) ) + { + const unwrap_check_mixed tmp1(in.base_ri.get_ref(), actual_out); + const unwrap_check_mixed tmp2(in.base_ci.get_ref(), actual_out); + + const umat& ri = tmp1.M; + const umat& ci = tmp2.M; + + arma_debug_check + ( + ( ((ri.is_vec() == false) && (ri.is_empty() == false)) || ((ci.is_vec() == false) && (ci.is_empty() == false)) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ri_mem = ri.memptr(); + const uword ri_n_elem = ri.n_elem; + + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + out.set_size(ri_n_elem, ci_n_elem); + + eT* out_mem = out.memptr(); + uword out_count = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword col = ci_mem[ci_count]; + + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + + for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) + { + const uword row = ri_mem[ri_count]; + + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + + out_mem[out_count] = m_local.at(row,col); + ++out_count; + } + } + } + else + if( (in.all_rows == true) && (in.all_cols == false) ) + { + const unwrap_check_mixed tmp2(in.base_ci.get_ref(), m_local); + + const umat& ci = tmp2.M; + + arma_debug_check + ( + ( (ci.is_vec() == false) && (ci.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + out.set_size(m_n_rows, ci_n_elem); + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword col = ci_mem[ci_count]; + + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + + arrayops::copy( out.colptr(ci_count), m_local.colptr(col), m_n_rows ); + } + } + else + if( (in.all_rows == false) && (in.all_cols == true) ) + { + const unwrap_check_mixed tmp1(in.base_ri.get_ref(), m_local); + + const umat& ri = tmp1.M; + + arma_debug_check + ( + ( (ri.is_vec() == false) && (ri.is_empty() == false) ), + "Mat::elem(): given object must be a vector" + ); + + const uword* ri_mem = ri.memptr(); + const uword ri_n_elem = ri.n_elem; + + out.set_size(ri_n_elem, m_n_cols); + + for(uword col=0; col < m_n_cols; ++col) + { + for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) + { + const uword row = ri_mem[ri_count]; + + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + + out.at(ri_count,col) = m_local.at(row,col); + } + } + } + + + if(alias) + { + actual_out.steal_mem(out); + + delete tmp_out; + } + } + + + +// TODO: implement a dedicated function instead of creating a temporary (but lots of potential aliasing issues) +template +inline +void +subview_elem2::plus_inplace(Mat& out, const subview_elem2& in) + { + arma_extra_debug_sigprint(); + + const Mat tmp(in); + + out += tmp; + } + + + +template +inline +void +subview_elem2::minus_inplace(Mat& out, const subview_elem2& in) + { + arma_extra_debug_sigprint(); + + const Mat tmp(in); + + out -= tmp; + } + + + +template +inline +void +subview_elem2::schur_inplace(Mat& out, const subview_elem2& in) + { + arma_extra_debug_sigprint(); + + const Mat tmp(in); + + out %= tmp; + } + + + +template +inline +void +subview_elem2::div_inplace(Mat& out, const subview_elem2& in) + { + arma_extra_debug_sigprint(); + + const Mat tmp(in); + + out /= tmp; + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_field_bones.hpp b/src/armadillo/include/armadillo_bits/subview_field_bones.hpp new file mode 100644 index 0000000..8ea8315 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_field_bones.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_field +//! @{ + + +//! Class for storing data required to construct or apply operations to a subfield +//! (ie. where the subfield starts and ends as well as a reference/pointer to the original field), +template +class subview_field + { + public: + + typedef oT object_type; + + const field& f; + + const uword aux_row1; + const uword aux_col1; + const uword aux_slice1; + + const uword n_rows; + const uword n_cols; + const uword n_slices; + const uword n_elem; + + + protected: + + arma_inline subview_field(const field& in_f, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols); + arma_inline subview_field(const field& in_f, const uword in_row1, const uword in_col1, const uword in_slice1, const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); + + + public: + + inline ~subview_field(); + inline subview_field() = delete; + + inline void operator= (const field& x); + inline void operator= (const subview_field& x); + + arma_warn_unused arma_inline oT& operator[](const uword i); + arma_warn_unused arma_inline const oT& operator[](const uword i) const; + + arma_warn_unused arma_inline oT& operator()(const uword i); + arma_warn_unused arma_inline const oT& operator()(const uword i) const; + + arma_warn_unused arma_inline oT& at(const uword row, const uword col); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col) const; + + arma_warn_unused arma_inline oT& at(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col, const uword slice) const; + + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col) const; + + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col, const uword slice) const; + + arma_warn_unused arma_inline bool is_empty() const; + + inline bool check_overlap(const subview_field& x) const; + + inline void print(const std::string extra_text = "") const; + inline void print(std::ostream& user_stream, const std::string extra_text = "") const; + + template inline void for_each(functor F); + template inline void for_each(functor F) const; + + inline void fill(const oT& x); + + inline static void extract(field& out, const subview_field& in); + + + friend class field; + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_field_meat.hpp b/src/armadillo/include/armadillo_bits/subview_field_meat.hpp new file mode 100644 index 0000000..dafc4df --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_field_meat.hpp @@ -0,0 +1,558 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview_field +//! @{ + + +template +inline +subview_field::~subview_field() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +subview_field::subview_field + ( + const field& in_f, + const uword in_row1, + const uword in_col1, + const uword in_n_rows, + const uword in_n_cols + ) + : f(in_f) + , aux_row1(in_row1) + , aux_col1(in_col1) + , aux_slice1(0) + , n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_slices( (in_f.n_slices > 0) ? uword(1) : uword(0) ) + , n_elem(in_n_rows*in_n_cols*n_slices) + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +subview_field::subview_field + ( + const field& in_f, + const uword in_row1, + const uword in_col1, + const uword in_slice1, + const uword in_n_rows, + const uword in_n_cols, + const uword in_n_slices + ) + : f(in_f) + , aux_row1(in_row1) + , aux_col1(in_col1) + , aux_slice1(in_slice1) + , n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_slices(in_n_slices) + , n_elem(in_n_rows*in_n_cols*in_n_slices) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_field::operator= (const field& x) + { + arma_extra_debug_sigprint(); + + subview_field& t = *this; + + arma_debug_check( (t.n_rows != x.n_rows) || (t.n_cols != x.n_cols) || (t.n_slices != x.n_slices), "incompatible field dimensions" ); + + if(t.n_slices == 1) + { + for(uword col=0; col < t.n_cols; ++col) + for(uword row=0; row < t.n_rows; ++row) + { + t.at(row,col) = x.at(row,col); + } + } + else + { + for(uword slice=0; slice < t.n_slices; ++slice) + for(uword col=0; col < t.n_cols; ++col ) + for(uword row=0; row < t.n_rows; ++row ) + { + t.at(row,col,slice) = x.at(row,col,slice); + } + } + } + + + +//! x.subfield(...) = y.subfield(...) +template +inline +void +subview_field::operator= (const subview_field& x) + { + arma_extra_debug_sigprint(); + + if(check_overlap(x)) + { + const field tmp(x); + + (*this).operator=(tmp); + + return; + } + + subview_field& t = *this; + + arma_debug_check( (t.n_rows != x.n_rows) || (t.n_cols != x.n_cols) || (t.n_slices != x.n_slices), "incompatible field dimensions" ); + + if(t.n_slices == 1) + { + for(uword col=0; col < t.n_cols; ++col) + for(uword row=0; row < t.n_rows; ++row) + { + t.at(row,col) = x.at(row,col); + } + } + else + { + for(uword slice=0; slice < t.n_slices; ++slice) + for(uword col=0; col < t.n_cols; ++col ) + for(uword row=0; row < t.n_rows; ++row ) + { + t.at(row,col,slice) = x.at(row,col,slice); + } + } + } + + + +template +arma_inline +oT& +subview_field::operator[](const uword i) + { + const uword n_elem_slice = n_rows*n_cols; + + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + const uword in_col = j / n_rows; + const uword in_row = j % n_rows; + + const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; + + return *((const_cast< field& >(f)).mem[index]); + } + + + +template +arma_inline +const oT& +subview_field::operator[](const uword i) const + { + const uword n_elem_slice = n_rows*n_cols; + + const uword in_slice = i / n_elem_slice; + const uword offset = in_slice * n_elem_slice; + const uword j = i - offset; + + const uword in_col = j / n_rows; + const uword in_row = j % n_rows; + + const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; + + return *(f.mem[index]); + } + + + +template +arma_inline +oT& +subview_field::operator()(const uword i) + { + arma_debug_check_bounds( (i >= n_elem), "subview_field::operator(): index out of bounds" ); + + return operator[](i); + } + + + +template +arma_inline +const oT& +subview_field::operator()(const uword i) const + { + arma_debug_check_bounds( (i >= n_elem), "subview_field::operator(): index out of bounds" ); + + return operator[](i); + } + + + +template +arma_inline +oT& +subview_field::operator()(const uword in_row, const uword in_col) + { + return operator()(in_row, in_col, 0); + } + + + +template +arma_inline +const oT& +subview_field::operator()(const uword in_row, const uword in_col) const + { + return operator()(in_row, in_col, 0); + } + + + +template +arma_inline +oT& +subview_field::operator()(const uword in_row, const uword in_col, const uword in_slice) + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "subview_field::operator(): index out of bounds" ); + + const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; + + return *((const_cast< field& >(f)).mem[index]); + } + + + +template +arma_inline +const oT& +subview_field::operator()(const uword in_row, const uword in_col, const uword in_slice) const + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "subview_field::operator(): index out of bounds" ); + + const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; + + return *(f.mem[index]); + } + + + +template +arma_inline +oT& +subview_field::at(const uword in_row, const uword in_col) + { + return at(in_row, in_col, 0); + } + + + +template +arma_inline +const oT& +subview_field::at(const uword in_row, const uword in_col) const + { + return at(in_row, in_col, 0); + } + + + +template +arma_inline +oT& +subview_field::at(const uword in_row, const uword in_col, const uword in_slice) + { + const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; + + return *((const_cast< field& >(f)).mem[index]); + } + + + +template +arma_inline +const oT& +subview_field::at(const uword in_row, const uword in_col, const uword in_slice) const + { + const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; + + return *(f.mem[index]); + } + + + +template +arma_inline +bool +subview_field::is_empty() const + { + return (n_elem == 0); + } + + + +template +inline +bool +subview_field::check_overlap(const subview_field& x) const + { + const subview_field& t = *this; + + if(&t.f != &x.f) + { + return false; + } + else + { + if( (t.n_elem == 0) || (x.n_elem == 0) ) + { + return false; + } + else + { + const uword t_row_start = t.aux_row1; + const uword t_row_end_p1 = t_row_start + t.n_rows; + + const uword t_col_start = t.aux_col1; + const uword t_col_end_p1 = t_col_start + t.n_cols; + + const uword t_slice_start = t.aux_slice1; + const uword t_slice_end_p1 = t_slice_start + t.n_slices; + + const uword x_row_start = x.aux_row1; + const uword x_row_end_p1 = x_row_start + x.n_rows; + + const uword x_col_start = x.aux_col1; + const uword x_col_end_p1 = x_col_start + x.n_cols; + + const uword x_slice_start = x.aux_slice1; + const uword x_slice_end_p1 = x_slice_start + x.n_slices; + + const bool outside_rows = ( (x_row_start >= t_row_end_p1 ) || (t_row_start >= x_row_end_p1 ) ); + const bool outside_cols = ( (x_col_start >= t_col_end_p1 ) || (t_col_start >= x_col_end_p1 ) ); + const bool outside_slices = ( (x_slice_start >= t_slice_end_p1) || (t_slice_start >= x_slice_end_p1) ); + + return ( (outside_rows == false) && (outside_cols == false) && (outside_slices == false) ); + } + } + } + + + +template +inline +void +subview_field::print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), *this); + } + + + +template +inline +void +subview_field::print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, *this); + } + + + +template +template +inline +void +subview_field::for_each(functor F) + { + arma_extra_debug_sigprint(); + + subview_field& t = *this; + + if(t.n_slices == 1) + { + for(uword col=0; col < t.n_cols; ++col) + for(uword row=0; row < t.n_rows; ++row) + { + F( t.at(row,col) ); + } + } + else + { + for(uword slice=0; slice < t.n_slices; ++slice) + for(uword col=0; col < t.n_cols; ++col ) + for(uword row=0; row < t.n_rows; ++row ) + { + F( t.at(row,col,slice) ); + } + } + } + + + +template +template +inline +void +subview_field::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + const subview_field& t = *this; + + if(t.n_slices == 1) + { + for(uword col=0; col < t.n_cols; ++col) + for(uword row=0; row < t.n_rows; ++row) + { + F( t.at(row,col) ); + } + } + else + { + for(uword slice=0; slice < t.n_slices; ++slice) + for(uword col=0; col < t.n_cols; ++col ) + for(uword row=0; row < t.n_rows; ++row ) + { + F( t.at(row,col,slice) ); + } + } + } + + + +template +inline +void +subview_field::fill(const oT& x) + { + arma_extra_debug_sigprint(); + + subview_field& t = *this; + + if(t.n_slices == 1) + { + for(uword col=0; col < t.n_cols; ++col) + for(uword row=0; row < t.n_rows; ++row) + { + t.at(row,col) = x; + } + } + else + { + for(uword slice=0; slice < t.n_slices; ++slice) + for(uword col=0; col < t.n_cols; ++col ) + for(uword row=0; row < t.n_rows; ++row ) + { + t.at(row,col,slice) = x; + } + } + } + + + +//! X = Y.subfield(...) +template +inline +void +subview_field::extract(field& actual_out, const subview_field& in) + { + arma_extra_debug_sigprint(); + + // + const bool alias = (&actual_out == &in.f); + + field* tmp = (alias) ? new field : nullptr; + field& out = (alias) ? (*tmp) : actual_out; + + // + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + const uword n_slices = in.n_slices; + + out.set_size(n_rows, n_cols, n_slices); + + arma_extra_debug_print(arma_str::format("out.n_rows = %u out.n_cols = %u out.n_slices = %u in.m.n_rows = %u in.m.n_cols = %u in.m.n_slices = %u") % out.n_rows % out.n_cols % out.n_slices % in.f.n_rows % in.f.n_cols % in.f.n_slices); + + if(n_slices == 1) + { + for(uword col = 0; col < n_cols; ++col) + for(uword row = 0; row < n_rows; ++row) + { + out.at(row,col) = in.at(row,col); + } + } + else + { + for(uword slice = 0; slice < n_slices; ++slice) + for(uword col = 0; col < n_cols; ++col ) + for(uword row = 0; row < n_rows; ++row ) + { + out.at(row,col,slice) = in.at(row,col,slice); + } + } + + if(alias) + { + actual_out = out; + delete tmp; + } + + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/subview_meat.hpp b/src/armadillo/include/armadillo_bits/subview_meat.hpp new file mode 100644 index 0000000..543383d --- /dev/null +++ b/src/armadillo/include/armadillo_bits/subview_meat.hpp @@ -0,0 +1,4974 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup subview +//! @{ + + +template +inline +subview::~subview() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +subview::subview(const Mat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols) + : m (in_m ) + , aux_row1(in_row1 ) + , aux_col1(in_col1 ) + , n_rows (in_n_rows) + , n_cols (in_n_cols) + , n_elem (in_n_rows*in_n_cols) + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +subview::subview(const subview& in) + : m (in.m ) + , aux_row1(in.aux_row1) + , aux_col1(in.aux_col1) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + } + + + +template +inline +subview::subview(subview&& in) + : m (in.m ) + , aux_row1(in.aux_row1) + , aux_col1(in.aux_col1) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + + // for paranoia + + access::rw(in.aux_row1) = 0; + access::rw(in.aux_col1) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_cols ) = 0; + access::rw(in.n_elem ) = 0; + } + + + +template +template +inline +void +subview::inplace_op(const eT val) + { + arma_extra_debug_sigprint(); + + subview& s = *this; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + if(s_n_rows == 1) + { + Mat& A = const_cast< Mat& >(s.m); + + const uword A_n_rows = A.n_rows; + + eT* Aptr = &(A.at(s.aux_row1,s.aux_col1)); + + uword jj; + for(jj=1; jj < s_n_cols; jj+=2) + { + if(is_same_type::yes) { (*Aptr) += val; Aptr += A_n_rows; (*Aptr) += val; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) -= val; Aptr += A_n_rows; (*Aptr) -= val; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) *= val; Aptr += A_n_rows; (*Aptr) *= val; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) /= val; Aptr += A_n_rows; (*Aptr) /= val; Aptr += A_n_rows; } + } + + if((jj-1) < s_n_cols) + { + if(is_same_type::yes) { (*Aptr) += val; } + if(is_same_type::yes) { (*Aptr) -= val; } + if(is_same_type::yes) { (*Aptr) *= val; } + if(is_same_type::yes) { (*Aptr) /= val; } + } + } + else + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + if(is_same_type::yes) { arrayops::inplace_plus ( colptr(ucol), val, s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( colptr(ucol), val, s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( colptr(ucol), val, s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( colptr(ucol), val, s_n_rows ); } + } + } + } + + + +template +template +inline +void +subview::inplace_op(const Base& in, const char* identifier) + { + arma_extra_debug_sigprint(); + + const Proxy P(in.get_ref()); + + subview& s = *this; + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + arma_debug_assert_same_size(s, P, identifier); + + const bool use_mp = arma_config::openmp && Proxy::use_mp && mp_gate::eval(s.n_elem); + const bool has_overlap = P.has_overlap(s); + + if(has_overlap) { arma_extra_debug_print("aliasing or overlap detected"); } + + if( (is_Mat::stored_type>::value) || (use_mp) || (has_overlap) ) + { + const unwrap_check::stored_type> tmp(P.Q, has_overlap); + const Mat& B = tmp.M; + + if(s_n_rows == 1) + { + Mat& A = const_cast< Mat& >(m); + + const uword A_n_rows = A.n_rows; + + eT* Aptr = &(A.at(aux_row1,aux_col1)); + const eT* Bptr = B.memptr(); + + uword jj; + for(jj=1; jj < s_n_cols; jj+=2) + { + const eT tmp1 = (*Bptr); Bptr++; + const eT tmp2 = (*Bptr); Bptr++; + + if(is_same_type::yes) { (*Aptr) = tmp1; Aptr += A_n_rows; (*Aptr) = tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) += tmp1; Aptr += A_n_rows; (*Aptr) += tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) -= tmp1; Aptr += A_n_rows; (*Aptr) -= tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) *= tmp1; Aptr += A_n_rows; (*Aptr) *= tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) /= tmp1; Aptr += A_n_rows; (*Aptr) /= tmp2; Aptr += A_n_rows; } + } + + if((jj-1) < s_n_cols) + { + if(is_same_type::yes) { (*Aptr) = (*Bptr); } + if(is_same_type::yes) { (*Aptr) += (*Bptr); } + if(is_same_type::yes) { (*Aptr) -= (*Bptr); } + if(is_same_type::yes) { (*Aptr) *= (*Bptr); } + if(is_same_type::yes) { (*Aptr) /= (*Bptr); } + } + } + else // not a row vector + { + if((s.aux_row1 == 0) && (s_n_rows == s.m.n_rows)) + { + if(is_same_type::yes) { arrayops::copy ( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_plus ( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_minus( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_mul ( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_div ( s.colptr(0), B.memptr(), s.n_elem ); } + } + else + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + if(is_same_type::yes) { arrayops::copy ( s.colptr(ucol), B.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_plus ( s.colptr(ucol), B.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( s.colptr(ucol), B.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( s.colptr(ucol), B.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( s.colptr(ucol), B.colptr(ucol), s_n_rows ); } + } + } + } + } + else // use the Proxy + { + if(s_n_rows == 1) + { + Mat& A = const_cast< Mat& >(m); + + const uword A_n_rows = A.n_rows; + + eT* Aptr = &(A.at(aux_row1,aux_col1)); + + uword jj; + for(jj=1; jj < s_n_cols; jj+=2) + { + const uword ii = (jj-1); + + const eT tmp1 = (Proxy::use_at) ? P.at(0,ii) : P[ii]; + const eT tmp2 = (Proxy::use_at) ? P.at(0,jj) : P[jj]; + + if(is_same_type::yes) { (*Aptr) = tmp1; Aptr += A_n_rows; (*Aptr) = tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) += tmp1; Aptr += A_n_rows; (*Aptr) += tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) -= tmp1; Aptr += A_n_rows; (*Aptr) -= tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) *= tmp1; Aptr += A_n_rows; (*Aptr) *= tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) /= tmp1; Aptr += A_n_rows; (*Aptr) /= tmp2; Aptr += A_n_rows; } + } + + const uword ii = (jj-1); + if(ii < s_n_cols) + { + if(is_same_type::yes) { (*Aptr) = (Proxy::use_at) ? P.at(0,ii) : P[ii]; } + if(is_same_type::yes) { (*Aptr) += (Proxy::use_at) ? P.at(0,ii) : P[ii]; } + if(is_same_type::yes) { (*Aptr) -= (Proxy::use_at) ? P.at(0,ii) : P[ii]; } + if(is_same_type::yes) { (*Aptr) *= (Proxy::use_at) ? P.at(0,ii) : P[ii]; } + if(is_same_type::yes) { (*Aptr) /= (Proxy::use_at) ? P.at(0,ii) : P[ii]; } + } + } + else // not a row vector + { + if(Proxy::use_at) + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + eT* s_col_data = s.colptr(ucol); + + uword jj; + for(jj=1; jj < s_n_rows; jj+=2) + { + const uword ii = (jj-1); + + const eT tmp1 = P.at(ii,ucol); + const eT tmp2 = P.at(jj,ucol); + + if(is_same_type::yes) { (*s_col_data) = tmp1; s_col_data++; (*s_col_data) = tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) += tmp1; s_col_data++; (*s_col_data) += tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) -= tmp1; s_col_data++; (*s_col_data) -= tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) *= tmp1; s_col_data++; (*s_col_data) *= tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) /= tmp1; s_col_data++; (*s_col_data) /= tmp2; s_col_data++; } + } + + const uword ii = (jj-1); + if(ii < s_n_rows) + { + if(is_same_type::yes) { (*s_col_data) = P.at(ii,ucol); } + if(is_same_type::yes) { (*s_col_data) += P.at(ii,ucol); } + if(is_same_type::yes) { (*s_col_data) -= P.at(ii,ucol); } + if(is_same_type::yes) { (*s_col_data) *= P.at(ii,ucol); } + if(is_same_type::yes) { (*s_col_data) /= P.at(ii,ucol); } + } + } + } + else + { + typename Proxy::ea_type Pea = P.get_ea(); + + uword count = 0; + + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + eT* s_col_data = s.colptr(ucol); + + uword jj; + for(jj=1; jj < s_n_rows; jj+=2) + { + const eT tmp1 = Pea[count]; count++; + const eT tmp2 = Pea[count]; count++; + + if(is_same_type::yes) { (*s_col_data) = tmp1; s_col_data++; (*s_col_data) = tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) += tmp1; s_col_data++; (*s_col_data) += tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) -= tmp1; s_col_data++; (*s_col_data) -= tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) *= tmp1; s_col_data++; (*s_col_data) *= tmp2; s_col_data++; } + if(is_same_type::yes) { (*s_col_data) /= tmp1; s_col_data++; (*s_col_data) /= tmp2; s_col_data++; } + } + + if((jj-1) < s_n_rows) + { + if(is_same_type::yes) { (*s_col_data) = Pea[count]; count++; } + if(is_same_type::yes) { (*s_col_data) += Pea[count]; count++; } + if(is_same_type::yes) { (*s_col_data) -= Pea[count]; count++; } + if(is_same_type::yes) { (*s_col_data) *= Pea[count]; count++; } + if(is_same_type::yes) { (*s_col_data) /= Pea[count]; count++; } + } + } + } + } + } + } + + + +template +template +inline +void +subview::inplace_op(const subview& x, const char* identifier) + { + arma_extra_debug_sigprint(); + + if(check_overlap(x)) + { + const Mat tmp(x); + + if(is_same_type::yes) { (*this).operator= (tmp); } + if(is_same_type::yes) { (*this).operator+=(tmp); } + if(is_same_type::yes) { (*this).operator-=(tmp); } + if(is_same_type::yes) { (*this).operator%=(tmp); } + if(is_same_type::yes) { (*this).operator/=(tmp); } + + return; + } + + subview& s = *this; + + arma_debug_assert_same_size(s, x, identifier); + + const uword s_n_cols = s.n_cols; + const uword s_n_rows = s.n_rows; + + if(s_n_rows == 1) + { + Mat& A = const_cast< Mat& >(s.m); + const Mat& B = x.m; + + const uword A_n_rows = A.n_rows; + const uword B_n_rows = B.n_rows; + + eT* Aptr = &(A.at(s.aux_row1,s.aux_col1)); + const eT* Bptr = &(B.at(x.aux_row1,x.aux_col1)); + + uword jj; + for(jj=1; jj < s_n_cols; jj+=2) + { + const eT tmp1 = (*Bptr); Bptr += B_n_rows; + const eT tmp2 = (*Bptr); Bptr += B_n_rows; + + if(is_same_type::yes) { (*Aptr) = tmp1; Aptr += A_n_rows; (*Aptr) = tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) += tmp1; Aptr += A_n_rows; (*Aptr) += tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) -= tmp1; Aptr += A_n_rows; (*Aptr) -= tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) *= tmp1; Aptr += A_n_rows; (*Aptr) *= tmp2; Aptr += A_n_rows; } + if(is_same_type::yes) { (*Aptr) /= tmp1; Aptr += A_n_rows; (*Aptr) /= tmp2; Aptr += A_n_rows; } + } + + if((jj-1) < s_n_cols) + { + if(is_same_type::yes) { (*Aptr) = (*Bptr); } + if(is_same_type::yes) { (*Aptr) += (*Bptr); } + if(is_same_type::yes) { (*Aptr) -= (*Bptr); } + if(is_same_type::yes) { (*Aptr) *= (*Bptr); } + if(is_same_type::yes) { (*Aptr) /= (*Bptr); } + } + } + else + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + if(is_same_type::yes) { arrayops::copy ( s.colptr(ucol), x.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_plus ( s.colptr(ucol), x.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( s.colptr(ucol), x.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( s.colptr(ucol), x.colptr(ucol), s_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( s.colptr(ucol), x.colptr(ucol), s_n_rows ); } + } + } + } + + + +template +inline +void +subview::operator= (const eT val) + { + arma_extra_debug_sigprint(); + + if(n_elem != 1) + { + arma_debug_assert_same_size(n_rows, n_cols, 1, 1, "copy into submatrix"); + } + + Mat& X = const_cast< Mat& >(m); + + X.at(aux_row1, aux_col1) = val; + } + + + +template +inline +void +subview::operator+= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview::operator-= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview::operator*= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview::operator/= (const eT val) + { + arma_extra_debug_sigprint(); + + inplace_op(val); + } + + + +template +inline +void +subview::operator= (const subview& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "copy into submatrix"); + } + + + +template +inline +void +subview::operator+= (const subview& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "addition"); + } + + + +template +inline +void +subview::operator-= (const subview& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "subtraction"); + } + + + +template +inline +void +subview::operator%= (const subview& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "element-wise multiplication"); + } + + + +template +inline +void +subview::operator/= (const subview& x) + { + arma_extra_debug_sigprint(); + + inplace_op(x, "element-wise division"); + } + + + +template +template +inline +void +subview::operator= (const Base& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "copy into submatrix"); + } + + + +template +template +inline +void +subview::operator+= (const Base& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "addition"); + } + + + +template +template +inline +void +subview::operator-= (const Base& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "subtraction"); + } + + + +template +template +inline +void +subview::operator%= (const Base& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "element-wise multiplication"); + } + + + +template +template +inline +void +subview::operator/= (const Base& in) + { + arma_extra_debug_sigprint(); + + inplace_op(in, "element-wise division"); + } + + + +template +template +inline +void +subview::operator=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpProxy p(x.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "copy into submatrix"); + + // Clear the subview. + zeros(); + + // Iterate through the sparse subview and set the nonzero values appropriately. + typename SpProxy::const_iterator_type cit = p.begin(); + typename SpProxy::const_iterator_type cit_end = p.end(); + + while(cit != cit_end) + { + at(cit.row(), cit.col()) = *cit; + ++cit; + } + } + + + +template +template +inline +void +subview::operator+=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpProxy p(x.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "addition"); + + // Iterate through the sparse subview and add its values. + typename SpProxy::const_iterator_type cit = p.begin(); + typename SpProxy::const_iterator_type cit_end = p.end(); + + while(cit != cit_end) + { + at(cit.row(), cit.col()) += *cit; + ++cit; + } + } + + + +template +template +inline +void +subview::operator-=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpProxy p(x.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "subtraction"); + + // Iterate through the sparse subview and subtract its values. + typename SpProxy::const_iterator_type cit = p.begin(); + typename SpProxy::const_iterator_type cit_end = p.end(); + + while(cit != cit_end) + { + at(cit.row(), cit.col()) -= *cit; + ++cit; + } + } + + + +template +template +inline +void +subview::operator%=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + const uword s_n_rows = (*this).n_rows; + const uword s_n_cols = (*this).n_cols; + + const SpProxy p(x.get_ref()); + + arma_debug_assert_same_size(s_n_rows, s_n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise multiplication"); + + if(n_elem == 0) { return; } + + if(p.get_n_nonzero() == 0) { (*this).zeros(); return; } + + // Iterate over nonzero values. + // Any zero values in the sparse expression will result in a zero in our subview. + typename SpProxy::const_iterator_type cit = p.begin(); + typename SpProxy::const_iterator_type cit_end = p.end(); + + uword r = 0; + uword c = 0; + + while(cit != cit_end) + { + const uword cit_row = cit.row(); + const uword cit_col = cit.col(); + + while( ((r == cit_row) && (c == cit_col)) == false ) + { + at(r,c) = eT(0); + + r++; if(r >= s_n_rows) { r = 0; c++; } + } + + at(r, c) *= (*cit); + + ++cit; + r++; if(r >= s_n_rows) { r = 0; c++; } + } + } + + + +template +template +inline +void +subview::operator/=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpProxy p(x.get_ref()); + + arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise division"); + + // This is probably going to fill your subview with a bunch of NaNs, + // so I'm not going to bother to implement it fast. + // You can have slow NaNs. They're fine too. + for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) + { + at(r, c) /= p.at(r, c); + } + } + + + +template +template +inline +typename enable_if2< is_same_type::value, void>::result +subview::operator= (const Gen& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(n_rows, n_cols, in.n_rows, in.n_cols, "copy into submatrix"); + + in.apply(*this); + } + + + +template +inline +void +subview::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (is_vec() == false), "copy into submatrix: size mismatch" ); + + const uword N = uword(list.size()); + + if(n_rows == 1) + { + arma_debug_assert_same_size(1, n_cols, 1, N, "copy into submatrix"); + + auto it = list.begin(); + + for(uword ii=0; ii < N; ++ii) { (*this).at(0,ii) = (*it); ++it; } + } + else + if(n_cols == 1) + { + arma_debug_assert_same_size(n_rows, 1, N, 1, "copy into submatrix"); + + arrayops::copy( (*this).colptr(0), list.begin(), N ); + } + } + + + +template +inline +void +subview::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + const Mat tmp(list); + + (*this).operator=(tmp); + } + + + +//! apply a functor to each element +template +template +inline +void +subview::for_each(functor F) + { + arma_extra_debug_sigprint(); + + Mat& X = const_cast< Mat& >(m); + + if(n_rows == 1) + { + const uword urow = aux_row1; + const uword start_col = aux_col1; + const uword end_col_plus1 = start_col + n_cols; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + { + F( X.at(urow, ucol) ); + } + } + else + { + const uword start_col = aux_col1; + const uword start_row = aux_row1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + for(uword urow = start_row; urow < end_row_plus1; ++urow) + { + F( X.at(urow, ucol) ); + } + } + } + + + +template +template +inline +void +subview::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + const Mat& X = m; + + if(n_rows == 1) + { + const uword urow = aux_row1; + const uword start_col = aux_col1; + const uword end_col_plus1 = start_col + n_cols; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + { + F( X.at(urow, ucol) ); + } + } + else + { + const uword start_col = aux_col1; + const uword start_row = aux_row1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + for(uword urow = start_row; urow < end_row_plus1; ++urow) + { + F( X.at(urow, ucol) ); + } + } + } + + + +//! transform each element in the subview using a functor +template +template +inline +void +subview::transform(functor F) + { + arma_extra_debug_sigprint(); + + Mat& X = const_cast< Mat& >(m); + + if(n_rows == 1) + { + const uword urow = aux_row1; + const uword start_col = aux_col1; + const uword end_col_plus1 = start_col + n_cols; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + { + X.at(urow, ucol) = eT( F( X.at(urow, ucol) ) ); + } + } + else + { + const uword start_col = aux_col1; + const uword start_row = aux_row1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + for(uword urow = start_row; urow < end_row_plus1; ++urow) + { + X.at(urow, ucol) = eT( F( X.at(urow, ucol) ) ); + } + } + } + + + +//! imbue (fill) the subview with values provided by a functor +template +template +inline +void +subview::imbue(functor F) + { + arma_extra_debug_sigprint(); + + Mat& X = const_cast< Mat& >(m); + + if(n_rows == 1) + { + const uword urow = aux_row1; + const uword start_col = aux_col1; + const uword end_col_plus1 = start_col + n_cols; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + { + X.at(urow, ucol) = eT( F() ); + } + } + else + { + const uword start_col = aux_col1; + const uword start_row = aux_row1; + + const uword end_col_plus1 = start_col + n_cols; + const uword end_row_plus1 = start_row + n_rows; + + for(uword ucol = start_col; ucol < end_col_plus1; ++ucol) + for(uword urow = start_row; urow < end_row_plus1; ++urow) + { + X.at(urow, ucol) = eT( F() ); + } + } + } + + + +template +inline +void +subview::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + subview& s = *this; + + const uword s_n_cols = s.n_cols; + const uword s_n_rows = s.n_rows; + + if(s_n_rows == 1) + { + Mat& A = const_cast< Mat& >(s.m); + + const uword A_n_rows = A.n_rows; + + eT* Aptr = &(A.at(s.aux_row1,s.aux_col1)); + + if(arma_isnan(old_val)) + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + (*Aptr) = (arma_isnan(*Aptr)) ? new_val : (*Aptr); + + Aptr += A_n_rows; + } + } + else + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + (*Aptr) = ((*Aptr) == old_val) ? new_val : (*Aptr); + + Aptr += A_n_rows; + } + } + } + else + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + arrayops::replace(s.colptr(ucol), s_n_rows, old_val, new_val); + } + } + } + + + +template +inline +void +subview::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + subview& s = *this; + + const uword s_n_cols = s.n_cols; + const uword s_n_rows = s.n_rows; + + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + arrayops::clean( s.colptr(ucol), s_n_rows, threshold ); + } + } + + + +template +inline +void +subview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "subview::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + subview& s = *this; + + const uword s_n_cols = s.n_cols; + const uword s_n_rows = s.n_rows; + + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + arrayops::clamp( s.colptr(ucol), s_n_rows, min_val, max_val ); + } + } + + + +template +inline +void +subview::fill(const eT val) + { + arma_extra_debug_sigprint(); + + subview& s = *this; + + const uword s_n_cols = s.n_cols; + const uword s_n_rows = s.n_rows; + + if(s_n_rows == 1) + { + Mat& A = const_cast< Mat& >(s.m); + + const uword A_n_rows = A.n_rows; + + eT* Aptr = &(A.at(s.aux_row1,s.aux_col1)); + + uword jj; + for(jj=1; jj < s_n_cols; jj+=2) + { + (*Aptr) = val; Aptr += A_n_rows; + (*Aptr) = val; Aptr += A_n_rows; + } + + if((jj-1) < s_n_cols) + { + (*Aptr) = val; + } + } + else + { + if( (s.aux_row1 == 0) && (s_n_rows == s.m.n_rows) ) + { + arrayops::inplace_set( s.colptr(0), val, s.n_elem ); + } + else + { + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + arrayops::inplace_set( s.colptr(ucol), val, s_n_rows ); + } + } + } + } + + + +template +inline +void +subview::zeros() + { + arma_extra_debug_sigprint(); + + (*this).fill(eT(0)); + } + + + +template +inline +void +subview::ones() + { + arma_extra_debug_sigprint(); + + (*this).fill(eT(1)); + } + + + +template +inline +void +subview::eye() + { + arma_extra_debug_sigprint(); + + (*this).zeros(); + + const uword N = (std::min)(n_rows, n_cols); + + for(uword ii=0; ii < N; ++ii) + { + at(ii,ii) = eT(1); + } + } + + + +template +inline +void +subview::randu() + { + arma_extra_debug_sigprint(); + + subview& s = (*this); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + if(s_n_rows == 1) + { + podarray tmp(s_n_cols); + + eT* tmp_mem = tmp.memptr(); + + arma_rng::randu::fill( tmp_mem, s_n_cols ); + + for(uword ii=0; ii < s_n_cols; ++ii) { at(0,ii) = tmp_mem[ii]; } + } + else + { + if( (s.aux_row1 == 0) && (s_n_rows == s.m.n_rows) ) + { + arma_rng::randu::fill( s.colptr(0), s.n_elem ); + } + else + { + for(uword ii=0; ii < s_n_cols; ++ii) + { + arma_rng::randu::fill( s.colptr(ii), s_n_rows ); + } + } + } + } + + + +template +inline +void +subview::randn() + { + arma_extra_debug_sigprint(); + + subview& s = (*this); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; + + if(s_n_rows == 1) + { + podarray tmp(s_n_cols); + + eT* tmp_mem = tmp.memptr(); + + arma_rng::randn::fill( tmp_mem, s_n_cols ); + + for(uword ii=0; ii < s_n_cols; ++ii) { at(0,ii) = tmp_mem[ii]; } + } + else + { + if( (s.aux_row1 == 0) && (s_n_rows == s.m.n_rows) ) + { + arma_rng::randn::fill( s.colptr(0), s.n_elem ); + } + else + { + for(uword ii=0; ii < s_n_cols; ++ii) + { + arma_rng::randn::fill( s.colptr(ii), s_n_rows ); + } + } + } + } + + + +template +inline +eT +subview::at_alt(const uword ii) const + { + return operator[](ii); + } + + + +template +inline +eT& +subview::operator[](const uword ii) + { + const uword in_col = ii / n_rows; + const uword in_row = ii % n_rows; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::operator[](const uword ii) const + { + const uword in_col = ii / n_rows; + const uword in_row = ii % n_rows; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +inline +eT& +subview::operator()(const uword ii) + { + arma_debug_check_bounds( (ii >= n_elem), "subview::operator(): index out of bounds" ); + + const uword in_col = ii / n_rows; + const uword in_row = ii % n_rows; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::operator()(const uword ii) const + { + arma_debug_check_bounds( (ii >= n_elem), "subview::operator(): index out of bounds" ); + + const uword in_col = ii / n_rows; + const uword in_row = ii % n_rows; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +inline +eT& +subview::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +inline +eT& +subview::at(const uword in_row, const uword in_col) + { + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::at(const uword in_row, const uword in_col) const + { + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +inline +eT& +subview::front() + { + const uword index = aux_col1*m.n_rows + aux_row1; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::front() const + { + const uword index = aux_col1*m.n_rows + aux_row1; + + return m.mem[index]; + } + + + +template +inline +eT& +subview::back() + { + const uword in_row = n_rows - 1; + const uword in_col = n_cols - 1; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::back() const + { + const uword in_row = n_rows - 1; + const uword in_col = n_cols - 1; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + +template +arma_inline +eT* +subview::colptr(const uword in_col) + { + return & access::rw((const_cast< Mat& >(m)).mem[ (in_col + aux_col1)*m.n_rows + aux_row1 ]); + } + + + +template +arma_inline +const eT* +subview::colptr(const uword in_col) const + { + return & m.mem[ (in_col + aux_col1)*m.n_rows + aux_row1 ]; + } + + + +template +template +inline +bool +subview::check_overlap(const subview& x) const + { + if(is_same_type::value == false) { return false; } + + const subview& s = (*this); + + if(void_ptr(&(s.m)) != void_ptr(&(x.m))) { return false; } + + if( (s.n_elem == 0) || (x.n_elem == 0) ) { return false; } + + const uword s_row_start = s.aux_row1; + const uword s_row_end_p1 = s_row_start + s.n_rows; + + const uword s_col_start = s.aux_col1; + const uword s_col_end_p1 = s_col_start + s.n_cols; + + + const uword x_row_start = x.aux_row1; + const uword x_row_end_p1 = x_row_start + x.n_rows; + + const uword x_col_start = x.aux_col1; + const uword x_col_end_p1 = x_col_start + x.n_cols; + + + const bool outside_rows = ( (x_row_start >= s_row_end_p1) || (s_row_start >= x_row_end_p1) ); + const bool outside_cols = ( (x_col_start >= s_col_end_p1) || (s_col_start >= x_col_end_p1) ); + + return ( (outside_rows == false) && (outside_cols == false) ); + } + + + +template +inline +bool +subview::is_vec() const + { + return ( (n_rows == 1) || (n_cols == 1) ); + } + + + +template +inline +bool +subview::is_finite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + for(uword ii=0; ii +inline +bool +subview::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + for(uword ii=0; ii +inline +bool +subview::has_inf() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + for(uword ii=0; ii +inline +bool +subview::has_nan() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + for(uword ii=0; ii +inline +bool +subview::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + for(uword ii=0; ii +inline +void +subview::extract(Mat& out, const subview& in) + { + arma_extra_debug_sigprint(); + + // NOTE: we're assuming that the matrix has already been set to the correct size and there is no aliasing; + // size setting and alias checking is done by either the Mat contructor or operator=() + + const uword n_rows = in.n_rows; // number of rows in the subview + const uword n_cols = in.n_cols; // number of columns in the subview + + arma_extra_debug_print(arma_str::format("out.n_rows = %u out.n_cols = %u in.m.n_rows = %u in.m.n_cols = %u") % out.n_rows % out.n_cols % in.m.n_rows % in.m.n_cols ); + + + if(in.is_vec()) + { + if(n_cols == 1) // a column vector + { + arma_extra_debug_print("subview::extract(): copying col (going across rows)"); + + // in.colptr(0) the first column of the subview, taking into account any row offset + arrayops::copy( out.memptr(), in.colptr(0), n_rows ); + } + else + if(n_rows == 1) // a row vector + { + arma_extra_debug_print("subview::extract(): copying row (going across columns)"); + + eT* out_mem = out.memptr(); + + const uword X_n_rows = in.m.n_rows; + + const eT* Xptr = &(in.m.at(in.aux_row1,in.aux_col1)); + + uword j; + + for(j=1; j < n_cols; j+=2) + { + const eT tmp1 = (*Xptr); Xptr += X_n_rows; + const eT tmp2 = (*Xptr); Xptr += X_n_rows; + + (*out_mem) = tmp1; out_mem++; + (*out_mem) = tmp2; out_mem++; + } + + if((j-1) < n_cols) + { + (*out_mem) = (*Xptr); + } + } + } + else // general submatrix + { + arma_extra_debug_print("subview::extract(): general submatrix"); + + if( (in.aux_row1 == 0) && (n_rows == in.m.n_rows) ) + { + arrayops::copy( out.memptr(), in.colptr(0), in.n_elem ); + } + else + { + for(uword col=0; col < n_cols; ++col) + { + arrayops::copy( out.colptr(col), in.colptr(col), n_rows ); + } + } + } + } + + + +//! X += Y.submat(...) +template +inline +void +subview::plus_inplace(Mat& out, const subview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "addition"); + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + + if(n_rows == 1) + { + eT* out_mem = out.memptr(); + + const Mat& X = in.m; + + const uword row = in.aux_row1; + const uword start_col = in.aux_col1; + + uword i,j; + for(i=0, j=1; j < n_cols; i+=2, j+=2) + { + const eT tmp1 = X.at(row, start_col+i); + const eT tmp2 = X.at(row, start_col+j); + + out_mem[i] += tmp1; + out_mem[j] += tmp2; + } + + if(i < n_cols) + { + out_mem[i] += X.at(row, start_col+i); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + arrayops::inplace_plus(out.colptr(col), in.colptr(col), n_rows); + } + } + } + + + +//! X -= Y.submat(...) +template +inline +void +subview::minus_inplace(Mat& out, const subview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "subtraction"); + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + + if(n_rows == 1) + { + eT* out_mem = out.memptr(); + + const Mat& X = in.m; + + const uword row = in.aux_row1; + const uword start_col = in.aux_col1; + + uword i,j; + for(i=0, j=1; j < n_cols; i+=2, j+=2) + { + const eT tmp1 = X.at(row, start_col+i); + const eT tmp2 = X.at(row, start_col+j); + + out_mem[i] -= tmp1; + out_mem[j] -= tmp2; + } + + if(i < n_cols) + { + out_mem[i] -= X.at(row, start_col+i); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + arrayops::inplace_minus(out.colptr(col), in.colptr(col), n_rows); + } + } + } + + + +//! X %= Y.submat(...) +template +inline +void +subview::schur_inplace(Mat& out, const subview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "element-wise multiplication"); + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + + if(n_rows == 1) + { + eT* out_mem = out.memptr(); + + const Mat& X = in.m; + + const uword row = in.aux_row1; + const uword start_col = in.aux_col1; + + uword i,j; + for(i=0, j=1; j < n_cols; i+=2, j+=2) + { + const eT tmp1 = X.at(row, start_col+i); + const eT tmp2 = X.at(row, start_col+j); + + out_mem[i] *= tmp1; + out_mem[j] *= tmp2; + } + + if(i < n_cols) + { + out_mem[i] *= X.at(row, start_col+i); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + arrayops::inplace_mul(out.colptr(col), in.colptr(col), n_rows); + } + } + } + + + +//! X /= Y.submat(...) +template +inline +void +subview::div_inplace(Mat& out, const subview& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(out, in, "element-wise division"); + + const uword n_rows = in.n_rows; + const uword n_cols = in.n_cols; + + if(n_rows == 1) + { + eT* out_mem = out.memptr(); + + const Mat& X = in.m; + + const uword row = in.aux_row1; + const uword start_col = in.aux_col1; + + uword i,j; + for(i=0, j=1; j < n_cols; i+=2, j+=2) + { + const eT tmp1 = X.at(row, start_col+i); + const eT tmp2 = X.at(row, start_col+j); + + out_mem[i] /= tmp1; + out_mem[j] /= tmp2; + } + + if(i < n_cols) + { + out_mem[i] /= X.at(row, start_col+i); + } + } + else + { + for(uword col=0; col < n_cols; ++col) + { + arrayops::inplace_div(out.colptr(col), in.colptr(col), n_rows); + } + } + } + + + +//! creation of subview (row vector) +template +inline +subview_row +subview::row(const uword row_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= n_rows, "subview::row(): out of bounds" ); + + const uword base_row = aux_row1 + row_num; + + return subview_row(m, base_row, aux_col1, n_cols); + } + + + +//! creation of subview (row vector) +template +inline +const subview_row +subview::row(const uword row_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( row_num >= n_rows, "subview::row(): out of bounds" ); + + const uword base_row = aux_row1 + row_num; + + return subview_row(m, base_row, aux_col1, n_cols); + } + + + +template +inline +subview_row +subview::operator()(const uword row_num, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + const uword base_col1 = aux_col1 + in_col1; + const uword base_row = aux_row1 + row_num; + + arma_debug_check_bounds + ( + (row_num >= n_rows) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "subview::operator(): indices out of bounds or incorrectly used" + ); + + return subview_row(m, base_row, base_col1, submat_n_cols); + } + + + +template +inline +const subview_row +subview::operator()(const uword row_num, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool col_all = col_span.whole; + + const uword local_n_cols = n_cols; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + const uword base_col1 = aux_col1 + in_col1; + const uword base_row = aux_row1 + row_num; + + arma_debug_check_bounds + ( + (row_num >= n_rows) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "subview::operator(): indices out of bounds or incorrectly used" + ); + + return subview_row(m, base_row, base_col1, submat_n_cols); + } + + + +//! creation of subview (column vector) +template +inline +subview_col +subview::col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "subview::col(): out of bounds" ); + + const uword base_col = aux_col1 + col_num; + + return subview_col(m, base_col, aux_row1, n_rows); + } + + + +//! creation of subview (column vector) +template +inline +const subview_col +subview::col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "subview::col(): out of bounds" ); + + const uword base_col = aux_col1 + col_num; + + return subview_col(m, base_col, aux_row1, n_rows); + } + + + +template +inline +subview_col +subview::operator()(const span& row_span, const uword col_num) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword base_row1 = aux_row1 + in_row1; + const uword base_col = aux_col1 + col_num; + + arma_debug_check_bounds + ( + (col_num >= n_cols) + || + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "subview::operator(): indices out of bounds or incorrectly used" + ); + + return subview_col(m, base_col, base_row1, submat_n_rows); + } + + + +template +inline +const subview_col +subview::operator()(const span& row_span, const uword col_num) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + + const uword local_n_rows = n_rows; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword base_row1 = aux_row1 + in_row1; + const uword base_col = aux_col1 + col_num; + + arma_debug_check_bounds + ( + (col_num >= n_cols) + || + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + , + "subview::operator(): indices out of bounds or incorrectly used" + ); + + return subview_col(m, base_col, base_row1, submat_n_rows); + } + + + +//! create a Col object which uses memory from an existing matrix object. +//! this approach is currently not alias safe +//! and does not take into account that the parent matrix object could be deleted. +//! if deleted memory is accessed by the created Col object, +//! it will cause memory corruption and/or a crash +template +inline +Col +subview::unsafe_col(const uword col_num) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "subview::unsafe_col(): out of bounds" ); + + return Col(colptr(col_num), n_rows, false, true); + } + + + +//! create a Col object which uses memory from an existing matrix object. +//! this approach is currently not alias safe +//! and does not take into account that the parent matrix object could be deleted. +//! if deleted memory is accessed by the created Col object, +//! it will cause memory corruption and/or a crash +template +inline +const Col +subview::unsafe_col(const uword col_num) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( col_num >= n_cols, "subview::unsafe_col(): out of bounds" ); + + return Col(const_cast(colptr(col_num)), n_rows, false, true); + } + + + +//! creation of subview (submatrix comprised of specified row vectors) +template +inline +subview +subview::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "subview::rows(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword base_row1 = aux_row1 + in_row1; + + return subview(m, base_row1, aux_col1, subview_n_rows, n_cols ); + } + + + +//! creation of subview (submatrix comprised of specified row vectors) +template +inline +const subview +subview::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_row2 >= n_rows), + "subview::rows(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword base_row1 = aux_row1 + in_row1; + + return subview(m, base_row1, aux_col1, subview_n_rows, n_cols ); + } + + + +//! creation of subview (submatrix comprised of specified column vectors) +template +inline +subview +subview::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "subview::cols(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + const uword base_col1 = aux_col1 + in_col1; + + return subview(m, aux_row1, base_col1, n_rows, subview_n_cols); + } + + + +//! creation of subview (submatrix comprised of specified column vectors) +template +inline +const subview +subview::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 > in_col2) || (in_col2 >= n_cols), + "subview::cols(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + const uword base_col1 = aux_col1 + in_col1; + + return subview(m, aux_row1, base_col1, n_rows, subview_n_cols); + } + + + +//! creation of subview (submatrix) +template +inline +subview +subview::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "subview::submat(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword subview_n_cols = in_col2 - in_col1 + 1; + + const uword base_row1 = aux_row1 + in_row1; + const uword base_col1 = aux_col1 + in_col1; + + return subview(m, base_row1, base_col1, subview_n_rows, subview_n_cols); + } + + + +//! creation of subview (generic submatrix) +template +inline +const subview +subview::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), + "subview::submat(): indices out of bounds or incorrectly used" + ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + const uword subview_n_cols = in_col2 - in_col1 + 1; + + const uword base_row1 = aux_row1 + in_row1; + const uword base_col1 = aux_col1 + in_col1; + + return subview(m, base_row1, base_col1, subview_n_rows, subview_n_cols); + } + + + +//! creation of subview (submatrix) +template +inline +subview +subview::submat(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "subview::submat(): indices out of bounds or incorrectly used" + ); + + const uword base_row1 = aux_row1 + in_row1; + const uword base_col1 = aux_col1 + in_col1; + + return subview(m, base_row1, base_col1, submat_n_rows, submat_n_cols); + } + + + +//! creation of subview (generic submatrix) +template +inline +const subview +subview::submat(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + const bool row_all = row_span.whole; + const bool col_all = col_span.whole; + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + const uword in_row1 = row_all ? 0 : row_span.a; + const uword in_row2 = row_span.b; + const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; + + const uword in_col1 = col_all ? 0 : col_span.a; + const uword in_col2 = col_span.b; + const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; + + arma_debug_check_bounds + ( + ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) + || + ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) + , + "subview::submat(): indices out of bounds or incorrectly used" + ); + + const uword base_row1 = aux_row1 + in_row1; + const uword base_col1 = aux_col1 + in_col1; + + return subview(m, base_row1, base_col1, submat_n_rows, submat_n_cols); + } + + + +template +inline +subview +subview::operator()(const span& row_span, const span& col_span) + { + arma_extra_debug_sigprint(); + + return (*this).submat(row_span, col_span); + } + + + +template +inline +const subview +subview::operator()(const span& row_span, const span& col_span) const + { + arma_extra_debug_sigprint(); + + return (*this).submat(row_span, col_span); + } + + + +template +inline +subview_each1< subview, 0 > +subview::each_col() + { + arma_extra_debug_sigprint(); + + return subview_each1< subview, 0 >(*this); + } + + + +template +inline +subview_each1< subview, 1 > +subview::each_row() + { + arma_extra_debug_sigprint(); + + return subview_each1< subview, 1 >(*this); + } + + + +template +template +inline +subview_each2< subview, 0, T1 > +subview::each_col(const Base& indices) + { + arma_extra_debug_sigprint(); + + return subview_each2< subview, 0, T1 >(*this, indices); + } + + + +template +template +inline +subview_each2< subview, 1, T1 > +subview::each_row(const Base& indices) + { + arma_extra_debug_sigprint(); + + return subview_each2< subview, 1, T1 >(*this, indices); + } + + + +//! apply a lambda function to each column, where each column is interpreted as a column vector +template +inline +void +subview::each_col(const std::function< void(Col&) >& F) + { + arma_extra_debug_sigprint(); + + for(uword ii=0; ii < n_cols; ++ii) + { + Col tmp(colptr(ii), n_rows, false, true); + F(tmp); + } + } + + + +template +inline +void +subview::each_col(const std::function< void(const Col&) >& F) const + { + arma_extra_debug_sigprint(); + + for(uword ii=0; ii < n_cols; ++ii) + { + const Col tmp(colptr(ii), n_rows, false, true); + F(tmp); + } + } + + + +//! apply a lambda function to each row, where each row is interpreted as a row vector +template +inline +void +subview::each_row(const std::function< void(Row&) >& F) + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); + + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + { + for(uword col_id = 0; col_id < n_cols; ++col_id) + { + const eT* col_mem = colptr(col_id); + + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; + } + + F(tmp1); + F(tmp2); + + for(uword col_id = 0; col_id < n_cols; ++col_id) + { + eT* col_mem = colptr(col_id); + + col_mem[ii] = tmp1_mem[col_id]; + col_mem[jj] = tmp2_mem[col_id]; + } + } + + if(ii < n_rows) + { + tmp1 = (*this).row(ii); + + F(tmp1); + + (*this).row(ii) = tmp1; + } + } + + + +template +inline +void +subview::each_row(const std::function< void(const Row&) >& F) const + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); + + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + { + for(uword col_id = 0; col_id < n_cols; ++col_id) + { + const eT* col_mem = colptr(col_id); + + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; + } + + F(tmp1); + F(tmp2); + } + + if(ii < n_rows) + { + tmp1 = (*this).row(ii); + + F(tmp1); + } + } + + + +//! creation of diagview (diagonal) +template +inline +diagview +subview::diag(const sword in_id) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (in_id < 0) ? uword(-in_id) : 0; + const uword col_offset = (in_id > 0) ? uword( in_id) : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "subview::diag(): requested diagonal out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + const uword base_row_offset = aux_row1 + row_offset; + const uword base_col_offset = aux_col1 + col_offset; + + return diagview(m, base_row_offset, base_col_offset, len); + } + + + +//! creation of diagview (diagonal) +template +inline +const diagview +subview::diag(const sword in_id) const + { + arma_extra_debug_sigprint(); + + const uword row_offset = uword( (in_id < 0) ? -in_id : 0 ); + const uword col_offset = uword( (in_id > 0) ? in_id : 0 ); + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "subview::diag(): requested diagonal out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + const uword base_row_offset = aux_row1 + row_offset; + const uword base_col_offset = aux_col1 + col_offset; + + return diagview(m, base_row_offset, base_col_offset, len); + } + + + +template +inline +void +subview::swap_rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_row1 >= n_rows) || (in_row2 >= n_rows), + "subview::swap_rows(): out of bounds" + ); + + eT* mem = (const_cast< Mat& >(m)).memptr(); + + if(n_elem > 0) + { + const uword m_n_rows = m.n_rows; + + for(uword ucol=0; ucol < n_cols; ++ucol) + { + const uword offset = (aux_col1 + ucol) * m_n_rows; + const uword pos1 = aux_row1 + in_row1 + offset; + const uword pos2 = aux_row1 + in_row2 + offset; + + std::swap( access::rw(mem[pos1]), access::rw(mem[pos2]) ); + } + } + } + + + +template +inline +void +subview::swap_cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds + ( + (in_col1 >= n_cols) || (in_col2 >= n_cols), + "subview::swap_cols(): out of bounds" + ); + + if(n_elem > 0) + { + eT* ptr1 = colptr(in_col1); + eT* ptr2 = colptr(in_col2); + + for(uword urow=0; urow < n_rows; ++urow) + { + std::swap( ptr1[urow], ptr2[urow] ); + } + } + } + + + +template +inline +typename subview::iterator +subview::begin() + { + return iterator(*this, aux_row1, aux_col1); + } + + + +template +inline +typename subview::const_iterator +subview::begin() const + { + return const_iterator(*this, aux_row1, aux_col1); + } + + + +template +inline +typename subview::const_iterator +subview::cbegin() const + { + return const_iterator(*this, aux_row1, aux_col1); + } + + + +template +inline +typename subview::iterator +subview::end() + { + return iterator(*this, aux_row1, aux_col1 + n_cols); + } + + + +template +inline +typename subview::const_iterator +subview::end() const + { + return const_iterator(*this, aux_row1, aux_col1 + n_cols); + } + + + +template +inline +typename subview::const_iterator +subview::cend() const + { + return const_iterator(*this, aux_row1, aux_col1 + n_cols); + } + + + +// +// +// + + + +template +inline +subview::iterator::iterator() + : M (nullptr) + , current_ptr(nullptr) + , current_row(0 ) + , current_col(0 ) + , aux_row1 (0 ) + , aux_row2_p1(0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +subview::iterator::iterator(const iterator& X) + : M (X.M ) + , current_ptr(X.current_ptr) + , current_row(X.current_row) + , current_col(X.current_col) + , aux_row1 (X.aux_row1 ) + , aux_row2_p1(X.aux_row2_p1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview::iterator::iterator(subview& in_sv, const uword in_row, const uword in_col) + : M (&(const_cast< Mat& >(in_sv.m))) + , current_ptr(&(M->at(in_row,in_col)) ) + , current_row(in_row ) + , current_col(in_col ) + , aux_row1 (in_sv.aux_row1 ) + , aux_row2_p1(in_sv.aux_row1 + in_sv.n_rows ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eT& +subview::iterator::operator*() + { + return (*current_ptr); + } + + + +template +inline +typename subview::iterator& +subview::iterator::operator++() + { + current_row++; + + if(current_row == aux_row2_p1) + { + current_row = aux_row1; + current_col++; + + current_ptr = &( (*M).at(current_row,current_col) ); + } + else + { + current_ptr++; + } + + return *this; + } + + + +template +inline +typename subview::iterator +subview::iterator::operator++(int) + { + typename subview::iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +bool +subview::iterator::operator==(const iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview::iterator::operator!=(const iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +bool +subview::iterator::operator==(const const_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview::iterator::operator!=(const const_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +// +// +// + + + +template +inline +subview::const_iterator::const_iterator() + : M (nullptr) + , current_ptr(nullptr) + , current_row(0 ) + , current_col(0 ) + , aux_row1 (0 ) + , aux_row2_p1(0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +subview::const_iterator::const_iterator(const iterator& X) + : M (X.M ) + , current_ptr(X.current_ptr) + , current_row(X.current_row) + , current_col(X.current_col) + , aux_row1 (X.aux_row1 ) + , aux_row2_p1(X.aux_row2_p1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview::const_iterator::const_iterator(const const_iterator& X) + : M (X.M ) + , current_ptr(X.current_ptr) + , current_row(X.current_row) + , current_col(X.current_col) + , aux_row1 (X.aux_row1 ) + , aux_row2_p1(X.aux_row2_p1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview::const_iterator::const_iterator(const subview& in_sv, const uword in_row, const uword in_col) + : M (&(in_sv.m) ) + , current_ptr(&(M->at(in_row,in_col)) ) + , current_row(in_row ) + , current_col(in_col ) + , aux_row1 (in_sv.aux_row1 ) + , aux_row2_p1(in_sv.aux_row1 + in_sv.n_rows) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +const eT& +subview::const_iterator::operator*() + { + return (*current_ptr); + } + + + +template +inline +typename subview::const_iterator& +subview::const_iterator::operator++() + { + current_row++; + + if(current_row == aux_row2_p1) + { + current_row = aux_row1; + current_col++; + + current_ptr = &( (*M).at(current_row,current_col) ); + } + else + { + current_ptr++; + } + + return *this; + } + + + +template +inline +typename subview::const_iterator +subview::const_iterator::operator++(int) + { + typename subview::const_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +bool +subview::const_iterator::operator==(const iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview::const_iterator::operator!=(const iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +template +inline +bool +subview::const_iterator::operator==(const const_iterator& rhs) const + { + return (current_ptr == rhs.current_ptr); + } + + + +template +inline +bool +subview::const_iterator::operator!=(const const_iterator& rhs) const + { + return (current_ptr != rhs.current_ptr); + } + + + +// +// +// + + + +template +inline +subview::row_iterator::row_iterator() + : M (nullptr) + , current_row(0 ) + , current_col(0 ) + , aux_col1 (0 ) + , aux_col2_p1(0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +subview::row_iterator::row_iterator(const row_iterator& X) + : M (X.M ) + , current_row(X.current_row) + , current_col(X.current_col) + , aux_col1 (X.aux_col1 ) + , aux_col2_p1(X.aux_col2_p1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview::row_iterator::row_iterator(subview& in_sv, const uword in_row, const uword in_col) + : M (&(const_cast< Mat& >(in_sv.m))) + , current_row(in_row ) + , current_col(in_col ) + , aux_col1 (in_sv.aux_col1 ) + , aux_col2_p1(in_sv.aux_col1 + in_sv.n_cols ) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +eT& +subview::row_iterator::operator*() + { + return M->at(current_row,current_col); + } + + + +template +inline +typename subview::row_iterator& +subview::row_iterator::operator++() + { + current_col++; + + if(current_col == aux_col2_p1) + { + current_col = aux_col1; + current_row++; + } + + return *this; + } + + + +template +inline +typename subview::row_iterator +subview::row_iterator::operator++(int) + { + typename subview::row_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +bool +subview::row_iterator::operator==(const row_iterator& rhs) const + { + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); + } + + + +template +inline +bool +subview::row_iterator::operator!=(const row_iterator& rhs) const + { + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); + } + + + +template +inline +bool +subview::row_iterator::operator==(const const_row_iterator& rhs) const + { + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); + } + + + +template +inline +bool +subview::row_iterator::operator!=(const const_row_iterator& rhs) const + { + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); + } + + + +// +// +// + + + +template +inline +subview::const_row_iterator::const_row_iterator() + : M (nullptr) + , current_row(0 ) + , current_col(0 ) + , aux_col1 (0 ) + , aux_col2_p1(0 ) + { + arma_extra_debug_sigprint(); + // Technically this iterator is invalid (it does not point to a valid element) + } + + + +template +inline +subview::const_row_iterator::const_row_iterator(const row_iterator& X) + : M (X.M ) + , current_row(X.current_row) + , current_col(X.current_col) + , aux_col1 (X.aux_col1 ) + , aux_col2_p1(X.aux_col2_p1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview::const_row_iterator::const_row_iterator(const const_row_iterator& X) + : M (X.M ) + , current_row(X.current_row) + , current_col(X.current_col) + , aux_col1 (X.aux_col1 ) + , aux_col2_p1(X.aux_col2_p1) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview::const_row_iterator::const_row_iterator(const subview& in_sv, const uword in_row, const uword in_col) + : M (&(in_sv.m) ) + , current_row(in_row ) + , current_col(in_col ) + , aux_col1 (in_sv.aux_col1 ) + , aux_col2_p1(in_sv.aux_col1 + in_sv.n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +const eT& +subview::const_row_iterator::operator*() const + { + return M->at(current_row,current_col); + } + + + +template +inline +typename subview::const_row_iterator& +subview::const_row_iterator::operator++() + { + current_col++; + + if(current_col == aux_col2_p1) + { + current_col = aux_col1; + current_row++; + } + + return *this; + } + + + +template +inline +typename subview::const_row_iterator +subview::const_row_iterator::operator++(int) + { + typename subview::const_row_iterator temp(*this); + + ++(*this); + + return temp; + } + + + +template +inline +bool +subview::const_row_iterator::operator==(const row_iterator& rhs) const + { + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); + } + + + +template +inline +bool +subview::const_row_iterator::operator!=(const row_iterator& rhs) const + { + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); + } + + + +template +inline +bool +subview::const_row_iterator::operator==(const const_row_iterator& rhs) const + { + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); + } + + + +template +inline +bool +subview::const_row_iterator::operator!=(const const_row_iterator& rhs) const + { + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); + } + + + +// +// +// + + + +template +inline +subview_col::subview_col(const Mat& in_m, const uword in_col) + : subview(in_m, 0, in_col, in_m.n_rows, 1) + , colmem(subview::colptr(0)) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_col::subview_col(const Mat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows) + : subview(in_m, in_row1, in_col, in_n_rows, 1) + , colmem(subview::colptr(0)) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_col::subview_col(const subview_col& in) + : subview(in) // interprets 'subview_col' as 'subview' + , colmem(in.colmem) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_col::subview_col(subview_col&& in) + : subview(std::move(in)) // interprets 'subview_col' as 'subview' + , colmem(in.colmem) + { + arma_extra_debug_sigprint(); + + access::rw(in.colmem) = nullptr; + } + + + +template +inline +void +subview_col::operator=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); + } + + + +template +inline +void +subview_col::operator=(const subview_col& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); // interprets 'subview_col' as 'subview' + } + + + +template +inline +void +subview_col::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + arma_debug_assert_same_size(subview::n_rows, subview::n_cols, N, 1, "copy into submatrix"); + + arrayops::copy( access::rwp(colmem), list.begin(), N ); + } + + + +template +inline +void +subview_col::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + if(subview::n_elem != 1) + { + arma_debug_assert_same_size(subview::n_rows, subview::n_cols, 1, 1, "copy into submatrix"); + } + + access::rw( colmem[0] ) = val; + } + + + +template +template +inline +void +subview_col::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); + } + + + +template +template +inline +void +subview_col::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + +template +template +inline +typename enable_if2< is_same_type::value, void>::result +subview_col::operator= (const Gen& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(subview::n_rows, uword(1), in.n_rows, (in.is_col ? uword(1) : in.n_cols), "copy into submatrix"); + + in.apply(*this); + } + + + +template +arma_inline +const Op,op_htrans> +subview_col::t() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_htrans> +subview_col::ht() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +subview_col::st() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +subview_col::as_row() const + { + return Op,op_strans>(*this); + } + + + +template +inline +void +subview_col::fill(const eT val) + { + arma_extra_debug_sigprint(); + + arrayops::inplace_set( access::rwp(colmem), val, subview::n_rows ); + } + + + +template +inline +void +subview_col::zeros() + { + arma_extra_debug_sigprint(); + + arrayops::fill_zeros( access::rwp(colmem), subview::n_rows ); + } + + + +template +inline +void +subview_col::ones() + { + arma_extra_debug_sigprint(); + + arrayops::inplace_set( access::rwp(colmem), eT(1), subview::n_rows ); + } + + + +template +arma_inline +eT +subview_col::at_alt(const uword ii) const + { + const eT* colmem_aligned = colmem; + memory::mark_as_aligned(colmem_aligned); + + return colmem_aligned[ii]; + } + + + +template +arma_inline +eT& +subview_col::operator[](const uword ii) + { + return access::rw( colmem[ii] ); + } + + + +template +arma_inline +eT +subview_col::operator[](const uword ii) const + { + return colmem[ii]; + } + + + +template +inline +eT& +subview_col::operator()(const uword ii) + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + return access::rw( colmem[ii] ); + } + + + +template +inline +eT +subview_col::operator()(const uword ii) const + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + return colmem[ii]; + } + + + +template +inline +eT& +subview_col::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col > 0)), "subview::operator(): index out of bounds" ); + + return access::rw( colmem[in_row] ); + } + + + +template +inline +eT +subview_col::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col > 0)), "subview::operator(): index out of bounds" ); + + return colmem[in_row]; + } + + + +template +inline +eT& +subview_col::at(const uword in_row, const uword) + { + return access::rw( colmem[in_row] ); + } + + + +template +inline +eT +subview_col::at(const uword in_row, const uword) const + { + return colmem[in_row]; + } + + + +template +arma_inline +eT* +subview_col::colptr(const uword) + { + return const_cast(colmem); + } + + +template +arma_inline +const eT* +subview_col::colptr(const uword) const + { + return colmem; + } + + +template +inline +subview_col +subview_col::rows(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::rows(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + const uword base_row1 = this->aux_row1 + in_row1; + + return subview_col(this->m, this->aux_col1, base_row1, subview_n_rows); + } + + + +template +inline +const subview_col +subview_col::rows(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::rows(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + const uword base_row1 = this->aux_row1 + in_row1; + + return subview_col(this->m, this->aux_col1, base_row1, subview_n_rows); + } + + + +template +inline +subview_col +subview_col::subvec(const uword in_row1, const uword in_row2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + const uword base_row1 = this->aux_row1 + in_row1; + + return subview_col(this->m, this->aux_col1, base_row1, subview_n_rows); + } + + + +template +inline +const subview_col +subview_col::subvec(const uword in_row1, const uword in_row2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_rows = in_row2 - in_row1 + 1; + + const uword base_row1 = this->aux_row1 + in_row1; + + return subview_col(this->m, this->aux_col1, base_row1, subview_n_rows); + } + + + +template +inline +subview_col +subview_col::subvec(const uword start_row, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_cols != 1), "subview_col::subvec(): given size does not specify a column vector" ); + + arma_debug_check_bounds( ( (start_row >= subview::n_rows) || ((start_row + s.n_rows) > subview::n_rows) ), "subview_col::subvec(): size out of bounds" ); + + const uword base_row1 = this->aux_row1 + start_row; + + return subview_col(this->m, this->aux_col1, base_row1, s.n_rows); + } + + + +template +inline +const subview_col +subview_col::subvec(const uword start_row, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_cols != 1), "subview_col::subvec(): given size does not specify a column vector" ); + + arma_debug_check_bounds( ( (start_row >= subview::n_rows) || ((start_row + s.n_rows) > subview::n_rows) ), "subview_col::subvec(): size out of bounds" ); + + const uword base_row1 = this->aux_row1 + start_row; + + return subview_col(this->m, this->aux_col1, base_row1, s.n_rows); + } + + + +template +inline +subview_col +subview_col::head(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::head(): size out of bounds" ); + + return subview_col(this->m, this->aux_col1, this->aux_row1, N); + } + + + +template +inline +const subview_col +subview_col::head(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::head(): size out of bounds" ); + + return subview_col(this->m, this->aux_col1, this->aux_row1, N); + } + + + +template +inline +subview_col +subview_col::tail(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::tail(): size out of bounds" ); + + const uword start_row = subview::aux_row1 + subview::n_rows - N; + + return subview_col(this->m, this->aux_col1, start_row, N); + } + + + +template +inline +const subview_col +subview_col::tail(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::tail(): size out of bounds" ); + + const uword start_row = subview::aux_row1 + subview::n_rows - N; + + return subview_col(this->m, this->aux_col1, start_row, N); + } + + + +template +inline +eT +subview_col::min() const + { + arma_extra_debug_sigprint(); + + if(subview::n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + return op_min::direct_min(colmem, subview::n_elem); + } + + + +template +inline +eT +subview_col::max() const + { + arma_extra_debug_sigprint(); + + if(subview::n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + return op_max::direct_max(colmem, subview::n_elem); + } + + + +template +inline +eT +subview_col::min(uword& index_of_min_val) const + { + arma_extra_debug_sigprint(); + + if(subview::n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + index_of_min_val = uword(0); + + return Datum::nan; + } + else + { + return op_min::direct_min(colmem, subview::n_elem, index_of_min_val); + } + } + + + +template +inline +eT +subview_col::max(uword& index_of_max_val) const + { + arma_extra_debug_sigprint(); + + if(subview::n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + index_of_max_val = uword(0); + + return Datum::nan; + } + else + { + return op_max::direct_max(colmem, subview::n_elem, index_of_max_val); + } + } + + + +template +inline +uword +subview_col::index_min() const + { + arma_extra_debug_sigprint(); + + uword index = 0; + + if(subview::n_elem == 0) + { + arma_debug_check(true, "index_min(): object has no elements"); + } + else + { + op_min::direct_min(colmem, subview::n_elem, index); + } + + return index; + } + + + +template +inline +uword +subview_col::index_max() const + { + arma_extra_debug_sigprint(); + + uword index = 0; + + if(subview::n_elem == 0) + { + arma_debug_check(true, "index_max(): object has no elements"); + } + else + { + op_max::direct_max(colmem, subview::n_elem, index); + } + + return index; + } + + + +// +// +// + + +template +inline +subview_cols::subview_cols(const Mat& in_m, const uword in_col1, const uword in_n_cols) + : subview(in_m, 0, in_col1, in_m.n_rows, in_n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cols::subview_cols(const subview_cols& in) + : subview(in) // interprets 'subview_cols' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cols::subview_cols(subview_cols&& in) + : subview(std::move(in)) // interprets 'subview_cols' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_cols::operator=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); + } + + + +template +inline +void +subview_cols::operator=(const subview_cols& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); // interprets 'subview_cols' as 'subview' + } + + + +template +inline +void +subview_cols::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + subview::operator=(list); + } + + + +template +inline +void +subview_cols::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + subview::operator=(list); + } + + + +template +inline +void +subview_cols::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + subview::operator=(val); + } + + + +template +template +inline +void +subview_cols::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + +template +template +inline +void +subview_cols::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + +template +template +inline +typename enable_if2< is_same_type::value, void>::result +subview_cols::operator= (const Gen& in) + { + arma_extra_debug_sigprint(); + + subview::operator=(in); + } + + + +template +arma_inline +const Op,op_htrans> +subview_cols::t() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_htrans> +subview_cols::ht() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +subview_cols::st() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +const Op,op_vectorise_col> +subview_cols::as_col() const + { + return Op,op_vectorise_col>(*this); + } + + + +template +inline +eT +subview_cols::at_alt(const uword ii) const + { + return operator[](ii); + } + + + +template +inline +eT& +subview_cols::operator[](const uword ii) + { + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::operator[](const uword ii) const + { + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_cols::operator()(const uword ii) + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::operator()(const uword ii) const + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_cols::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_cols::at(const uword in_row, const uword in_col) + { + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::at(const uword in_row, const uword in_col) const + { + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return subview::m.mem[index]; + } + + + +template +arma_inline +eT* +subview_cols::colptr(const uword in_col) + { + return & access::rw((const_cast< Mat& >(subview::m)).mem[ (in_col + subview::aux_col1) * subview::m.n_rows ]); + } + + + +template +arma_inline +const eT* +subview_cols::colptr(const uword in_col) const + { + return & subview::m.mem[ (in_col + subview::aux_col1) * subview::m.n_rows ]; + } + + + +// +// +// + + + +template +inline +subview_row::subview_row(const Mat& in_m, const uword in_row) + : subview(in_m, in_row, 0, 1, in_m.n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_row::subview_row(const Mat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols) + : subview(in_m, in_row, in_col1, 1, in_n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_row::subview_row(const subview_row& in) + : subview(in) // interprets 'subview_row' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_row::subview_row(subview_row&& in) + : subview(std::move(in)) // interprets 'subview_row' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_row::operator=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); + } + + + +template +inline +void +subview_row::operator=(const subview_row& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); // interprets 'subview_row' as 'subview' + } + + + +template +inline +void +subview_row::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + subview::operator=(val); // interprets 'subview_row' as 'subview' + } + + + +template +inline +void +subview_row::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + arma_debug_assert_same_size(subview::n_rows, subview::n_cols, 1, N, "copy into submatrix"); + + auto it = list.begin(); + + for(uword ii=0; ii < N; ++ii) + { + (*this).operator[](ii) = (*it); + ++it; + } + } + + + +template +template +inline +void +subview_row::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); + } + + + +template +template +inline +void +subview_row::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + +template +template +inline +typename enable_if2< is_same_type::value, void>::result +subview_row::operator= (const Gen& in) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(uword(1), subview::n_cols, (in.is_row ? uword(1) : in.n_rows), in.n_cols, "copy into submatrix"); + + in.apply(*this); + } + + + +template +arma_inline +const Op,op_htrans> +subview_row::t() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_htrans> +subview_row::ht() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +subview_row::st() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +subview_row::as_col() const + { + return Op,op_strans>(*this); + } + + + +template +inline +eT +subview_row::at_alt(const uword ii) const + { + const uword index = (ii + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_row::operator[](const uword ii) + { + const uword index = (ii + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_row::operator[](const uword ii) const + { + const uword index = (ii + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_row::operator()(const uword ii) + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + const uword index = (ii + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_row::operator()(const uword ii) const + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + const uword index = (ii + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_row::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row > 0) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_row::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row > 0) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_row::at(const uword, const uword in_col) + { + const uword index = (in_col + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_row::at(const uword, const uword in_col) const + { + const uword index = (in_col + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); + + return subview::m.mem[index]; + } + + + +template +inline +subview_row +subview_row::cols(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::cols(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + const uword base_col1 = this->aux_col1 + in_col1; + + return subview_row(this->m, this->aux_row1, base_col1, subview_n_cols); + } + + + +template +inline +const subview_row +subview_row::cols(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::cols(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + const uword base_col1 = this->aux_col1 + in_col1; + + return subview_row(this->m, this->aux_row1, base_col1, subview_n_cols); + } + + + +template +inline +subview_row +subview_row::subvec(const uword in_col1, const uword in_col2) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + const uword base_col1 = this->aux_col1 + in_col1; + + return subview_row(this->m, this->aux_row1, base_col1, subview_n_cols); + } + + + +template +inline +const subview_row +subview_row::subvec(const uword in_col1, const uword in_col2) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::subvec(): indices out of bounds or incorrectly used" ); + + const uword subview_n_cols = in_col2 - in_col1 + 1; + + const uword base_col1 = this->aux_col1 + in_col1; + + return subview_row(this->m, this->aux_row1, base_col1, subview_n_cols); + } + + + +template +inline +subview_row +subview_row::subvec(const uword start_col, const SizeMat& s) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_rows != 1), "subview_row::subvec(): given size does not specify a row vector" ); + + arma_debug_check_bounds( ( (start_col >= subview::n_cols) || ((start_col + s.n_cols) > subview::n_cols) ), "subview_row::subvec(): size out of bounds" ); + + const uword base_col1 = this->aux_col1 + start_col; + + return subview_row(this->m, this->aux_row1, base_col1, s.n_cols); + } + + + +template +inline +const subview_row +subview_row::subvec(const uword start_col, const SizeMat& s) const + { + arma_extra_debug_sigprint(); + + arma_debug_check( (s.n_rows != 1), "subview_row::subvec(): given size does not specify a row vector" ); + + arma_debug_check_bounds( ( (start_col >= subview::n_cols) || ((start_col + s.n_cols) > subview::n_cols) ), "subview_row::subvec(): size out of bounds" ); + + const uword base_col1 = this->aux_col1 + start_col; + + return subview_row(this->m, this->aux_row1, base_col1, s.n_cols); + } + + + +template +inline +subview_row +subview_row::head(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::head(): size out of bounds" ); + + return subview_row(this->m, this->aux_row1, this->aux_col1, N); + } + + + +template +inline +const subview_row +subview_row::head(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::head(): size out of bounds" ); + + return subview_row(this->m, this->aux_row1, this->aux_col1, N); + } + + + +template +inline +subview_row +subview_row::tail(const uword N) + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::tail(): size out of bounds" ); + + const uword start_col = subview::aux_col1 + subview::n_cols - N; + + return subview_row(this->m, this->aux_row1, start_col, N); + } + + + +template +inline +const subview_row +subview_row::tail(const uword N) const + { + arma_extra_debug_sigprint(); + + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::tail(): size out of bounds" ); + + const uword start_col = subview::aux_col1 + subview::n_cols - N; + + return subview_row(this->m, this->aux_row1, start_col, N); + } + + + +template +inline +uword +subview_row::index_min() const + { + const Proxy< subview_row > P(*this); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_min(): object has no elements"); + } + else + { + op_min::min_with_index(P, index); + } + + return index; + } + + + +template +inline +uword +subview_row::index_max() const + { + const Proxy< subview_row > P(*this); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_max(): object has no elements"); + } + else + { + op_max::max_with_index(P, index); + } + + return index; + } + + + +template +inline +typename subview::row_iterator +subview_row::begin() + { + return typename subview::row_iterator(*this, subview::aux_row1, subview::aux_col1); + } + + + +template +inline +typename subview::const_row_iterator +subview_row::begin() const + { + return typename subview::const_row_iterator(*this, subview::aux_row1, subview::aux_col1); + } + + + +template +inline +typename subview::const_row_iterator +subview_row::cbegin() const + { + return typename subview::const_row_iterator(*this, subview::aux_row1, subview::aux_col1); + } + + + +template +inline +typename subview::row_iterator +subview_row::end() + { + return typename subview::row_iterator(*this, subview::aux_row1 + subview::n_rows, subview::aux_col1); + } + + + +template +inline +typename subview::const_row_iterator +subview_row::end() const + { + return typename subview::const_row_iterator(*this, subview::aux_row1 + subview::n_rows, subview::aux_col1); + } + + + +template +inline +typename subview::const_row_iterator +subview_row::cend() const + { + return typename subview::const_row_iterator(*this, subview::aux_row1 + subview::n_rows, subview::aux_col1); + } + + + +// +// +// + + + +template +inline +subview_row_strans::subview_row_strans(const subview_row& in_sv_row) + : sv_row(in_sv_row ) + , n_rows(in_sv_row.n_cols) + , n_elem(in_sv_row.n_elem) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_row_strans::extract(Mat& out) const + { + arma_extra_debug_sigprint(); + + // NOTE: this function assumes that matrix 'out' has already been set to the correct size + + const Mat& X = sv_row.m; + + eT* out_mem = out.memptr(); + + const uword row = sv_row.aux_row1; + const uword start_col = sv_row.aux_col1; + const uword sv_row_n_cols = sv_row.n_cols; + + uword ii,jj; + + for(ii=0, jj=1; jj < sv_row_n_cols; ii+=2, jj+=2) + { + const eT tmp1 = X.at(row, start_col+ii); + const eT tmp2 = X.at(row, start_col+jj); + + out_mem[ii] = tmp1; + out_mem[jj] = tmp2; + } + + if(ii < sv_row_n_cols) + { + out_mem[ii] = X.at(row, start_col+ii); + } + } + + + +template +inline +eT +subview_row_strans::at_alt(const uword ii) const + { + return sv_row[ii]; + } + + + +template +inline +eT +subview_row_strans::operator[](const uword ii) const + { + return sv_row[ii]; + } + + + +template +inline +eT +subview_row_strans::operator()(const uword ii) const + { + return sv_row(ii); + } + + + +template +inline +eT +subview_row_strans::operator()(const uword in_row, const uword in_col) const + { + return sv_row(in_col, in_row); // deliberately swapped + } + + + +template +inline +eT +subview_row_strans::at(const uword in_row, const uword) const + { + return sv_row.at(0, in_row); // deliberately swapped + } + + + +// +// +// + + + +template +inline +subview_row_htrans::subview_row_htrans(const subview_row& in_sv_row) + : sv_row(in_sv_row ) + , n_rows(in_sv_row.n_cols) + , n_elem(in_sv_row.n_elem) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_row_htrans::extract(Mat& out) const + { + arma_extra_debug_sigprint(); + + // NOTE: this function assumes that matrix 'out' has already been set to the correct size + + const Mat& X = sv_row.m; + + eT* out_mem = out.memptr(); + + const uword row = sv_row.aux_row1; + const uword start_col = sv_row.aux_col1; + const uword sv_row_n_cols = sv_row.n_cols; + + for(uword ii=0; ii < sv_row_n_cols; ++ii) + { + out_mem[ii] = access::alt_conj( X.at(row, start_col+ii) ); + } + } + + + +template +inline +eT +subview_row_htrans::at_alt(const uword ii) const + { + return access::alt_conj( sv_row[ii] ); + } + + + +template +inline +eT +subview_row_htrans::operator[](const uword ii) const + { + return access::alt_conj( sv_row[ii] ); + } + + + +template +inline +eT +subview_row_htrans::operator()(const uword ii) const + { + return access::alt_conj( sv_row(ii) ); + } + + + +template +inline +eT +subview_row_htrans::operator()(const uword in_row, const uword in_col) const + { + return access::alt_conj( sv_row(in_col, in_row) ); // deliberately swapped + } + + + +template +inline +eT +subview_row_htrans::at(const uword in_row, const uword) const + { + return access::alt_conj( sv_row.at(0, in_row) ); // deliberately swapped + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/sym_helper.hpp b/src/armadillo/include/armadillo_bits/sym_helper.hpp new file mode 100644 index 0000000..00555c4 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/sym_helper.hpp @@ -0,0 +1,485 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup sym_helper +//! @{ + + +namespace sym_helper +{ + +// computationally inexpensive algorithm to guess whether a matrix is positive definite: +// (1) ensure the matrix is symmetric/hermitian (within a tolerance) +// (2) ensure the diagonal entries are real and greater than zero +// (3) ensure that the value with largest modulus is on the main diagonal +// (4) ensure rudimentary diagonal dominance: (real(A_ii) + real(A_jj)) > 2*abs(real(A_ij)) +// the above conditions are necessary, but not sufficient; +// doing it properly would be too computationally expensive for our purposes +// more info: +// http://mathworld.wolfram.com/PositiveDefiniteMatrix.html +// http://mathworld.wolfram.com/DiagonallyDominantMatrix.html + +template +inline +typename enable_if2::no, bool>::result +guess_sympd_worker(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming A is square-sized + + const eT tol = eT(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + eT max_diag = eT(0); + + for(uword j=0; j < N; ++j) + { + const eT A_jj = A_col[j]; + + if(A_jj <= eT(0)) { return false; } + + max_diag = (A_jj > max_diag) ? A_jj : max_diag; + + A_col += N; + } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const eT A_jj = A_col[j]; + + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + for(uword i=jp1; i < N; ++i) + { + const eT A_ij = A_col[i]; + const eT A_ji = (*A_ji_ptr); + + const eT A_ij_abs = (std::abs)(A_ij); + const eT A_ji_abs = (std::abs)(A_ji); + + // if( (A_ij_abs >= max_diag) || (A_ji_abs >= max_diag) ) { return false; } + if(A_ij_abs >= max_diag) { return false; } + + const eT A_delta = (std::abs)(A_ij - A_ji); + const eT A_abs_max = (std::max)(A_ij_abs, A_ji_abs); + + if( (A_delta > tol) && (A_delta > (A_abs_max*tol)) ) { return false; } + + const eT A_ii = (*A_ii_ptr); + + if( (A_ij_abs + A_ij_abs) >= (A_ii + A_jj) ) { return false; } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + + return true; + } + + + +template +inline +typename enable_if2::yes, bool>::result +guess_sympd_worker(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming A is square-sized + + typedef typename get_pod_type::result T; + + const T tol = T(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + T max_diag = T(0); + + for(uword j=0; j < N; ++j) + { + const eT& A_jj = A_col[j]; + const T A_jj_real = std::real(A_jj); + const T A_jj_imag = std::imag(A_jj); + + if( (A_jj_real <= T(0)) || (std::abs(A_jj_imag) > tol) ) { return false; } + + max_diag = (A_jj_real > max_diag) ? A_jj_real : max_diag; + + A_col += N; + } + + const T square_max_diag = max_diag * max_diag; + + if(arma_isfinite(square_max_diag) == false) { return false; } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + const T A_jj_real = std::real(A_col[j]); + + for(uword i=jp1; i < N; ++i) + { + const eT& A_ij = A_col[i]; + const T A_ij_real = std::real(A_ij); + const T A_ij_imag = std::imag(A_ij); + + // avoid using std::abs(), as that is time consuming due to division and std::sqrt() + const T square_A_ij_abs = (A_ij_real * A_ij_real) + (A_ij_imag * A_ij_imag); + + if(arma_isfinite(square_A_ij_abs) == false) { return false; } + + if(square_A_ij_abs >= square_max_diag) { return false; } + + const T A_ij_real_abs = (std::abs)(A_ij_real); + const T A_ij_imag_abs = (std::abs)(A_ij_imag); + + + const eT& A_ji = (*A_ji_ptr); + const T A_ji_real = std::real(A_ji); + const T A_ji_imag = std::imag(A_ji); + + const T A_ji_real_abs = (std::abs)(A_ji_real); + const T A_ji_imag_abs = (std::abs)(A_ji_imag); + + const T A_real_delta = (std::abs)(A_ij_real - A_ji_real); + const T A_real_abs_max = (std::max)(A_ij_real_abs, A_ji_real_abs); + + if( (A_real_delta > tol) && (A_real_delta > (A_real_abs_max*tol)) ) { return false; } + + + const T A_imag_delta = (std::abs)(A_ij_imag + A_ji_imag); // take into account complex conjugate + const T A_imag_abs_max = (std::max)(A_ij_imag_abs, A_ji_imag_abs); + + if( (A_imag_delta > tol) && (A_imag_delta > (A_imag_abs_max*tol)) ) { return false; } + + + const T A_ii_real = std::real(*A_ii_ptr); + + if( (A_ij_real_abs + A_ij_real_abs) >= (A_ii_real + A_jj_real) ) { return false; } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + + return true; + } + + + +template +inline +bool +guess_sympd(const Mat& A) + { + arma_extra_debug_sigprint(); + + // analyse matrices with size >= 4x4 + + if((A.n_rows != A.n_cols) || (A.n_rows < uword(4))) { return false; } + + return guess_sympd_worker(A); + } + + + +template +inline +bool +guess_sympd(const Mat& A, const uword min_n_rows) + { + arma_extra_debug_sigprint(); + + if((A.n_rows != A.n_cols) || (A.n_rows < min_n_rows)) { return false; } + + return guess_sympd_worker(A); + } + + + +// + + + +template +inline +typename enable_if2::no, void>::result +analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) + { + arma_extra_debug_sigprint(); + + is_approx_sym = true; + is_approx_sympd = true; + + const eT tol = eT(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + eT max_diag = eT(0); + + for(uword j=0; j < N; ++j) + { + const eT A_jj = A_col[j]; + + if(A_jj <= eT(0)) { is_approx_sympd = false; } + + max_diag = (A_jj > max_diag) ? A_jj : max_diag; + + A_col += N; + } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const eT A_jj = A_col[j]; + + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + for(uword i=jp1; i < N; ++i) + { + const eT A_ij = A_col[i]; + const eT A_ji = (*A_ji_ptr); + + const eT A_ij_abs = (std::abs)(A_ij); + const eT A_ji_abs = (std::abs)(A_ji); + + const eT A_delta = (std::abs)(A_ij - A_ji); + const eT A_abs_max = (std::max)(A_ij_abs, A_ji_abs); + + if( (A_delta > tol) && (A_delta > (A_abs_max*tol)) ) { is_approx_sym = false; return; } + + if(is_approx_sympd) + { + // if( (A_ij_abs >= max_diag) || (A_ji_abs >= max_diag) ) { is_approx_sympd = false; } + if(A_ij_abs >= max_diag) { is_approx_sympd = false; } + + const eT A_ii = (*A_ii_ptr); + + if( (A_ij_abs + A_ij_abs) >= (A_ii + A_jj) ) { is_approx_sympd = false; } + } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + } + + + +template +inline +typename enable_if2::yes, void>::result +analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + is_approx_sym = true; + is_approx_sympd = true; + + const T tol = T(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + T max_diag = T(0); + + for(uword j=0; j < N; ++j) + { + const eT& A_jj = A_col[j]; + const T A_jj_real = std::real(A_jj); + const T A_jj_imag = std::imag(A_jj); + + if( (A_jj_real <= T(0)) || (std::abs(A_jj_imag) > tol) ) { is_approx_sympd = false; } + + max_diag = (A_jj_real > max_diag) ? A_jj_real : max_diag; + + A_col += N; + } + + const T square_max_diag = max_diag * max_diag; + + if(arma_isfinite(square_max_diag) == false) { is_approx_sympd = false; } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + const T A_jj_real = std::real(A_col[j]); + + for(uword i=jp1; i < N; ++i) + { + const eT& A_ij = A_col[i]; + const T A_ij_real = std::real(A_ij); + const T A_ij_imag = std::imag(A_ij); + + const T A_ij_real_abs = (std::abs)(A_ij_real); + const T A_ij_imag_abs = (std::abs)(A_ij_imag); + + const eT& A_ji = (*A_ji_ptr); + const T A_ji_real = std::real(A_ji); + const T A_ji_imag = std::imag(A_ji); + + const T A_ji_real_abs = (std::abs)(A_ji_real); + const T A_ji_imag_abs = (std::abs)(A_ji_imag); + + const T A_real_delta = (std::abs)(A_ij_real - A_ji_real); + const T A_real_abs_max = (std::max)(A_ij_real_abs, A_ji_real_abs); + + if( (A_real_delta > tol) && (A_real_delta > (A_real_abs_max*tol)) ) { is_approx_sym = false; return; } + + const T A_imag_delta = (std::abs)(A_ij_imag + A_ji_imag); // take into account complex conjugate + const T A_imag_abs_max = (std::max)(A_ij_imag_abs, A_ji_imag_abs); + + if( (A_imag_delta > tol) && (A_imag_delta > (A_imag_abs_max*tol)) ) { is_approx_sym = false; return; } + + if(is_approx_sympd) + { + // avoid using std::abs(), as that is time consuming due to division and std::sqrt() + const T square_A_ij_abs = (A_ij_real * A_ij_real) + (A_ij_imag * A_ij_imag); + + if(arma_isfinite(square_A_ij_abs) == false) + { + is_approx_sympd = false; + } + else + { + const T A_ii_real = std::real(*A_ii_ptr); + + if( (A_ij_real_abs + A_ij_real_abs) >= (A_ii_real + A_jj_real) ) { is_approx_sympd = false; } + + if(square_A_ij_abs >= square_max_diag) { is_approx_sympd = false; } + } + } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + } + + + +template +inline +void +analyse_matrix(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) + { + arma_extra_debug_sigprint(); + + if((A.n_rows != A.n_cols) || (A.n_rows < uword(4))) + { + is_approx_sym = false; + is_approx_sympd = false; + return; + } + + analyse_matrix_worker(is_approx_sym, is_approx_sympd, A); + + if(is_approx_sym == false) { is_approx_sympd = false; } + } + + + +template +inline +bool +check_diag_imag(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming matrix A is square-sized + + typedef typename get_pod_type::result T; + + const T tol = T(10000) * std::numeric_limits::epsilon(); // allow some leeway + + const eT* colmem = A.memptr(); + + const uword N = A.n_rows; + + for(uword i=0; i tol) { return false; } + + colmem += N; + } + + return true; + } + + + +} // end of namespace sym_helper + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/traits.hpp b/src/armadillo/include/armadillo_bits/traits.hpp new file mode 100644 index 0000000..bde2200 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/traits.hpp @@ -0,0 +1,1315 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup traits +//! @{ + + +template +struct get_pod_type + { typedef T1 result; }; + +template +struct get_pod_type< std::complex > + { typedef T2 result; }; + + + +template +struct is_Mat_fixed_only + { + typedef char yes[1]; + typedef char no[2]; + + template static yes& check(typename X::Mat_fixed_type*); + template static no& check(...); + + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); + }; + + + +template +struct is_Row_fixed_only + { + typedef char yes[1]; + typedef char no[2]; + + template static yes& check(typename X::Row_fixed_type*); + template static no& check(...); + + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); + }; + + + +template +struct is_Col_fixed_only + { + typedef char yes[1]; + typedef char no[2]; + + template static yes& check(typename X::Col_fixed_type*); + template static no& check(...); + + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); + }; + + + +template +struct is_Mat_fixed + { static constexpr bool value = ( is_Mat_fixed_only::value || is_Row_fixed_only::value || is_Col_fixed_only::value ); }; + + + +template +struct is_Mat_only + { static constexpr bool value = is_Mat_fixed_only::value; }; + +template +struct is_Mat_only< Mat > + { static constexpr bool value = true; }; + +template +struct is_Mat_only< const Mat > + { static constexpr bool value = true; }; + + + +template +struct is_Mat + { static constexpr bool value = ( is_Mat_fixed_only::value || is_Row_fixed_only::value || is_Col_fixed_only::value ); }; + +template +struct is_Mat< Mat > + { static constexpr bool value = true; }; + +template +struct is_Mat< const Mat > + { static constexpr bool value = true; }; + +template +struct is_Mat< Row > + { static constexpr bool value = true; }; + +template +struct is_Mat< const Row > + { static constexpr bool value = true; }; + +template +struct is_Mat< Col > + { static constexpr bool value = true; }; + +template +struct is_Mat< const Col > + { static constexpr bool value = true; }; + + + +template +struct is_Row + { static constexpr bool value = is_Row_fixed_only::value; }; + +template +struct is_Row< Row > + { static constexpr bool value = true; }; + +template +struct is_Row< const Row > + { static constexpr bool value = true; }; + + + +template +struct is_Col + { static constexpr bool value = is_Col_fixed_only::value; }; + +template +struct is_Col< Col > + { static constexpr bool value = true; }; + +template +struct is_Col< const Col > + { static constexpr bool value = true; }; + + + +template +struct is_diagview + { static constexpr bool value = false; }; + +template +struct is_diagview< diagview > + { static constexpr bool value = true; }; + +template +struct is_diagview< const diagview > + { static constexpr bool value = true; }; + + +template +struct is_subview + { static constexpr bool value = false; }; + +template +struct is_subview< subview > + { static constexpr bool value = true; }; + +template +struct is_subview< const subview > + { static constexpr bool value = true; }; + + +template +struct is_subview_row + { static constexpr bool value = false; }; + +template +struct is_subview_row< subview_row > + { static constexpr bool value = true; }; + +template +struct is_subview_row< const subview_row > + { static constexpr bool value = true; }; + + +template +struct is_subview_col + { static constexpr bool value = false; }; + +template +struct is_subview_col< subview_col > + { static constexpr bool value = true; }; + +template +struct is_subview_col< const subview_col > + { static constexpr bool value = true; }; + + +template +struct is_subview_cols + { static constexpr bool value = false; }; + +template +struct is_subview_cols< subview_cols > + { static constexpr bool value = true; }; + +template +struct is_subview_cols< const subview_cols > + { static constexpr bool value = true; }; + + +template +struct is_subview_elem1 + { static constexpr bool value = false; }; + +template +struct is_subview_elem1< subview_elem1 > + { static constexpr bool value = true; }; + +template +struct is_subview_elem1< const subview_elem1 > + { static constexpr bool value = true; }; + + +template +struct is_subview_elem2 + { static constexpr bool value = false; }; + +template +struct is_subview_elem2< subview_elem2 > + { static constexpr bool value = true; }; + +template +struct is_subview_elem2< const subview_elem2 > + { static constexpr bool value = true; }; + + + +// +// +// + + + +template +struct is_Cube + { static constexpr bool value = false; }; + +template +struct is_Cube< Cube > + { static constexpr bool value = true; }; + +template +struct is_Cube< const Cube > + { static constexpr bool value = true; }; + +template +struct is_subview_cube + { static constexpr bool value = false; }; + +template +struct is_subview_cube< subview_cube > + { static constexpr bool value = true; }; + +template +struct is_subview_cube< const subview_cube > + { static constexpr bool value = true; }; + +template +struct is_subview_cube_slices + { static constexpr bool value = false; }; + +template +struct is_subview_cube_slices< subview_cube_slices > + { static constexpr bool value = true; }; + +template +struct is_subview_cube_slices< const subview_cube_slices > + { static constexpr bool value = true; }; + + +// +// +// + + +template +struct is_Gen + { static constexpr bool value = false; }; + +template +struct is_Gen< Gen > + { static constexpr bool value = true; }; + +template +struct is_Gen< const Gen > + { static constexpr bool value = true; }; + + +template +struct is_Op + { static constexpr bool value = false; }; + +template +struct is_Op< Op > + { static constexpr bool value = true; }; + +template +struct is_Op< const Op > + { static constexpr bool value = true; }; + + +template +struct is_CubeToMatOp + { static constexpr bool value = false; }; + +template +struct is_CubeToMatOp< CubeToMatOp > + { static constexpr bool value = true; }; + +template +struct is_CubeToMatOp< const CubeToMatOp > + { static constexpr bool value = true; }; + + +template +struct is_SpToDOp + { static constexpr bool value = false; }; + +template +struct is_SpToDOp< SpToDOp > + { static constexpr bool value = true; }; + +template +struct is_SpToDOp< const SpToDOp > + { static constexpr bool value = true; }; + + +template +struct is_SpToDGlue + { static constexpr bool value = false; }; + +template +struct is_SpToDGlue< SpToDGlue > + { static constexpr bool value = true; }; + +template +struct is_SpToDGlue< const SpToDGlue > + { static constexpr bool value = true; }; + + +template +struct is_eOp + { static constexpr bool value = false; }; + +template +struct is_eOp< eOp > + { static constexpr bool value = true; }; + +template +struct is_eOp< const eOp > + { static constexpr bool value = true; }; + + +template +struct is_mtOp + { static constexpr bool value = false; }; + +template +struct is_mtOp< mtOp > + { static constexpr bool value = true; }; + +template +struct is_mtOp< const mtOp > + { static constexpr bool value = true; }; + + +template +struct is_Glue + { static constexpr bool value = false; }; + +template +struct is_Glue< Glue > + { static constexpr bool value = true; }; + +template +struct is_Glue< const Glue > + { static constexpr bool value = true; }; + + +template +struct is_eGlue + { static constexpr bool value = false; }; + +template +struct is_eGlue< eGlue > + { static constexpr bool value = true; }; + +template +struct is_eGlue< const eGlue > + { static constexpr bool value = true; }; + + +template +struct is_mtGlue + { static constexpr bool value = false; }; + +template +struct is_mtGlue< mtGlue > + { static constexpr bool value = true; }; + +template +struct is_mtGlue< const mtGlue > + { static constexpr bool value = true; }; + + +// +// + + +template +struct is_glue_times + { static constexpr bool value = false; }; + +template +struct is_glue_times< Glue > + { static constexpr bool value = true; }; + +template +struct is_glue_times< const Glue > + { static constexpr bool value = true; }; + + +template +struct is_glue_times_diag + { static constexpr bool value = false; }; + +template +struct is_glue_times_diag< Glue > + { static constexpr bool value = true; }; + +template +struct is_glue_times_diag< const Glue > + { static constexpr bool value = true; }; + + +template +struct is_op_diagmat + { static constexpr bool value = false; }; + +template +struct is_op_diagmat< Op > + { static constexpr bool value = true; }; + +template +struct is_op_diagmat< const Op > + { static constexpr bool value = true; }; + + +// +// + + +template +struct is_Mat_trans + { static constexpr bool value = false; }; + +template +struct is_Mat_trans< Op > + { static constexpr bool value = is_Mat::value; }; + +template +struct is_Mat_trans< Op > + { static constexpr bool value = is_Mat::value; }; + + +// +// + + +template +struct is_GenCube + { static constexpr bool value = false; }; + +template +struct is_GenCube< GenCube > + { static constexpr bool value = true; }; + + +template +struct is_OpCube + { static constexpr bool value = false; }; + +template +struct is_OpCube< OpCube > + { static constexpr bool value = true; }; + + +template +struct is_eOpCube + { static constexpr bool value = false; }; + +template +struct is_eOpCube< eOpCube > + { static constexpr bool value = true; }; + + +template +struct is_mtOpCube + { static constexpr bool value = false; }; + +template +struct is_mtOpCube< mtOpCube > + { static constexpr bool value = true; }; + + +template +struct is_GlueCube + { static constexpr bool value = false; }; + +template +struct is_GlueCube< GlueCube > + { static constexpr bool value = true; }; + + +template +struct is_eGlueCube + { static constexpr bool value = false; }; + +template +struct is_eGlueCube< eGlueCube > + { static constexpr bool value = true; }; + + +template +struct is_mtGlueCube + { static constexpr bool value = false; }; + +template +struct is_mtGlueCube< mtGlueCube > + { static constexpr bool value = true; }; + + +// +// +// + + +template +struct is_arma_type2 + { + static constexpr bool value + = is_Mat::value + || is_Gen::value + || is_Op::value + || is_Glue::value + || is_eOp::value + || is_eGlue::value + || is_mtOp::value + || is_mtGlue::value + || is_diagview::value + || is_subview::value + || is_subview_row::value + || is_subview_col::value + || is_subview_cols::value + || is_subview_elem1::value + || is_subview_elem2::value + || is_CubeToMatOp::value + || is_SpToDOp::value + || is_SpToDGlue::value + ; + }; + + + +// due to rather baroque C++ rules for proving constant expressions, +// certain compilers may get confused with the combination of conditional inheritance, nested classes and the shenanigans in is_Mat_fixed_only. +// below we explicitly ensure the type is forced to be const, which seems to eliminate the confusion. +template +struct is_arma_type + { + static constexpr bool value = is_arma_type2::value; + }; + + + +template +struct is_arma_cube_type + { + static constexpr bool value + = is_Cube::value + || is_GenCube::value + || is_OpCube::value + || is_eOpCube::value + || is_mtOpCube::value + || is_GlueCube::value + || is_eGlueCube::value + || is_mtGlueCube::value + || is_subview_cube::value + || is_subview_cube_slices::value + ; + }; + + + +// +// +// + + + +template +struct is_SpMat + { static constexpr bool value = false; }; + +template +struct is_SpMat< SpMat > + { static constexpr bool value = true; }; + +template +struct is_SpMat< SpCol > + { static constexpr bool value = true; }; + +template +struct is_SpMat< SpRow > + { static constexpr bool value = true; }; + + + +template +struct is_SpRow + { static constexpr bool value = false; }; + +template +struct is_SpRow< SpRow > + { static constexpr bool value = true; }; + + + +template +struct is_SpCol + { static constexpr bool value = false; }; + +template +struct is_SpCol< SpCol > + { static constexpr bool value = true; }; + + + +template +struct is_SpSubview + { static constexpr bool value = false; }; + +template +struct is_SpSubview< SpSubview > + { static constexpr bool value = true; }; + + +template +struct is_SpSubview_col + { static constexpr bool value = false; }; + +template +struct is_SpSubview_col< SpSubview_col > + { static constexpr bool value = true; }; + + +template +struct is_SpSubview_col_list + { static constexpr bool value = false; }; + +template +struct is_SpSubview_col_list< SpSubview_col_list > + { static constexpr bool value = true; }; + + +template +struct is_SpSubview_row + { static constexpr bool value = false; }; + +template +struct is_SpSubview_row< SpSubview_row > + { static constexpr bool value = true; }; + + +template +struct is_spdiagview + { static constexpr bool value = false; }; + +template +struct is_spdiagview< spdiagview > + { static constexpr bool value = true; }; + + +template +struct is_SpOp + { static constexpr bool value = false; }; + +template +struct is_SpOp< SpOp > + { static constexpr bool value = true; }; + + +template +struct is_SpGlue + { static constexpr bool value = false; }; + +template +struct is_SpGlue< SpGlue > + { static constexpr bool value = true; }; + + +template +struct is_mtSpOp + { static constexpr bool value = false; }; + +template +struct is_mtSpOp< mtSpOp > + { static constexpr bool value = true; }; + + +template +struct is_mtSpGlue + { static constexpr bool value = false; }; + +template +struct is_mtSpGlue< mtSpGlue > + { static constexpr bool value = true; }; + + + +template +struct is_arma_sparse_type + { + static constexpr bool value + = is_SpMat::value + || is_SpSubview::value + || is_SpSubview_col::value + || is_SpSubview_col_list::value + || is_SpSubview_row::value + || is_spdiagview::value + || is_SpOp::value + || is_SpGlue::value + || is_mtSpOp::value + || is_mtSpGlue::value + ; + }; + + + +// +// +// + + +template +struct is_same_type + { + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; + }; + + +template +struct is_same_type + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; + + + +// +// +// + + +template +struct is_u8 + { static constexpr bool value = false; }; + +template<> +struct is_u8 + { static constexpr bool value = true; }; + + + +template +struct is_s8 + { static constexpr bool value = false; }; + +template<> +struct is_s8 + { static constexpr bool value = true; }; + + + +template +struct is_u16 + { static constexpr bool value = false; }; + +template<> +struct is_u16 + { static constexpr bool value = true; }; + + + +template +struct is_s16 + { static constexpr bool value = false; }; + +template<> +struct is_s16 + { static constexpr bool value = true; }; + + + +template +struct is_u32 + { static constexpr bool value = false; }; + +template<> +struct is_u32 + { static constexpr bool value = true; }; + + + +template +struct is_s32 + { static constexpr bool value = false; }; + +template<> +struct is_s32 + { static constexpr bool value = true; }; + + + +template +struct is_u64 + { static constexpr bool value = false; }; + +template<> +struct is_u64 + { static constexpr bool value = true; }; + + +template +struct is_s64 + { static constexpr bool value = false; }; + +template<> +struct is_s64 + { static constexpr bool value = true; }; + + + +template +struct is_ulng_t + { static constexpr bool value = false; }; + +template<> +struct is_ulng_t + { static constexpr bool value = true; }; + + + +template +struct is_slng_t + { static constexpr bool value = false; }; + +template<> +struct is_slng_t + { static constexpr bool value = true; }; + + + +template +struct is_ulng_t_32 + { static constexpr bool value = false; }; + +template<> +struct is_ulng_t_32 + { static constexpr bool value = (sizeof(ulng_t) == 4); }; + + + +template +struct is_slng_t_32 + { static constexpr bool value = false; }; + +template<> +struct is_slng_t_32 + { static constexpr bool value = (sizeof(slng_t) == 4); }; + + + +template +struct is_ulng_t_64 + { static constexpr bool value = false; }; + +template<> +struct is_ulng_t_64 + { static constexpr bool value = (sizeof(ulng_t) == 8); }; + + + +template +struct is_slng_t_64 + { static constexpr bool value = false; }; + +template<> +struct is_slng_t_64 + { static constexpr bool value = (sizeof(slng_t) == 8); }; + + + +template +struct is_uword + { static constexpr bool value = false; }; + +template<> +struct is_uword + { static constexpr bool value = true; }; + + + +template +struct is_sword + { static constexpr bool value = false; }; + +template<> +struct is_sword + { static constexpr bool value = true; }; + + + +template +struct is_float + { static constexpr bool value = false; }; + +template<> +struct is_float + { static constexpr bool value = true; }; + + + +template +struct is_double + { static constexpr bool value = false; }; + +template<> +struct is_double + { static constexpr bool value = true; }; + + + +template +struct is_real + { + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; + }; + +template<> +struct is_real + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; + +template<> +struct is_real + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; + + + + +template +struct is_cx + { + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; + }; + +// template<> +template +struct is_cx< std::complex > + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; + + + +template +struct is_cx_float + { + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; + }; + +template<> +struct is_cx_float< std::complex > + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; + + + +template +struct is_cx_double + { + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; + }; + +template<> +struct is_cx_double< std::complex > + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; + + + +template +struct is_supported_elem_type + { + static constexpr bool value = \ + is_u8::value || + is_s8::value || + is_u16::value || + is_s16::value || + is_u32::value || + is_s32::value || + is_u64::value || + is_s64::value || + is_ulng_t::value || + is_slng_t::value || + is_float::value || + is_double::value || + is_cx_float::value || + is_cx_double::value; + }; + + + +template +struct is_supported_blas_type + { + static constexpr bool value = \ + is_float::value || + is_double::value || + is_cx_float::value || + is_cx_double::value; + }; + + + +template +struct has_blas_float_bug + { + #if defined(ARMA_BLAS_FLOAT_BUG) + static constexpr bool value = is_float::result>::value; + #else + static constexpr bool value = false; + #endif + }; + + + +template +struct is_signed + { + static constexpr bool value = true; + }; + + +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; + + +template +struct is_non_integral + { + static constexpr bool value = false; + }; + + +template<> struct is_non_integral< float > { static constexpr bool value = true; }; +template<> struct is_non_integral< double > { static constexpr bool value = true; }; +template<> struct is_non_integral< std::complex > { static constexpr bool value = true; }; +template<> struct is_non_integral< std::complex > { static constexpr bool value = true; }; + + + + +// + +class arma_junk_class; + +template +struct force_different_type + { + typedef T1 T1_result; + typedef T2 T2_result; + }; + + +template +struct force_different_type + { + typedef T1 T1_result; + typedef arma_junk_class T2_result; + }; + + + +// + + +template +struct resolves_to_vector_default + { + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; + }; + +template +struct resolves_to_vector_test + { + static constexpr bool value = (T1::is_col || T1::is_row || T1::is_xvec); + static constexpr bool yes = (T1::is_col || T1::is_row || T1::is_xvec); + static constexpr bool no = ((T1::is_col || T1::is_row || T1::is_xvec) == false); + }; + + +template +struct resolves_to_vector_redirect {}; + +template +struct resolves_to_vector_redirect { typedef resolves_to_vector_default result; }; + +template +struct resolves_to_vector_redirect { typedef resolves_to_vector_test result; }; + + +template +struct resolves_to_vector : public resolves_to_vector_redirect::value>::result {}; + +template +struct resolves_to_sparse_vector : public resolves_to_vector_redirect::value>::result {}; + +// + +template +struct resolves_to_rowvector_default { static constexpr bool value = false; }; + +template +struct resolves_to_rowvector_test { static constexpr bool value = T1::is_row; }; + + +template +struct resolves_to_rowvector_redirect {}; + +template +struct resolves_to_rowvector_redirect { typedef resolves_to_rowvector_default result; }; + +template +struct resolves_to_rowvector_redirect { typedef resolves_to_rowvector_test result; }; + + +template +struct resolves_to_rowvector : public resolves_to_rowvector_redirect::value>::result {}; + +// + +template +struct resolves_to_colvector_default { static constexpr bool value = false; }; + +template +struct resolves_to_colvector_test { static constexpr bool value = T1::is_col; }; + + +template +struct resolves_to_colvector_redirect {}; + +template +struct resolves_to_colvector_redirect { typedef resolves_to_colvector_default result; }; + +template +struct resolves_to_colvector_redirect { typedef resolves_to_colvector_test result; }; + + +template +struct resolves_to_colvector : public resolves_to_colvector_redirect::value>::result {}; + + + +template +struct is_outer_product + { static constexpr bool value = false; }; + +template +struct is_outer_product< Glue > + { static constexpr bool value = (resolves_to_colvector::value && resolves_to_rowvector::value); }; + + + +template +struct has_op_inv_any + { static constexpr bool value = false; }; + +template +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; + +template +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; + + + + +template +struct has_nested_op_traits + { + typedef char yes[1]; + typedef char no[2]; + + template static yes& check(typename X::template traits*); + template static no& check(...); + + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); + }; + +template +struct has_nested_glue_traits + { + typedef char yes[1]; + typedef char no[2]; + + template static yes& check(typename X::template traits*); + template static no& check(...); + + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/translate_arpack.hpp b/src/armadillo/include/armadillo_bits/translate_arpack.hpp new file mode 100644 index 0000000..8482892 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/translate_arpack.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if defined(ARMA_USE_ARPACK) + +//! \namespace arpack namespace for ARPACK functions +namespace arpack + { + + // If real, then eT == eeT; otherwise, eT == std::complex. + // For real calls, rwork is ignored; it's only necessary in the complex case. + template + inline + void + naupd(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, eeT* tol, eT* resid, blas_int* ncv, eT* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, eT* workd, eT* workl, blas_int* lworkl, eeT* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_ignore(rwork); arma_fortran(arma_snaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1); } + else if( is_double::value) { typedef double T; arma_ignore(rwork); arma_fortran(arma_dnaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1); } + else if( is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cnaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info, 1, 1); } + else if(is_cx_double::value) { typedef cx_double T; typedef double xT; arma_fortran(arma_znaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_ignore(rwork); arma_fortran(arma_snaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + else if( is_double::value) { typedef double T; arma_ignore(rwork); arma_fortran(arma_dnaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + else if( is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cnaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info); } + else if(is_cx_double::value) { typedef cx_double T; typedef double xT; arma_fortran(arma_znaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info); } + #endif + } + + + //! The use of two template types is necessary here because the compiler will + //! instantiate this method for complex types (where eT != eeT) but that in + //! practice that is never actually used. + template + inline + void + saupd(blas_int* ido, char* bmat, blas_int* n, char* which, blas_int* nev, eeT* tol, eT* resid, blas_int* ncv, eT* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, eT* workd, eT* workl, blas_int* lworkl, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + #endif + } + + + + template + inline + void + seupd(blas_int* rvec, char* howmny, blas_int* select, eT* d, eT* z, blas_int* ldz, eT* sigma, char* bmat, blas_int* n, char* which, blas_int* nev, eT* tol, eT* resid, blas_int* ncv, eT* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, eT* workd, eT* workl, blas_int* lworkl, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sseupd)(rvec, howmny, select, (T*) d, (T*) z, ldz, (T*) sigma, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dseupd)(rvec, howmny, select, (T*) d, (T*) z, ldz, (T*) sigma, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sseupd)(rvec, howmny, select, (T*) d, (T*) z, ldz, (T*) sigma, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dseupd)(rvec, howmny, select, (T*) d, (T*) z, ldz, (T*) sigma, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + #endif + } + + + + // for complex versions, pass d for dr, and null for di; pass sigma for + // sigmar, and null for sigmai; rwork isn't used for non-complex versions + template + inline + void + neupd(blas_int* rvec, char* howmny, blas_int* select, eT* dr, eT* di, eT* z, blas_int* ldz, eT* sigmar, eT* sigmai, eT* workev, char* bmat, blas_int* n, char* which, blas_int* nev, eeT* tol, eT* resid, blas_int* ncv, eT* v, blas_int* ldv, blas_int* iparam, blas_int* ipntr, eT* workd, eT* workl, blas_int* lworkl, eeT* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_ignore(rwork); arma_fortran(arma_sneupd)(rvec, howmny, select, (T*) dr, (T*) di, (T*) z, ldz, (T*) sigmar, (T*) sigmai, (T*) workev, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1, 1); } + else if( is_double::value) { typedef double T; arma_ignore(rwork); arma_fortran(arma_dneupd)(rvec, howmny, select, (T*) dr, (T*) di, (T*) z, ldz, (T*) sigmar, (T*) sigmai, (T*) workev, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1, 1); } + else if( is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cneupd)(rvec, howmny, select, (T*) dr, (T*) z, ldz, (T*) sigmar, (T*) workev, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info, 1, 1, 1); } + else if(is_cx_double::value) { typedef cx_double T; typedef double xT; arma_fortran(arma_zneupd)(rvec, howmny, select, (T*) dr, (T*) z, ldz, (T*) sigmar, (T*) workev, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_ignore(rwork); arma_fortran(arma_sneupd)(rvec, howmny, select, (T*) dr, (T*) di, (T*) z, ldz, (T*) sigmar, (T*) sigmai, (T*) workev, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + else if( is_double::value) { typedef double T; arma_ignore(rwork); arma_fortran(arma_dneupd)(rvec, howmny, select, (T*) dr, (T*) di, (T*) z, ldz, (T*) sigmar, (T*) sigmai, (T*) workev, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } + else if( is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cneupd)(rvec, howmny, select, (T*) dr, (T*) z, ldz, (T*) sigmar, (T*) workev, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info); } + else if(is_cx_double::value) { typedef cx_double T; typedef double xT; arma_fortran(arma_zneupd)(rvec, howmny, select, (T*) dr, (T*) z, ldz, (T*) sigmar, (T*) workev, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info); } + #endif + } + + + } // namespace arpack + + +#endif diff --git a/src/armadillo/include/armadillo_bits/translate_atlas.hpp b/src/armadillo/include/armadillo_bits/translate_atlas.hpp new file mode 100644 index 0000000..95d43d5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/translate_atlas.hpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_ATLAS) + + +// TODO: remove support for ATLAS in next major version + +//! \namespace atlas namespace for ATLAS functions +namespace atlas + { + + template + inline static const eT& tmp_real(const eT& X) { return X; } + + template + inline static const T tmp_real(const std::complex& X) { return X.real(); } + + + + template + arma_inline + eT + cblas_asum(const int N, const eT* X) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + return eT( arma_wrapper(cblas_sasum)(N, (const T*)X, 1) ); + } + else + if(is_double::value) + { + typedef double T; + return eT( arma_wrapper(cblas_dasum)(N, (const T*)X, 1) ); + } + + return eT(0); + } + + + + template + arma_inline + eT + cblas_nrm2(const int N, const eT* X) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + return eT( arma_wrapper(cblas_snrm2)(N, (const T*)X, 1) ); + } + else + if(is_double::value) + { + typedef double T; + return eT( arma_wrapper(cblas_dnrm2)(N, (const T*)X, 1) ); + } + + return eT(0); + } + + + + template + arma_inline + eT + cblas_dot(const int N, const eT* X, const eT* Y) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + return eT( arma_wrapper(cblas_sdot)(N, (const T*)X, 1, (const T*)Y, 1) ); + } + else + if(is_double::value) + { + typedef double T; + return eT( arma_wrapper(cblas_ddot)(N, (const T*)X, 1, (const T*)Y, 1) ); + } + + return eT(0); + } + + + + template + arma_inline + eT + cblas_cx_dot(const int N, const eT* X, const eT* Y) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_cx_float::value) + { + typedef typename std::complex T; + + T out; + arma_wrapper(cblas_cdotu_sub)(N, (const T*)X, 1, (const T*)Y, 1, &out); + + return eT(out); + } + else + if(is_cx_double::value) + { + typedef typename std::complex T; + + T out; + arma_wrapper(cblas_zdotu_sub)(N, (const T*)X, 1, (const T*)Y, 1, &out); + + return eT(out); + } + + return eT(0); + } + + + + template + inline + void + cblas_gemv + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, + const int M, const int N, + const eT alpha, + const eT *A, const int lda, + const eT *X, const int incX, + const eT beta, + eT *Y, const int incY + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + arma_wrapper(cblas_sgemv)(layout, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(cblas_dgemv)(layout, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); + } + else + if(is_cx_float::value) + { + typedef std::complex T; + arma_wrapper(cblas_cgemv)(layout, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); + } + else + if(is_cx_double::value) + { + typedef std::complex T; + arma_wrapper(cblas_zgemv)(layout, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); + } + } + + + + template + inline + void + cblas_gemm + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, + const atlas_CBLAS_TRANS TransB, const int M, const int N, + const int K, const eT alpha, const eT *A, + const int lda, const eT *B, const int ldb, + const eT beta, eT *C, const int ldc + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + arma_wrapper(cblas_sgemm)(layout, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(cblas_dgemm)(layout, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); + } + else + if(is_cx_float::value) + { + typedef std::complex T; + arma_wrapper(cblas_cgemm)(layout, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); + } + else + if(is_cx_double::value) + { + typedef std::complex T; + arma_wrapper(cblas_zgemm)(layout, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); + } + } + + + + template + inline + void + cblas_syrk + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, + const int N, const int K, const eT alpha, + const eT* A, const int lda, const eT beta, eT* C, const int ldc + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + arma_wrapper(cblas_ssyrk)(layout, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(cblas_dsyrk)(layout, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); + } + } + + + + template + inline + void + cblas_herk + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, + const int N, const int K, const T alpha, + const std::complex* A, const int lda, const T beta, std::complex* C, const int ldc + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float TT; + typedef std::complex cx_TT; + + arma_wrapper(cblas_cherk)(layout, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); + } + else + if(is_double::value) + { + typedef double TT; + typedef std::complex cx_TT; + + arma_wrapper(cblas_zherk)(layout, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); + } + } + + } + +#endif diff --git a/src/armadillo/include/armadillo_bits/translate_blas.hpp b/src/armadillo/include/armadillo_bits/translate_blas.hpp new file mode 100644 index 0000000..91fb6a2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/translate_blas.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if defined(ARMA_USE_BLAS) + + +//! \namespace blas namespace for BLAS functions +namespace blas + { + + + template + inline + void + gemv(const char* transA, const blas_int* m, const blas_int* n, const eT* alpha, const eT* A, const blas_int* ldA, const eT* x, const blas_int* incx, const eT* beta, eT* y, const blas_int* incy) + { + arma_type_check((is_supported_blas_type::value == false)); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + { + if( is_float::value) { typedef float T; arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); } + } + #else + { + if( is_float::value) { typedef float T; arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); } + } + #endif + } + + + + template + inline + void + gemm(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* B, const blas_int* ldB, const eT* beta, eT* C, const blas_int* ldC) + { + arma_type_check((is_supported_blas_type::value == false)); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + { + if( is_float::value) { typedef float T; arma_fortran(arma_sgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); } + } + #else + { + if( is_float::value) { typedef float T; arma_fortran(arma_sgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); } + } + #endif + } + + + + template + inline + void + syrk(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* beta, eT* C, const blas_int* ldC) + { + arma_type_check((is_supported_blas_type::value == false)); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + { + if( is_float::value) { typedef float T; arma_fortran(arma_ssyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC, 1, 1); } + } + #else + { + if( is_float::value) { typedef float T; arma_fortran(arma_ssyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC); } + } + #endif + } + + + + template + inline + void + herk(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const T* alpha, const std::complex* A, const blas_int* ldA, const T* beta, std::complex* C, const blas_int* ldC) + { + arma_type_check((is_supported_blas_type::value == false)); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + { + if( is_float::value) { typedef float TT; typedef blas_cxf cx_TT; arma_fortran(arma_cherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC, 1, 1); } + else if(is_double::value) { typedef double TT; typedef blas_cxd cx_TT; arma_fortran(arma_zherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC, 1, 1); } + } + #else + { + if( is_float::value) { typedef float TT; typedef blas_cxf cx_TT; arma_fortran(arma_cherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC); } + else if(is_double::value) { typedef double TT; typedef blas_cxd cx_TT; arma_fortran(arma_zherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC); } + } + #endif + } + + + + template + inline + eT + dot(const uword n_elem, const eT* x, const eT* y) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + #if defined(ARMA_BLAS_FLOAT_BUG) + { + if(n_elem == 0) { return eT(0); } + + const char trans = 'T'; + + const blas_int m = blas_int(n_elem); + const blas_int n = 1; + const blas_int inc = 1; + + const eT alpha = eT(1); + const eT beta = eT(0); + + eT result[2]; // paranoia: using two elements instead of one + + blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc); + + return result[0]; + } + #else + { + blas_int n = blas_int(n_elem); + blas_int inc = 1; + + typedef float T; + return eT( arma_fortran(arma_sdot)(&n, (const T*)x, &inc, (const T*)y, &inc) ); + } + #endif + } + else + if(is_double::value) + { + blas_int n = blas_int(n_elem); + blas_int inc = 1; + + typedef double T; + return eT( arma_fortran(arma_ddot)(&n, (const T*)x, &inc, (const T*)y, &inc) ); + } + else + if( (is_cx_float::value) || (is_cx_double::value) ) + { + if(n_elem == 0) { return eT(0); } + + // using gemv() workaround due to compatibility issues with cdotu() and zdotu() + + const char trans = 'T'; + + const blas_int m = blas_int(n_elem); + const blas_int n = 1; + const blas_int inc = 1; + + const eT alpha = eT(1); + const eT beta = eT(0); + + eT result[2]; // paranoia: using two elements instead of one + + blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc); + + return result[0]; + } + + return eT(0); + } + + + + template + arma_inline + eT + asum(const uword n_elem, const eT* x) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + blas_int n = blas_int(n_elem); + blas_int inc = 1; + + typedef float T; + return arma_fortran(arma_sasum)(&n, (const T*)x, &inc); + } + else + if(is_double::value) + { + blas_int n = blas_int(n_elem); + blas_int inc = 1; + + typedef double T; + return arma_fortran(arma_dasum)(&n, (const T*)x, &inc); + } + + return eT(0); + } + + + + template + arma_inline + eT + nrm2(const uword n_elem, const eT* x) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + blas_int n = blas_int(n_elem); + blas_int inc = 1; + + typedef float T; + return arma_fortran(arma_snrm2)(&n, (const T*)x, &inc); + } + else + if(is_double::value) + { + blas_int n = blas_int(n_elem); + blas_int inc = 1; + + typedef double T; + return arma_fortran(arma_dnrm2)(&n, (const T*)x, &inc); + } + + return eT(0); + } + + + } // namespace blas + + +#endif diff --git a/src/armadillo/include/armadillo_bits/translate_fftw3.hpp b/src/armadillo/include/armadillo_bits/translate_fftw3.hpp new file mode 100644 index 0000000..1edd727 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/translate_fftw3.hpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_FFTW3) + + +namespace fftw3 + { + template + arma_inline + void_ptr + plan_dft_1d(int N, eT* input, eT* output, int fftw3_sign, unsigned int fftw3_flags) + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + return fftwf_plan_dft_1d(N, (cx_float*)input, (cx_float*)output, fftw3_sign, fftw3_flags); + } + else + if(is_cx_double::value) + { + return fftw_plan_dft_1d(N, (cx_double*)input, (cx_double*)output, fftw3_sign, fftw3_flags); + } + + return nullptr; + } + + + + template + arma_inline + void + execute(void_ptr plan) + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + fftwf_execute(plan); + } + else + if(is_cx_double::value) + { + fftw_execute(plan); + } + } + + + + template + arma_inline + void + destroy_plan(void_ptr plan) + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + fftwf_destroy_plan(plan); + } + else + if(is_cx_double::value) + { + fftw_destroy_plan(plan); + } + } + + + + template + arma_inline + void + cleanup() + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + fftwf_cleanup(); + } + else + if(is_cx_double::value) + { + fftw_cleanup(); + } + } + } + + +#endif diff --git a/src/armadillo/include/armadillo_bits/translate_lapack.hpp b/src/armadillo/include/armadillo_bits/translate_lapack.hpp new file mode 100644 index 0000000..7ed4c0e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/translate_lapack.hpp @@ -0,0 +1,1347 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if defined(ARMA_USE_LAPACK) + + +//! \namespace lapack namespace for LAPACK functions +namespace lapack + { + + template + inline + void + getrf(blas_int* m, blas_int* n, eT* a, blas_int* lda, blas_int* ipiv, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgetrf)(m, n, (T*)a, lda, ipiv, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgetrf)(m, n, (T*)a, lda, ipiv, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgetrf)(m, n, (T*)a, lda, ipiv, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgetrf)(m, n, (T*)a, lda, ipiv, info); } + } + + + + template + inline + void + getrs(char* trans, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, blas_int* ipiv, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgetrs)(trans, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + getri(blas_int* n, eT* a, blas_int* lda, blas_int* ipiv, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgetri)(n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgetri)(n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgetri)(n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgetri)(n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + } + + + + template + inline + void + trtri(char* uplo, char* diag, blas_int* n, eT* a, blas_int* lda, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_strtri)(uplo, diag, n, (T*)a, lda, info, 1, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dtrtri)(uplo, diag, n, (T*)a, lda, info, 1, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_ctrtri)(uplo, diag, n, (T*)a, lda, info, 1, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_ztrtri)(uplo, diag, n, (T*)a, lda, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_strtri)(uplo, diag, n, (T*)a, lda, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dtrtri)(uplo, diag, n, (T*)a, lda, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_ctrtri)(uplo, diag, n, (T*)a, lda, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_ztrtri)(uplo, diag, n, (T*)a, lda, info); } + #endif + } + + + + template + inline + void + geev(char* jobvl, char* jobvr, blas_int* n, eT* a, blas_int* lda, eT* wr, eT* wi, eT* vl, blas_int* ldvl, eT* vr, blas_int* ldvr, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgeev)(jobvl, jobvr, n, (T*)a, lda, (T*)wr, (T*)wi, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgeev)(jobvl, jobvr, n, (T*)a, lda, (T*)wr, (T*)wi, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgeev)(jobvl, jobvr, n, (T*)a, lda, (T*)wr, (T*)wi, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgeev)(jobvl, jobvr, n, (T*)a, lda, (T*)wr, (T*)wi, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info); } + #endif + } + + + + template + inline + void + cx_geev(char* jobvl, char* jobvr, blas_int* n, eT* a, blas_int* lda, eT* w, eT* vl, blas_int* ldvl, eT* vr, blas_int* ldvr, eT* work, blas_int* lwork, typename eT::value_type* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgeev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)w, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info, 1, 1); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgeev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)w, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info, 1, 1); } + #else + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgeev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)w, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgeev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)w, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info); } + #endif + } + + + + template + inline + void + geevx(char* balanc, char* jobvl, char* jobvr, char* sense, blas_int* n, eT* a, blas_int* lda, eT* wr, eT* wi, eT* vl, blas_int* ldvl, eT* vr, blas_int* ldvr, blas_int* ilo, blas_int* ihi, eT* scale, eT* abnrm, eT* rconde, eT* rcondv, eT* work, blas_int* lwork, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgeevx)(balanc, jobvl, jobvr, sense, n, (T*)(a), lda, (T*)(wr), (T*)(wi), (T*)(vl), ldvl, (T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (T*)(work), lwork, iwork, info, 1, 1, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgeevx)(balanc, jobvl, jobvr, sense, n, (T*)(a), lda, (T*)(wr), (T*)(wi), (T*)(vl), ldvl, (T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (T*)(work), lwork, iwork, info, 1, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgeevx)(balanc, jobvl, jobvr, sense, n, (T*)(a), lda, (T*)(wr), (T*)(wi), (T*)(vl), ldvl, (T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (T*)(work), lwork, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgeevx)(balanc, jobvl, jobvr, sense, n, (T*)(a), lda, (T*)(wr), (T*)(wi), (T*)(vl), ldvl, (T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (T*)(work), lwork, iwork, info); } + #endif + } + + + + template + inline + void + cx_geevx(char* balanc, char* jobvl, char* jobvr, char* sense, blas_int* n, eT* a, blas_int* lda, eT* w, eT* vl, blas_int* ldvl, eT* vr, blas_int* ldvr, blas_int* ilo, blas_int* ihi, typename eT::value_type* scale, typename eT::value_type* abnrm, typename eT::value_type* rconde, typename eT::value_type* rcondv, eT* work, blas_int* lwork, typename eT::value_type* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgeevx)(balanc, jobvl, jobvr, sense, n, (cx_T*)(a), lda, (cx_T*)(w), (cx_T*)(vl), ldvl, (cx_T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (cx_T*)(work), lwork, (T*)(rwork), info, 1, 1, 1, 1); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgeevx)(balanc, jobvl, jobvr, sense, n, (cx_T*)(a), lda, (cx_T*)(w), (cx_T*)(vl), ldvl, (cx_T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (cx_T*)(work), lwork, (T*)(rwork), info, 1, 1, 1, 1); } + #else + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgeevx)(balanc, jobvl, jobvr, sense, n, (cx_T*)(a), lda, (cx_T*)(w), (cx_T*)(vl), ldvl, (cx_T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (cx_T*)(work), lwork, (T*)(rwork), info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgeevx)(balanc, jobvl, jobvr, sense, n, (cx_T*)(a), lda, (cx_T*)(w), (cx_T*)(vl), ldvl, (cx_T*)(vr), ldvr, ilo, ihi, (T*)(scale), (T*)(abnrm), (T*)(rconde), (T*)(rcondv), (cx_T*)(work), lwork, (T*)(rwork), info); } + #endif + } + + + + template + inline + void + syev(char* jobz, char* uplo, blas_int* n, eT* a, blas_int* lda, eT* w, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssyev)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsyev)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssyev)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsyev)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, info); } + #endif + } + + + + template + inline + void + heev + ( + char* jobz, char* uplo, blas_int* n, + eT* a, blas_int* lda, typename eT::value_type* w, + eT* work, blas_int* lwork, typename eT::value_type* rwork, + blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cheev)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, info, 1, 1); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zheev)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, info, 1, 1); } + #else + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cheev)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zheev)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, info); } + #endif + } + + + + template + inline + void + syevd(char* jobz, char* uplo, blas_int* n, eT* a, blas_int* lda, eT* w, eT* work, blas_int* lwork, blas_int* iwork, blas_int* liwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssyevd)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, iwork, liwork, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsyevd)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, iwork, liwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssyevd)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, iwork, liwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsyevd)(jobz, uplo, n, (T*)a, lda, (T*)w, (T*)work, lwork, iwork, liwork, info); } + #endif + } + + + + template + inline + void + heevd + ( + char* jobz, char* uplo, blas_int* n, + eT* a, blas_int* lda, typename eT::value_type* w, + eT* work, blas_int* lwork, typename eT::value_type* rwork, + blas_int* lrwork, blas_int* iwork, blas_int* liwork, + blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cheevd)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, lrwork, iwork, liwork, info, 1, 1); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zheevd)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, lrwork, iwork, liwork, info, 1, 1); } + #else + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cheevd)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, lrwork, iwork, liwork, info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zheevd)(jobz, uplo, n, (cx_T*)a, lda, (T*)w, (cx_T*)work, lwork, (T*)rwork, lrwork, iwork, liwork, info); } + #endif + } + + + + template + inline + void + ggev + ( + char* jobvl, char* jobvr, blas_int* n, + eT* a, blas_int* lda, eT* b, blas_int* ldb, + eT* alphar, eT* alphai, eT* beta, + eT* vl, blas_int* ldvl, eT* vr, blas_int* ldvr, + eT* work, blas_int* lwork, + blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sggev)(jobvl, jobvr, n, (T*)a, lda, (T*)b, ldb, (T*)alphar, (T*)alphai, (T*)beta, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dggev)(jobvl, jobvr, n, (T*)a, lda, (T*)b, ldb, (T*)alphar, (T*)alphai, (T*)beta, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sggev)(jobvl, jobvr, n, (T*)a, lda, (T*)b, ldb, (T*)alphar, (T*)alphai, (T*)beta, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dggev)(jobvl, jobvr, n, (T*)a, lda, (T*)b, ldb, (T*)alphar, (T*)alphai, (T*)beta, (T*)vl, ldvl, (T*)vr, ldvr, (T*)work, lwork, info); } + #endif + } + + + + template + inline + void + cx_ggev + ( + char* jobvl, char* jobvr, blas_int* n, + eT* a, blas_int* lda, eT* b, blas_int* ldb, + eT* alpha, eT* beta, + eT* vl, blas_int* ldvl, eT* vr, blas_int* ldvr, + eT* work, blas_int* lwork, typename eT::value_type* rwork, + blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cggev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)b, ldb, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info, 1, 1); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zggev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)b, ldb, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info, 1, 1); } + #else + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cggev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)b, ldb, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zggev)(jobvl, jobvr, n, (cx_T*)a, lda, (cx_T*)b, ldb, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vl, ldvl, (cx_T*)vr, ldvr, (cx_T*)work, lwork, (T*)rwork, info); } + #endif + } + + + + template + inline + void + potrf(char* uplo, blas_int* n, eT* a, blas_int* lda, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_spotrf)(uplo, n, (T*)a, lda, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpotrf)(uplo, n, (T*)a, lda, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpotrf)(uplo, n, (T*)a, lda, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpotrf)(uplo, n, (T*)a, lda, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_spotrf)(uplo, n, (T*)a, lda, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpotrf)(uplo, n, (T*)a, lda, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpotrf)(uplo, n, (T*)a, lda, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpotrf)(uplo, n, (T*)a, lda, info); } + #endif + } + + + + template + inline + void + potrs(char* uplo, blas_int* n, const blas_int* nrhs, eT* a, blas_int* lda, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_spotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_spotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpotrs)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + pbtrf(char* uplo, blas_int* n, blas_int* kd, eT* ab, blas_int* ldab, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_spbtrf)(uplo, n, kd, (T*)ab, ldab, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpbtrf)(uplo, n, kd, (T*)ab, ldab, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpbtrf)(uplo, n, kd, (T*)ab, ldab, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpbtrf)(uplo, n, kd, (T*)ab, ldab, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_spbtrf)(uplo, n, kd, (T*)ab, ldab, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpbtrf)(uplo, n, kd, (T*)ab, ldab, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpbtrf)(uplo, n, kd, (T*)ab, ldab, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpbtrf)(uplo, n, kd, (T*)ab, ldab, info); } + #endif + } + + + + template + inline + void + potri(char* uplo, blas_int* n, eT* a, blas_int* lda, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_spotri)(uplo, n, (T*)a, lda, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpotri)(uplo, n, (T*)a, lda, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpotri)(uplo, n, (T*)a, lda, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpotri)(uplo, n, (T*)a, lda, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_spotri)(uplo, n, (T*)a, lda, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dpotri)(uplo, n, (T*)a, lda, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cpotri)(uplo, n, (T*)a, lda, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zpotri)(uplo, n, (T*)a, lda, info); } + #endif + } + + + + template + inline + void + geqrf(blas_int* m, blas_int* n, eT* a, blas_int* lda, eT* tau, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgeqrf)(m, n, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgeqrf)(m, n, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgeqrf)(m, n, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgeqrf)(m, n, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + } + + + + template + inline + void + geqp3(blas_int* m, blas_int* n, eT* a, blas_int* lda, blas_int* jpvt, eT* tau, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgeqp3)(m, n, (T*)a, lda, jpvt, (T*)tau, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgeqp3)(m, n, (T*)a, lda, jpvt, (T*)tau, (T*)work, lwork, info); } + } + + + + template + inline + void + cx_geqp3(blas_int* m, blas_int* n, eT* a, blas_int* lda, blas_int* jpvt, eT* tau, eT* work, blas_int* lwork, typename eT::value_type* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgeqp3)(m, n, (cx_T*)a, lda, jpvt, (cx_T*)tau, (cx_T*)work, lwork, (T*)rwork, info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgeqp3)(m, n, (cx_T*)a, lda, jpvt, (cx_T*)tau, (cx_T*)work, lwork, (T*)rwork, info); } + } + + + + template + inline + void + orgqr(blas_int* m, blas_int* n, blas_int* k, eT* a, blas_int* lda, eT* tau, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sorgqr)(m, n, k, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dorgqr)(m, n, k, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + } + + + + template + inline + void + ungqr(blas_int* m, blas_int* n, blas_int* k, eT* a, blas_int* lda, eT* tau, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cungqr)(m, n, k, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zungqr)(m, n, k, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } + } + + + + template + inline + void + gesvd + ( + char* jobu, char* jobvt, blas_int* m, blas_int* n, eT* a, blas_int* lda, + eT* s, eT* u, blas_int* ldu, eT* vt, blas_int* ldvt, + eT* work, blas_int* lwork, blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgesvd)(jobu, jobvt, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgesvd)(jobu, jobvt, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgesvd)(jobu, jobvt, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgesvd)(jobu, jobvt, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, info); } + #endif + } + + + + template + inline + void + cx_gesvd + ( + char* jobu, char* jobvt, blas_int* m, blas_int* n, std::complex* a, blas_int* lda, + T* s, std::complex* u, blas_int* ldu, std::complex* vt, blas_int* ldvt, + std::complex* work, blas_int* lwork, T* rwork, blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + arma_type_check(( is_supported_blas_type< std::complex >::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float bT; typedef blas_cxf cx_bT; arma_fortran(arma_cgesvd)(jobu, jobvt, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, info, 1, 1); } + else if(is_double::value) { typedef double bT; typedef blas_cxd cx_bT; arma_fortran(arma_zgesvd)(jobu, jobvt, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, info, 1, 1); } + #else + if( is_float::value) { typedef float bT; typedef blas_cxf cx_bT; arma_fortran(arma_cgesvd)(jobu, jobvt, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, info); } + else if(is_double::value) { typedef double bT; typedef blas_cxd cx_bT; arma_fortran(arma_zgesvd)(jobu, jobvt, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, info); } + #endif + } + + + + template + inline + void + gesdd + ( + char* jobz, blas_int* m, blas_int* n, + eT* a, blas_int* lda, eT* s, eT* u, blas_int* ldu, eT* vt, blas_int* ldvt, + eT* work, blas_int* lwork, blas_int* iwork, blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgesdd)(jobz, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, iwork, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgesdd)(jobz, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, iwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgesdd)(jobz, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgesdd)(jobz, m, n, (T*)a, lda, (T*)s, (T*)u, ldu, (T*)vt, ldvt, (T*)work, lwork, iwork, info); } + #endif + } + + + + template + inline + void + cx_gesdd + ( + char* jobz, blas_int* m, blas_int* n, + std::complex* a, blas_int* lda, T* s, std::complex* u, blas_int* ldu, std::complex* vt, blas_int* ldvt, + std::complex* work, blas_int* lwork, T* rwork, blas_int* iwork, blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + arma_type_check(( is_supported_blas_type< std::complex >::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float bT; typedef blas_cxf cx_bT; arma_fortran(arma_cgesdd)(jobz, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, iwork, info, 1); } + else if(is_double::value) { typedef double bT; typedef blas_cxd cx_bT; arma_fortran(arma_zgesdd)(jobz, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, iwork, info, 1); } + #else + if( is_float::value) { typedef float bT; typedef blas_cxf cx_bT; arma_fortran(arma_cgesdd)(jobz, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, iwork, info); } + else if(is_double::value) { typedef double bT; typedef blas_cxd cx_bT; arma_fortran(arma_zgesdd)(jobz, m, n, (cx_bT*)a, lda, (bT*)s, (cx_bT*)u, ldu, (cx_bT*)vt, ldvt, (cx_bT*)work, lwork, (bT*)rwork, iwork, info); } + #endif + } + + + + template + inline + void + gesv(blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, blas_int* ipiv, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgesv)(n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgesv)(n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgesv)(n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgesv)(n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + } + + + + template + inline + void + gesvx(char* fact, char* trans, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, eT* af, blas_int* ldaf, blas_int* ipiv, char* equed, eT* r, eT* c, eT* b, blas_int* ldb, eT* x, blas_int* ldx, eT* rcond, eT* ferr, eT* berr, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgesvx)(fact, trans, n, nrhs, (T*)a, lda, (T*)af, ldaf, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgesvx)(fact, trans, n, nrhs, (T*)a, lda, (T*)af, ldaf, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgesvx)(fact, trans, n, nrhs, (T*)a, lda, (T*)af, ldaf, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgesvx)(fact, trans, n, nrhs, (T*)a, lda, (T*)af, ldaf, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_gesvx(char* fact, char* trans, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, eT* af, blas_int* ldaf, blas_int* ipiv, char* equed, T* r, T* c, eT* b, blas_int* ldb, eT* x, blas_int* ldx, T* rcond, T* ferr, T* berr, eT* work, T* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgesvx)(fact, trans, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgesvx)(fact, trans, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgesvx)(fact, trans, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgesvx)(fact, trans, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + template + inline + void + posv(char* uplo, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zposv)(uplo, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + posvx(char* fact, char* uplo, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, eT* af, blas_int* ldaf, char* equed, eT* s, eT* b, blas_int* ldb, eT* x, blas_int* ldx, eT* rcond, eT* ferr, eT* berr, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sposvx)(fact, uplo, n, nrhs, (T*)a, lda, (T*)af, ldaf, equed, (T*)s, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dposvx)(fact, uplo, n, nrhs, (T*)a, lda, (T*)af, ldaf, equed, (T*)s, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sposvx)(fact, uplo, n, nrhs, (T*)a, lda, (T*)af, ldaf, equed, (T*)s, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dposvx)(fact, uplo, n, nrhs, (T*)a, lda, (T*)af, ldaf, equed, (T*)s, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_posvx(char* fact, char* uplo, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, eT* af, blas_int* ldaf, char* equed, T* s, eT* b, blas_int* ldb, eT* x, blas_int* ldx, T* rcond, T* ferr, T* berr, eT* work, T* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cposvx)(fact, uplo, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, equed, (pod_T*)s, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zposvx)(fact, uplo, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, equed, (pod_T*)s, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cposvx)(fact, uplo, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, equed, (pod_T*)s, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zposvx)(fact, uplo, n, nrhs, (cx_T*)a, lda, (cx_T*)af, ldaf, equed, (pod_T*)s, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + template + inline + void + gels(char* trans, blas_int* m, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, eT* b, blas_int* ldb, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgels)(trans, m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)work, lwork, info); } + #endif + } + + + + template + inline + void + gelsd(blas_int* m, blas_int* n, blas_int* nrhs, eT* a, blas_int* lda, eT* b, blas_int* ldb, eT* S, eT* rcond, blas_int* rank, eT* work, blas_int* lwork, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgelsd)(m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)S, (T*)rcond, rank, (T*)work, lwork, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgelsd)(m, n, nrhs, (T*)a, lda, (T*)b, ldb, (T*)S, (T*)rcond, rank, (T*)work, lwork, iwork, info); } + } + + + + template + inline + void + cx_gelsd(blas_int* m, blas_int* n, blas_int* nrhs, std::complex* a, blas_int* lda, std::complex* b, blas_int* ldb, T* S, T* rcond, blas_int* rank, std::complex* work, blas_int* lwork, T* rwork, blas_int* iwork, blas_int* info) + { + typedef typename std::complex eT; + + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgelsd)(m, n, nrhs, (cx_T*)a, lda, (cx_T*)b, ldb, (pod_T*)S, (pod_T*)rcond, rank, (cx_T*)work, lwork, (pod_T*)rwork, iwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgelsd)(m, n, nrhs, (cx_T*)a, lda, (cx_T*)b, ldb, (pod_T*)S, (pod_T*)rcond, rank, (cx_T*)work, lwork, (pod_T*)rwork, iwork, info); } + } + + + + template + inline + void + trtrs(char* uplo, char* trans, char* diag, blas_int* n, blas_int* nrhs, const eT* a, blas_int* lda, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_strtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1, 1, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dtrtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1, 1, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_ctrtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1, 1, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_ztrtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_strtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dtrtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_ctrtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_ztrtrs)(uplo, trans, diag, n, nrhs, (T*)a, lda, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + gbtrf(blas_int* m, blas_int* n, blas_int* kl, blas_int* ku, eT* ab, blas_int* ldab, blas_int* ipiv, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgbtrf)(m, n, kl, ku, (T*)ab, ldab, ipiv, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgbtrf)(m, n, kl, ku, (T*)ab, ldab, ipiv, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgbtrf)(m, n, kl, ku, (T*)ab, ldab, ipiv, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgbtrf)(m, n, kl, ku, (T*)ab, ldab, ipiv, info); } + } + + + + template + inline + void + gbtrs(char* trans, blas_int* n, blas_int* kl, blas_int* ku, blas_int* nrhs, eT* ab, blas_int* ldab, blas_int* ipiv, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgbtrs)(trans, n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + gbsv(blas_int* n, blas_int* kl, blas_int* ku, blas_int* nrhs, eT* ab, blas_int* ldab, blas_int* ipiv, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgbsv)(n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgbsv)(n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgbsv)(n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgbsv)(n, kl, ku, nrhs, (T*)ab, ldab, ipiv, (T*)b, ldb, info); } + } + + + + template + inline + void + gbsvx(char* fact, char* trans, blas_int* n, blas_int* kl, blas_int* ku, blas_int* nrhs, eT* ab, blas_int* ldab, eT* afb, blas_int* ldafb, blas_int* ipiv, char* equed, eT* r, eT* c, eT* b, blas_int* ldb, eT* x, blas_int* ldx, eT* rcond, eT* ferr, eT* berr, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgbsvx)(fact, trans, n, kl, ku, nrhs, (T*)ab, ldab, (T*)afb, ldafb, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgbsvx)(fact, trans, n, kl, ku, nrhs, (T*)ab, ldab, (T*)afb, ldafb, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgbsvx)(fact, trans, n, kl, ku, nrhs, (T*)ab, ldab, (T*)afb, ldafb, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgbsvx)(fact, trans, n, kl, ku, nrhs, (T*)ab, ldab, (T*)afb, ldafb, ipiv, equed, (T*)r, (T*)c, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_gbsvx(char* fact, char* trans, blas_int* n, blas_int* kl, blas_int* ku, blas_int* nrhs, eT* ab, blas_int* ldab, eT* afb, blas_int* ldafb, blas_int* ipiv, char* equed, T* r, T* c, eT* b, blas_int* ldb, eT* x, blas_int* ldx, T* rcond, T* ferr, T* berr, eT* work, T* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgbsvx)(fact, trans, n, kl, ku, nrhs, (cx_T*)ab, ldab, (cx_T*)afb, ldafb, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgbsvx)(fact, trans, n, kl, ku, nrhs, (cx_T*)ab, ldab, (cx_T*)afb, ldafb, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgbsvx)(fact, trans, n, kl, ku, nrhs, (cx_T*)ab, ldab, (cx_T*)afb, ldafb, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgbsvx)(fact, trans, n, kl, ku, nrhs, (cx_T*)ab, ldab, (cx_T*)afb, ldafb, ipiv, equed, (pod_T*)r, (pod_T*)c, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + template + inline + void + gtsv(blas_int* n, blas_int* nrhs, eT* dl, eT* d, eT* du, eT* b, blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgtsv)(n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)b, ldb, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgtsv)(n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)b, ldb, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgtsv)(n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgtsv)(n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)b, ldb, info); } + } + + + + template + inline + void + gtsvx(char* fact, char* trans, blas_int* n, blas_int* nrhs, eT* dl, eT* d, eT* du, eT* dlf, eT* df, eT* duf, eT* du2, blas_int* ipiv, eT* b, blas_int* ldb, eT* x, blas_int* ldx, eT* rcond, eT* ferr, eT* berr, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgtsvx)(fact, trans, n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)dlf, (T*)df, (T*)duf, (T*)du2, ipiv, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgtsvx)(fact, trans, n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)dlf, (T*)df, (T*)duf, (T*)du2, ipiv, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgtsvx)(fact, trans, n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)dlf, (T*)df, (T*)duf, (T*)du2, ipiv, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgtsvx)(fact, trans, n, nrhs, (T*)dl, (T*)d, (T*)du, (T*)dlf, (T*)df, (T*)duf, (T*)du2, ipiv, (T*)b, ldb, (T*)x, ldx, (T*)rcond, (T*)ferr, (T*)berr, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_gtsvx(char* fact, char* trans, blas_int* n, blas_int* nrhs, eT* dl, eT* d, eT* du, eT* dlf, eT* df, eT* duf, eT* du2, blas_int* ipiv, eT* b, blas_int* ldb, eT* x, blas_int* ldx, T* rcond, T* ferr, T* berr, eT* work, T* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgtsvx)(fact, trans, n, nrhs, (cx_T*)dl, (cx_T*)d, (cx_T*)du, (cx_T*)dlf, (cx_T*)df, (cx_T*)duf, (cx_T*)du2, ipiv, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgtsvx)(fact, trans, n, nrhs, (cx_T*)dl, (cx_T*)d, (cx_T*)du, (cx_T*)dlf, (cx_T*)df, (cx_T*)duf, (cx_T*)du2, ipiv, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info, 1, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgtsvx)(fact, trans, n, nrhs, (cx_T*)dl, (cx_T*)d, (cx_T*)du, (cx_T*)dlf, (cx_T*)df, (cx_T*)duf, (cx_T*)du2, ipiv, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgtsvx)(fact, trans, n, nrhs, (cx_T*)dl, (cx_T*)d, (cx_T*)du, (cx_T*)dlf, (cx_T*)df, (cx_T*)duf, (cx_T*)du2, ipiv, (cx_T*)b, ldb, (cx_T*)x, ldx, (pod_T*)rcond, (pod_T*)ferr, (pod_T*)berr, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + template + inline + void + gees(char* jobvs, char* sort, void* select, blas_int* n, eT* a, blas_int* lda, blas_int* sdim, eT* wr, eT* wi, eT* vs, blas_int* ldvs, eT* work, blas_int* lwork, blas_int* bwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgees)(jobvs, sort, (fn_select_s2)select, n, (T*)a, lda, sdim, (T*)wr, (T*)wi, (T*)vs, ldvs, (T*)work, lwork, bwork, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgees)(jobvs, sort, (fn_select_d2)select, n, (T*)a, lda, sdim, (T*)wr, (T*)wi, (T*)vs, ldvs, (T*)work, lwork, bwork, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgees)(jobvs, sort, (fn_select_s2)select, n, (T*)a, lda, sdim, (T*)wr, (T*)wi, (T*)vs, ldvs, (T*)work, lwork, bwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgees)(jobvs, sort, (fn_select_d2)select, n, (T*)a, lda, sdim, (T*)wr, (T*)wi, (T*)vs, ldvs, (T*)work, lwork, bwork, info); } + #endif + } + + + + template + inline + void + cx_gees(char* jobvs, char* sort, void* select, blas_int* n, std::complex* a, blas_int* lda, blas_int* sdim, std::complex* w, std::complex* vs, blas_int* ldvs, std::complex* work, blas_int* lwork, T* rwork, blas_int* bwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + arma_type_check(( is_supported_blas_type< std::complex >::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float bT; typedef blas_cxf cT; arma_fortran(arma_cgees)(jobvs, sort, (fn_select_c1)select, n, (cT*)a, lda, sdim, (cT*)w, (cT*)vs, ldvs, (cT*)work, lwork, (bT*)rwork, bwork, info, 1, 1); } + else if(is_double::value) { typedef double bT; typedef blas_cxd cT; arma_fortran(arma_zgees)(jobvs, sort, (fn_select_z1)select, n, (cT*)a, lda, sdim, (cT*)w, (cT*)vs, ldvs, (cT*)work, lwork, (bT*)rwork, bwork, info, 1, 1); } + #else + if( is_float::value) { typedef float bT; typedef blas_cxf cT; arma_fortran(arma_cgees)(jobvs, sort, (fn_select_c1)select, n, (cT*)a, lda, sdim, (cT*)w, (cT*)vs, ldvs, (cT*)work, lwork, (bT*)rwork, bwork, info); } + else if(is_double::value) { typedef double bT; typedef blas_cxd cT; arma_fortran(arma_zgees)(jobvs, sort, (fn_select_z1)select, n, (cT*)a, lda, sdim, (cT*)w, (cT*)vs, ldvs, (cT*)work, lwork, (bT*)rwork, bwork, info); } + #endif + } + + + + template + inline + void + trsyl(char* transa, char* transb, blas_int* isgn, blas_int* m, blas_int* n, const eT* a, blas_int* lda, const eT* b, blas_int* ldb, eT* c, blas_int* ldc, eT* scale, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_strsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (T*)scale, info, 1, 1); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dtrsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (T*)scale, info, 1, 1); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_ctrsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (float*)scale, info, 1, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_ztrsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (double*)scale, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_strsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (T*)scale, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dtrsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (T*)scale, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_ctrsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (float*)scale, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_ztrsyl)(transa, transb, isgn, m, n, (T*)a, lda, (T*)b, ldb, (T*)c, ldc, (double*)scale, info); } + #endif + } + + + + template + inline + void + gges + ( + char* jobvsl, char* jobvsr, char* sort, void* selctg, blas_int* n, + eT* a, blas_int* lda, eT* b, blas_int* ldb, blas_int* sdim, + eT* alphar, eT* alphai, eT* beta, + eT* vsl, blas_int* ldvsl, eT* vsr, blas_int* ldvsr, + eT* work, blas_int* lwork, + blas_int* bwork, + blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgges)(jobvsl, jobvsr, sort, (fn_select_s3)selctg, n, (T*)a, lda, (T*)b, ldb, sdim, (T*)alphar, (T*)alphai, (T*)beta, (T*)vsl, ldvsl, (T*)vsr, ldvsr, (T*)work, lwork, bwork, info, 1, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgges)(jobvsl, jobvsr, sort, (fn_select_d3)selctg, n, (T*)a, lda, (T*)b, ldb, sdim, (T*)alphar, (T*)alphai, (T*)beta, (T*)vsl, ldvsl, (T*)vsr, ldvsr, (T*)work, lwork, bwork, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgges)(jobvsl, jobvsr, sort, (fn_select_s3)selctg, n, (T*)a, lda, (T*)b, ldb, sdim, (T*)alphar, (T*)alphai, (T*)beta, (T*)vsl, ldvsl, (T*)vsr, ldvsr, (T*)work, lwork, bwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgges)(jobvsl, jobvsr, sort, (fn_select_d3)selctg, n, (T*)a, lda, (T*)b, ldb, sdim, (T*)alphar, (T*)alphai, (T*)beta, (T*)vsl, ldvsl, (T*)vsr, ldvsr, (T*)work, lwork, bwork, info); } + #endif + } + + + + template + inline + void + cx_gges + ( + char* jobvsl, char* jobvsr, char* sort, void* selctg, blas_int* n, + eT* a, blas_int* lda, eT* b, blas_int* ldb, blas_int* sdim, + eT* alpha, eT* beta, + eT* vsl, blas_int* ldvsl, eT* vsr, blas_int* ldvsr, + eT* work, blas_int* lwork, typename eT::value_type* rwork, + blas_int* bwork, + blas_int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgges)(jobvsl, jobvsr, sort, (fn_select_c2)selctg, n, (cx_T*)a, lda, (cx_T*)b, ldb, sdim, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vsl, ldvsl, (cx_T*)vsr, ldvsr, (cx_T*)work, lwork, (T*)rwork, bwork, info, 1, 1, 1); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgges)(jobvsl, jobvsr, sort, (fn_select_z2)selctg, n, (cx_T*)a, lda, (cx_T*)b, ldb, sdim, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vsl, ldvsl, (cx_T*)vsr, ldvsr, (cx_T*)work, lwork, (T*)rwork, bwork, info, 1, 1, 1); } + #else + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgges)(jobvsl, jobvsr, sort, (fn_select_c2)selctg, n, (cx_T*)a, lda, (cx_T*)b, ldb, sdim, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vsl, ldvsl, (cx_T*)vsr, ldvsr, (cx_T*)work, lwork, (T*)rwork, bwork, info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgges)(jobvsl, jobvsr, sort, (fn_select_z2)selctg, n, (cx_T*)a, lda, (cx_T*)b, ldb, sdim, (cx_T*)alpha, (cx_T*)beta, (cx_T*)vsl, ldvsl, (cx_T*)vsr, ldvsr, (cx_T*)work, lwork, (T*)rwork, bwork, info); } + #endif + } + + + + template + inline + typename get_pod_type::result + lange(char* norm, blas_int* m, blas_int* n, eT* a, blas_int* lda, typename get_pod_type::result* work) + { + arma_type_check(( is_supported_blas_type::value == false )); + + typedef typename get_pod_type::result out_T; + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float pod_T; typedef float T; return out_T( arma_fortran(arma_slange)(norm, m, n, (T*)a, lda, (pod_T*)work, 1) ); } + else if( is_double::value) { typedef double pod_T; typedef double T; return out_T( arma_fortran(arma_dlange)(norm, m, n, (T*)a, lda, (pod_T*)work, 1) ); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clange)(norm, m, n, (T*)a, lda, (pod_T*)work, 1) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlange)(norm, m, n, (T*)a, lda, (pod_T*)work, 1) ); } + #else + if( is_float::value) { typedef float pod_T; typedef float T; return out_T( arma_fortran(arma_slange)(norm, m, n, (T*)a, lda, (pod_T*)work) ); } + else if( is_double::value) { typedef double pod_T; typedef double T; return out_T( arma_fortran(arma_dlange)(norm, m, n, (T*)a, lda, (pod_T*)work) ); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clange)(norm, m, n, (T*)a, lda, (pod_T*)work) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlange)(norm, m, n, (T*)a, lda, (pod_T*)work) ); } + #endif + + return out_T(0); + } + + + + template + inline + typename get_pod_type::result + lansy(char* norm, char* uplo, blas_int* n, eT* a, blas_int* lda, typename get_pod_type::result* work) + { + arma_type_check(( is_supported_blas_type::value == false )); + + typedef typename get_pod_type::result out_T; + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float pod_T; typedef float T; return out_T( arma_fortran(arma_slansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work, 1, 1) ); } + else if( is_double::value) { typedef double pod_T; typedef double T; return out_T( arma_fortran(arma_dlansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work, 1, 1) ); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work, 1, 1) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work, 1, 1) ); } + #else + if( is_float::value) { typedef float pod_T; typedef float T; return out_T( arma_fortran(arma_slansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work) ); } + else if( is_double::value) { typedef double pod_T; typedef double T; return out_T( arma_fortran(arma_dlansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work) ); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlansy)(norm, uplo, n, (T*)a, lda, (pod_T*)work) ); } + #endif + + return out_T(0); + } + + + + template + inline + typename get_pod_type::result + lanhe(char* norm, char* uplo, blas_int* n, eT* a, blas_int* lda, typename get_pod_type::result* work) + { + arma_type_check(( is_supported_blas_type::value == false )); + + typedef typename get_pod_type::result out_T; + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clanhe)(norm, uplo, n, (T*)a, lda, (pod_T*)work, 1, 1) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlanhe)(norm, uplo, n, (T*)a, lda, (pod_T*)work, 1, 1) ); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clanhe)(norm, uplo, n, (T*)a, lda, (pod_T*)work) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlanhe)(norm, uplo, n, (T*)a, lda, (pod_T*)work) ); } + #endif + + return out_T(0); + } + + + + template + inline + typename get_pod_type::result + langb(char* norm, blas_int* n, blas_int* kl, blas_int* ku, eT* ab, blas_int* ldab, typename get_pod_type::result* work) + { + arma_type_check(( is_supported_blas_type::value == false )); + + typedef typename get_pod_type::result out_T; + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float pod_T; typedef float T; return out_T( arma_fortran(arma_slangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work, 1) ); } + else if( is_double::value) { typedef double pod_T; typedef double T; return out_T( arma_fortran(arma_dlangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work, 1) ); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work, 1) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work, 1) ); } + #else + if( is_float::value) { typedef float pod_T; typedef float T; return out_T( arma_fortran(arma_slangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work) ); } + else if( is_double::value) { typedef double pod_T; typedef double T; return out_T( arma_fortran(arma_dlangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work) ); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; return out_T( arma_fortran(arma_clangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work) ); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; return out_T( arma_fortran(arma_zlangb)(norm, n, kl, ku, (T*)ab, ldab, (pod_T*)work) ); } + #endif + + return out_T(0); + } + + + + template + inline + void + gecon(char* norm, blas_int* n, const eT* a, blas_int* lda, const eT* anorm, eT* rcond, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgecon)(norm, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgecon)(norm, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgecon)(norm, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgecon)(norm, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_gecon(char* norm, blas_int* n, const std::complex* a, blas_int* lda, const T* anorm, T* rcond, std::complex* work, T* rwork, blas_int* info) + { + typedef typename std::complex eT; + + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgecon)(norm, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgecon)(norm, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgecon)(norm, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgecon)(norm, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + template + inline + void + pocon(char* uplo, blas_int* n, const eT* a, blas_int* lda, const eT* anorm, eT* rcond, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_spocon)(uplo, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dpocon)(uplo, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_spocon)(uplo, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dpocon)(uplo, n, (T*)a, lda, (T*)anorm, (T*)rcond, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_pocon(char* uplo, blas_int* n, const std::complex* a, blas_int* lda, const T* anorm, T* rcond, std::complex* work, T* rwork, blas_int* info) + { + typedef typename std::complex eT; + + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cpocon)(uplo, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zpocon)(uplo, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cpocon)(uplo, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zpocon)(uplo, n, (cx_T*)a, lda, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + template + inline + void + trcon(char* norm, char* uplo, char* diag, blas_int* n, const eT* a, blas_int* lda, eT* rcond, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_strcon)(norm, uplo, diag, n, (T*)a, lda, (T*)rcond, (T*)work, iwork, info, 1, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dtrcon)(norm, uplo, diag, n, (T*)a, lda, (T*)rcond, (T*)work, iwork, info, 1, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_strcon)(norm, uplo, diag, n, (T*)a, lda, (T*)rcond, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dtrcon)(norm, uplo, diag, n, (T*)a, lda, (T*)rcond, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_trcon(char* norm, char* uplo, char* diag, blas_int* n, const std::complex* a, blas_int* lda, T* rcond, std::complex* work, T* rwork, blas_int* info) + { + typedef typename std::complex eT; + + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_ctrcon)(norm, uplo, diag, n, (cx_T*)a, lda, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_ztrcon)(norm, uplo, diag, n, (cx_T*)a, lda, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1, 1, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_ctrcon)(norm, uplo, diag, n, (cx_T*)a, lda, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_ztrcon)(norm, uplo, diag, n, (cx_T*)a, lda, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + template + inline + void + gbcon(char* norm, blas_int* n, blas_int* kl, blas_int* ku, const eT* ab, blas_int* ldab, const blas_int* ipiv, const eT* anorm, eT* rcond, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sgbcon)(norm, n, kl, ku, (T*)ab, ldab, ipiv, (T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgbcon)(norm, n, kl, ku, (T*)ab, ldab, ipiv, (T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sgbcon)(norm, n, kl, ku, (T*)ab, ldab, ipiv, (T*)anorm, (T*)rcond, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgbcon)(norm, n, kl, ku, (T*)ab, ldab, ipiv, (T*)anorm, (T*)rcond, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + cx_gbcon(char* norm, blas_int* n, blas_int* kl, blas_int* ku, const std::complex* ab, blas_int* ldab, const blas_int* ipiv, const T* anorm, T* rcond, std::complex* work, T* rwork, blas_int* info) + { + typedef typename std::complex eT; + + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgbcon)(norm, n, kl, ku, (cx_T*)ab, ldab, ipiv, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgbcon)(norm, n, kl, ku, (cx_T*)ab, ldab, ipiv, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info, 1); } + #else + if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_cgbcon)(norm, n, kl, ku, (cx_T*)ab, ldab, ipiv, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zgbcon)(norm, n, kl, ku, (cx_T*)ab, ldab, ipiv, (pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, (pod_T*)rwork, info); } + #endif + } + + + + inline + blas_int + laenv(blas_int* ispec, char* name, char* opts, blas_int* n1, blas_int* n2, blas_int* n3, blas_int* n4, blas_len name_len, blas_len opts_len) + { + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + return arma_fortran(arma_ilaenv)(ispec, name, opts, n1, n2, n3, n4, name_len, opts_len); + #else + arma_ignore(name_len); + arma_ignore(opts_len); + return arma_fortran(arma_ilaenv)(ispec, name, opts, n1, n2, n3, n4); // not advised! + #endif + } + + + + template + inline + void + lahqr(blas_int* wantt, blas_int* wantz, blas_int* n, blas_int* ilo, blas_int* ihi, eT* h, blas_int* ldh, eT* wr, eT* wi, blas_int* iloz, blas_int* ihiz, eT* z, blas_int* ldz, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_slahqr)(wantt, wantz, n, ilo, ihi, (T*)h, ldh, (T*)wr, (T*)wi, iloz, ihiz, (T*)z, ldz, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dlahqr)(wantt, wantz, n, ilo, ihi, (T*)h, ldh, (T*)wr, (T*)wi, iloz, ihiz, (T*)z, ldz, info); } + } + + + + template + inline + void + stedc(char* compz, blas_int* n, eT* d, eT* e, eT* z, blas_int* ldz, eT* work, blas_int* lwork, blas_int* iwork, blas_int* liwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_sstedc)(compz, n, (T*)d, (T*)e, (T*)z, ldz, (T*)work, lwork, iwork, liwork, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dstedc)(compz, n, (T*)d, (T*)e, (T*)z, ldz, (T*)work, lwork, iwork, liwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_sstedc)(compz, n, (T*)d, (T*)e, (T*)z, ldz, (T*)work, lwork, iwork, liwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dstedc)(compz, n, (T*)d, (T*)e, (T*)z, ldz, (T*)work, lwork, iwork, liwork, info); } + #endif + } + + + + template + inline + void + trevc(char* side, char* howmny, blas_int* select, blas_int* n, eT* t, blas_int* ldt, eT* vl, blas_int* ldvl, eT* vr, blas_int* ldvr, blas_int* mm, blas_int* m, eT* work, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_strevc)(side, howmny, select, n, (T*)t, ldt, (T*)vl, ldvl, (T*)vr, ldvr, mm, m, (T*)work, info, 1, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dtrevc)(side, howmny, select, n, (T*)t, ldt, (T*)vl, ldvl, (T*)vr, ldvr, mm, m, (T*)work, info, 1, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_strevc)(side, howmny, select, n, (T*)t, ldt, (T*)vl, ldvl, (T*)vr, ldvr, mm, m, (T*)work, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dtrevc)(side, howmny, select, n, (T*)t, ldt, (T*)vl, ldvl, (T*)vr, ldvr, mm, m, (T*)work, info); } + #endif + } + + + + template + inline + void + gehrd(blas_int* n, blas_int* ilo, blas_int* ihi, eT* a, blas_int* lda, eT* tao, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + } + + + + template + inline + void + pstrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* piv, blas_int* rank, const typename get_pod_type::result* tol, const typename get_pod_type::result* work, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float pod_T; typedef float T; arma_fortran(arma_spstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + else if( is_double::value) { typedef double pod_T; typedef double T; arma_fortran(arma_dpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; arma_fortran(arma_cpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; arma_fortran(arma_zpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + #else + if( is_float::value) { typedef float pod_T; typedef float T; arma_fortran(arma_spstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + else if( is_double::value) { typedef double pod_T; typedef double T; arma_fortran(arma_dpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; arma_fortran(arma_cpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; arma_fortran(arma_zpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + #endif + } + + + } + + +#endif diff --git a/src/armadillo/include/armadillo_bits/translate_superlu.hpp b/src/armadillo/include/armadillo_bits/translate_superlu.hpp new file mode 100644 index 0000000..a04f01e --- /dev/null +++ b/src/armadillo/include/armadillo_bits/translate_superlu.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +#if defined(ARMA_USE_SUPERLU) + +//! \namespace superlu namespace for SuperLU functions +namespace superlu + { + + template + inline + void + gssv(superlu_options_t* options, SuperMatrix* A, int* perm_c, int* perm_r, SuperMatrix* L, SuperMatrix* U, SuperMatrix* B, SuperLUStat_t* stat, int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + arma_wrapper(sgssv)(options, A, perm_c, perm_r, L, U, B, stat, info); + } + else + if(is_double::value) + { + arma_wrapper(dgssv)(options, A, perm_c, perm_r, L, U, B, stat, info); + } + else + if(is_cx_float::value) + { + arma_wrapper(cgssv)(options, A, perm_c, perm_r, L, U, B, stat, info); + } + else + if(is_cx_double::value) + { + arma_wrapper(zgssv)(options, A, perm_c, perm_r, L, U, B, stat, info); + } + } + + + + template + inline + void + gssvx( + superlu_options_t* opts, + SuperMatrix* A, + int* perm_c, int* perm_r, + int* etree, char* equed, + typename get_pod_type::result* R, typename get_pod_type::result* C, + SuperMatrix* L, SuperMatrix* U, + void* work, int lwork, + SuperMatrix* B, SuperMatrix* X, + typename get_pod_type::result* rpg, typename get_pod_type::result* rcond, + typename get_pod_type::result* ferr, typename get_pod_type::result* berr, + GlobalLU_t* glu, mem_usage_t* mu, SuperLUStat_t* stat, int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + typedef float T; + arma_wrapper(sgssvx)(opts, A, perm_c, perm_r, etree, equed, (T*)R, (T*)C, L, U, work, lwork, B, X, (T*)rpg, (T*)rcond, (T*)ferr, (T*)berr, glu, mu, stat, info); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(dgssvx)(opts, A, perm_c, perm_r, etree, equed, (T*)R, (T*)C, L, U, work, lwork, B, X, (T*)rpg, (T*)rcond, (T*)ferr, (T*)berr, glu, mu, stat, info); + } + else + if(is_cx_float::value) + { + typedef float T; + arma_wrapper(cgssvx)(opts, A, perm_c, perm_r, etree, equed, (T*)R, (T*)C, L, U, work, lwork, B, X, (T*)rpg, (T*)rcond, (T*)ferr, (T*)berr, glu, mu, stat, info); + } + else + if(is_cx_double::value) + { + typedef double T; + arma_wrapper(zgssvx)(opts, A, perm_c, perm_r, etree, equed, (T*)R, (T*)C, L, U, work, lwork, B, X, (T*)rpg, (T*)rcond, (T*)ferr, (T*)berr, glu, mu, stat, info); + } + } + + + + template + inline + void + gstrf(superlu_options_t* options, + SuperMatrix* A, + int relax, + int panel_size, int *etree, + void *work, int lwork, + int* perm_c, int* perm_r, + SuperMatrix* L, SuperMatrix* U, + GlobalLU_t* Glu, SuperLUStat_t* stat, int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + arma_wrapper(sgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + else + if(is_double::value) + { + arma_wrapper(dgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + else + if(is_cx_float::value) + { + arma_wrapper(cgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + else + if(is_cx_double::value) + { + arma_wrapper(zgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + } + + + + template + inline + void + gstrs(trans_t trans, + SuperMatrix* L, SuperMatrix* U, + int* perm_c, int* perm_r, + SuperMatrix* B, SuperLUStat_t* stat, int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + arma_wrapper(sgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + else + if(is_double::value) + { + arma_wrapper(dgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + else + if(is_cx_float::value) + { + arma_wrapper(cgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + else + if(is_cx_double::value) + { + arma_wrapper(zgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + } + + + + template + inline + typename get_pod_type::result + langs(char* norm, superlu::SuperMatrix* A) + { + arma_type_check(( is_supported_blas_type::value == false )); + + typedef typename get_pod_type::result T; + + if(is_float::value) + { + return arma_wrapper(slangs)(norm, A); + } + else + if(is_double::value) + { + return arma_wrapper(dlangs)(norm, A); + } + else + if(is_cx_float::value) + { + return arma_wrapper(clangs)(norm, A); + } + else + if(is_cx_double::value) + { + return arma_wrapper(zlangs)(norm, A); + } + + return T(0); // to avoid false warnigns from the compiler + } + + + + template + inline + void + gscon(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type::result anorm, typename get_pod_type::result* rcond, superlu::SuperLUStat_t* stat, int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + typedef float T; + arma_wrapper(sgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(dgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + else + if(is_cx_float::value) + { + typedef float T; + arma_wrapper(cgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + else + if(is_cx_double::value) + { + typedef double T; + arma_wrapper(zgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + } + + + + inline + void + init_stat(SuperLUStat_t* stat) + { + arma_wrapper(StatInit)(stat); + } + + + inline + void + free_stat(SuperLUStat_t* stat) + { + arma_wrapper(StatFree)(stat); + } + + + + inline + void + set_default_opts(superlu_options_t* opts) + { + arma_wrapper(set_default_options)(opts); + } + + + inline + void + get_permutation_c(int ispec, SuperMatrix* A, int* perm_c) + { + arma_wrapper(get_perm_c)(ispec, A, perm_c); + } + + + + inline + void + sp_preorder_mat(superlu_options_t* opts, SuperMatrix* A, int* perm_c, int* etree, SuperMatrix* AC) + { + arma_wrapper(sp_preorder)(opts, A, perm_c, etree, AC); + } + + + + inline + int + sp_ispec_environ(int ispec) + { + return arma_wrapper(sp_ienv)(ispec); + } + + + + inline + void + destroy_supernode_mat(SuperMatrix* a) + { + arma_wrapper(Destroy_SuperNode_Matrix)(a); + } + + + + inline + void + destroy_compcol_mat(SuperMatrix* a) + { + arma_wrapper(Destroy_CompCol_Matrix)(a); + } + + + + inline + void + destroy_compcolperm_mat(SuperMatrix* a) + { + arma_wrapper(Destroy_CompCol_Permuted)(a); + } + + + + inline + void + destroy_dense_mat(SuperMatrix* a) + { + arma_wrapper(Destroy_SuperMatrix_Store)(a); + } + + + + inline + void* + malloc(size_t N) + { + return arma_wrapper(superlu_malloc)(N); + } + + + + inline + void + free(void* mem) + { + arma_wrapper(superlu_free)(mem); + } + + } // namespace superlu + +#endif diff --git a/src/armadillo/include/armadillo_bits/trimat_helper.hpp b/src/armadillo/include/armadillo_bits/trimat_helper.hpp new file mode 100644 index 0000000..9242083 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/trimat_helper.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup trimat_helper +//! @{ + + +namespace trimat_helper +{ + + + +template +inline +bool +is_triu(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + + const uword N = A.n_rows; + const uword Nm1 = N-1; + + if(N < 2) { return false; } + + const eT* A_col = A.memptr(); + const eT eT_zero = eT(0); + + // quickly check element at bottom-left + + if(A_col[Nm1] != eT_zero) { return false; } + + // if we got to this point, do a thorough check + + for(uword j=0; j < Nm1; ++j) + { + for(uword i=(j+1); i < N; ++i) + { + const eT A_ij = A_col[i]; + + if(A_ij != eT_zero) { return false; } + } + + A_col += N; + } + + return true; + } + + + +template +inline +bool +is_tril(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + + const uword N = A.n_rows; + + if(N < 2) { return false; } + + const eT eT_zero = eT(0); + + // quickly check element at top-right + + const eT* A_colNm1 = A.colptr(N-1); + + if(A_colNm1[0] != eT_zero) { return false; } + + // if we got to this point, do a thorough check + + const eT* A_col = A.memptr() + N; + + for(uword j=1; j < N; ++j) + { + for(uword i=0; i < j; ++i) + { + const eT A_ij = A_col[i]; + + if(A_ij != eT_zero) { return false; } + } + + A_col += N; + } + + return true; + } + + + +template +inline +bool +has_nonfinite_tril(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + + const eT* colptr = A.memptr(); + const uword N = A.n_rows; + + for(uword i=0; i +inline +bool +has_nonfinite_triu(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + + const eT* colptr = A.memptr(); + const uword N = A.n_rows; + + for(uword i=0; i= 0xff + typedef unsigned char u8; + typedef char s8; + #elif defined(UINT8_MAX) + typedef uint8_t u8; + typedef int8_t s8; + #else + #error "don't know how to typedef 'u8' on this system" + #endif +#endif + +// NOTE: "char" can be either "signed char" or "unsigned char" +// NOTE: https://en.wikipedia.org/wiki/C_data_types + + +#if USHRT_MAX >= 0xffff + typedef unsigned short u16; + typedef short s16; +#elif defined(UINT16_MAX) + typedef uint16_t u16; + typedef int16_t s16; +#else + #error "don't know how to typedef 'u16' on this system" +#endif + + +#if UINT_MAX >= 0xffffffff + typedef unsigned int u32; + typedef int s32; +#elif defined(UINT32_MAX) + typedef uint32_t u32; + typedef int32_t s32; +#else + #error "don't know how to typedef 'u32' on this system" +#endif + + +#if ULLONG_MAX >= 0xffffffffffffffff + typedef unsigned long long u64; + typedef long long s64; +#elif defined(UINT64_MAX) + typedef uint64_t u64; + typedef int64_t s64; +#else + #error "don't know how to typedef 'u64' on this system" +#endif + + +// for compatibility with earlier versions of Armadillo +typedef unsigned long ulng_t; +typedef long slng_t; + + +#if defined(ARMA_64BIT_WORD) + typedef u64 uword; + typedef s64 sword; + + typedef u32 uhword; + typedef s32 shword; + + #define ARMA_MAX_UWORD 0xffffffffffffffff + #define ARMA_MAX_UHWORD 0xffffffff +#else + typedef u32 uword; + typedef s32 sword; + + typedef u16 uhword; + typedef s16 shword; + + #define ARMA_MAX_UWORD 0xffffffff + #define ARMA_MAX_UHWORD 0xffff +#endif + + +typedef std::complex cx_float; +typedef std::complex cx_double; + +typedef void* void_ptr; + + +// + + +#if defined(ARMA_BLAS_LONG_LONG) + typedef long long blas_int; + #define ARMA_MAX_BLAS_INT 0x7fffffffffffffffULL +#elif defined(ARMA_BLAS_LONG) + typedef long blas_int; + #define ARMA_MAX_BLAS_INT 0x7fffffffffffffffUL +#else + typedef int blas_int; + #define ARMA_MAX_BLAS_INT 0x7fffffffU +#endif + + +// + + +#if defined(ARMA_USE_MKL_TYPES) + // for compatibility with MKL + typedef MKL_Complex8 blas_cxf; + typedef MKL_Complex16 blas_cxd; +#else + // standard BLAS and LAPACK prototypes use "void*" pointers for complex arrays + typedef void blas_cxf; + typedef void blas_cxd; +#endif + + +// + + +// NOTE: blas_len is the fortran type for "hidden" arguments that specify the length of character arguments; +// NOTE: it varies across compilers, compiler versions and systems (eg. 32 bit vs 64 bit); +// NOTE: the default setting of "size_t" is an educated guess. +// NOTE: --- +// NOTE: for gcc / gfortran: https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html +// NOTE: gcc 7 and earlier: int +// NOTE: gcc 8 and 9: size_t +// NOTE: --- +// NOTE: for ifort (intel fortran compiler): +// NOTE: "Intel Fortran Compiler User and Reference Guides", Document Number: 304970-006US, 2009, p. 301 +// NOTE: http://www.complexfluids.ethz.ch/MK/ifort.pdf +// NOTE: the type is unsigned 4-byte integer on 32 bit systems +// NOTE: the type is unsigned 8-byte integer on 64 bit systems +// NOTE: --- +// NOTE: for NAG fortran: https://www.nag.co.uk/nagware/np/r62_doc/manual/compiler_11_1.html#AUTOTOC_11_1 +// NOTE: Chrlen = usually int, or long long on 64-bit Windows +// NOTE: --- +// TODO: flang: https://github.com/flang-compiler/flang/wiki +// TODO: other compilers: http://fortranwiki.org/fortran/show/Compilers + +#if !defined(ARMA_FORTRAN_CHARLEN_TYPE) + #if defined(__GNUC__) && !defined(__clang__) + #if (__GNUC__ <= 7) + #define ARMA_FORTRAN_CHARLEN_TYPE int + #else + #define ARMA_FORTRAN_CHARLEN_TYPE size_t + #endif + #else + // TODO: determine the type for other compilers + #define ARMA_FORTRAN_CHARLEN_TYPE size_t + #endif +#endif + +typedef ARMA_FORTRAN_CHARLEN_TYPE blas_len; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/typedef_elem_check.hpp b/src/armadillo/include/armadillo_bits/typedef_elem_check.hpp new file mode 100644 index 0000000..db462ab --- /dev/null +++ b/src/armadillo/include/armadillo_bits/typedef_elem_check.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup typedef_elem +//! @{ + + +namespace junk + { + struct arma_elem_size_test + { + arma_static_check( (sizeof(u8) != 1), "error: type 'u8' has unsupported size" ); + arma_static_check( (sizeof(s8) != 1), "error: type 's8' has unsupported size" ); + + arma_static_check( (sizeof(u16) != 2), "error: type 'u16' has unsupported size" ); + arma_static_check( (sizeof(s16) != 2), "error: type 's16' has unsupported size" ); + + arma_static_check( (sizeof(u32) != 4), "error: type 'u32' has unsupported size" ); + arma_static_check( (sizeof(s32) != 4), "error: type 's32' has unsupported size" ); + + arma_static_check( (sizeof(u64) != 8), "error: type 'u64' has unsupported size" ); + arma_static_check( (sizeof(s64) != 8), "error: type 's64' has unsupported size" ); + + arma_static_check( (sizeof(float) != 4), "error: type 'float' has unsupported size" ); + arma_static_check( (sizeof(double) != 8), "error: type 'double' has unsupported size" ); + + arma_static_check( (sizeof(std::complex) != 8), "type 'std::complex' has unsupported size" ); + arma_static_check( (sizeof(std::complex) != 16), "type 'std::complex' has unsupported size" ); + }; + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/typedef_mat.hpp b/src/armadillo/include/armadillo_bits/typedef_mat.hpp new file mode 100644 index 0000000..69a4c90 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/typedef_mat.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup typedef_mat +//! @{ + + +typedef Mat uchar_mat; +typedef Col uchar_vec; +typedef Col uchar_colvec; +typedef Row uchar_rowvec; +typedef Cube uchar_cube; + +typedef Mat u32_mat; +typedef Col u32_vec; +typedef Col u32_colvec; +typedef Row u32_rowvec; +typedef Cube u32_cube; + +typedef Mat s32_mat; +typedef Col s32_vec; +typedef Col s32_colvec; +typedef Row s32_rowvec; +typedef Cube s32_cube; + +typedef Mat u64_mat; +typedef Col u64_vec; +typedef Col u64_colvec; +typedef Row u64_rowvec; +typedef Cube u64_cube; + +typedef Mat s64_mat; +typedef Col s64_vec; +typedef Col s64_colvec; +typedef Row s64_rowvec; +typedef Cube s64_cube; + +typedef Mat umat; +typedef Col uvec; +typedef Col ucolvec; +typedef Row urowvec; +typedef Cube ucube; + +typedef Mat imat; +typedef Col ivec; +typedef Col icolvec; +typedef Row irowvec; +typedef Cube icube; + +typedef Mat fmat; +typedef Col fvec; +typedef Col fcolvec; +typedef Row frowvec; +typedef Cube fcube; + +typedef Mat dmat; +typedef Col dvec; +typedef Col dcolvec; +typedef Row drowvec; +typedef Cube dcube; + +typedef Mat mat; +typedef Col vec; +typedef Col colvec; +typedef Row rowvec; +typedef Cube cube; + +typedef Mat cx_fmat; +typedef Col cx_fvec; +typedef Col cx_fcolvec; +typedef Row cx_frowvec; +typedef Cube cx_fcube; + +typedef Mat cx_dmat; +typedef Col cx_dvec; +typedef Col cx_dcolvec; +typedef Row cx_drowvec; +typedef Cube cx_dcube; + +typedef Mat cx_mat; +typedef Col cx_vec; +typedef Col cx_colvec; +typedef Row cx_rowvec; +typedef Cube cx_cube; + + + +typedef SpMat sp_umat; +typedef SpCol sp_uvec; +typedef SpCol sp_ucolvec; +typedef SpRow sp_urowvec; + +typedef SpMat sp_imat; +typedef SpCol sp_ivec; +typedef SpCol sp_icolvec; +typedef SpRow sp_irowvec; + +typedef SpMat sp_fmat; +typedef SpCol sp_fvec; +typedef SpCol sp_fcolvec; +typedef SpRow sp_frowvec; + +typedef SpMat sp_dmat; +typedef SpCol sp_dvec; +typedef SpCol sp_dcolvec; +typedef SpRow sp_drowvec; + +typedef SpMat sp_mat; +typedef SpCol sp_vec; +typedef SpCol sp_colvec; +typedef SpRow sp_rowvec; + +typedef SpMat sp_cx_fmat; +typedef SpCol sp_cx_fvec; +typedef SpCol sp_cx_fcolvec; +typedef SpRow sp_cx_frowvec; + +typedef SpMat sp_cx_dmat; +typedef SpCol sp_cx_dvec; +typedef SpCol sp_cx_dcolvec; +typedef SpRow sp_cx_drowvec; + +typedef SpMat sp_cx_mat; +typedef SpCol sp_cx_vec; +typedef SpCol sp_cx_colvec; +typedef SpRow sp_cx_rowvec; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/typedef_mat_fixed.hpp b/src/armadillo/include/armadillo_bits/typedef_mat_fixed.hpp new file mode 100644 index 0000000..bd45615 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/typedef_mat_fixed.hpp @@ -0,0 +1,326 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup typedef_mat_fixed +//! @{ + + + +typedef umat::fixed<2,2> umat22; +typedef umat::fixed<3,3> umat33; +typedef umat::fixed<4,4> umat44; +typedef umat::fixed<5,5> umat55; +typedef umat::fixed<6,6> umat66; +typedef umat::fixed<7,7> umat77; +typedef umat::fixed<8,8> umat88; +typedef umat::fixed<9,9> umat99; + +typedef imat::fixed<2,2> imat22; +typedef imat::fixed<3,3> imat33; +typedef imat::fixed<4,4> imat44; +typedef imat::fixed<5,5> imat55; +typedef imat::fixed<6,6> imat66; +typedef imat::fixed<7,7> imat77; +typedef imat::fixed<8,8> imat88; +typedef imat::fixed<9,9> imat99; + +typedef fmat::fixed<2,2> fmat22; +typedef fmat::fixed<3,3> fmat33; +typedef fmat::fixed<4,4> fmat44; +typedef fmat::fixed<5,5> fmat55; +typedef fmat::fixed<6,6> fmat66; +typedef fmat::fixed<7,7> fmat77; +typedef fmat::fixed<8,8> fmat88; +typedef fmat::fixed<9,9> fmat99; + +typedef dmat::fixed<2,2> dmat22; +typedef dmat::fixed<3,3> dmat33; +typedef dmat::fixed<4,4> dmat44; +typedef dmat::fixed<5,5> dmat55; +typedef dmat::fixed<6,6> dmat66; +typedef dmat::fixed<7,7> dmat77; +typedef dmat::fixed<8,8> dmat88; +typedef dmat::fixed<9,9> dmat99; + +typedef mat::fixed<2,2> mat22; +typedef mat::fixed<3,3> mat33; +typedef mat::fixed<4,4> mat44; +typedef mat::fixed<5,5> mat55; +typedef mat::fixed<6,6> mat66; +typedef mat::fixed<7,7> mat77; +typedef mat::fixed<8,8> mat88; +typedef mat::fixed<9,9> mat99; + +typedef cx_fmat::fixed<2,2> cx_fmat22; +typedef cx_fmat::fixed<3,3> cx_fmat33; +typedef cx_fmat::fixed<4,4> cx_fmat44; +typedef cx_fmat::fixed<5,5> cx_fmat55; +typedef cx_fmat::fixed<6,6> cx_fmat66; +typedef cx_fmat::fixed<7,7> cx_fmat77; +typedef cx_fmat::fixed<8,8> cx_fmat88; +typedef cx_fmat::fixed<9,9> cx_fmat99; + +typedef cx_dmat::fixed<2,2> cx_dmat22; +typedef cx_dmat::fixed<3,3> cx_dmat33; +typedef cx_dmat::fixed<4,4> cx_dmat44; +typedef cx_dmat::fixed<5,5> cx_dmat55; +typedef cx_dmat::fixed<6,6> cx_dmat66; +typedef cx_dmat::fixed<7,7> cx_dmat77; +typedef cx_dmat::fixed<8,8> cx_dmat88; +typedef cx_dmat::fixed<9,9> cx_dmat99; + +typedef cx_mat::fixed<2,2> cx_mat22; +typedef cx_mat::fixed<3,3> cx_mat33; +typedef cx_mat::fixed<4,4> cx_mat44; +typedef cx_mat::fixed<5,5> cx_mat55; +typedef cx_mat::fixed<6,6> cx_mat66; +typedef cx_mat::fixed<7,7> cx_mat77; +typedef cx_mat::fixed<8,8> cx_mat88; +typedef cx_mat::fixed<9,9> cx_mat99; + + +// + + +typedef uvec::fixed<2> uvec2; +typedef uvec::fixed<3> uvec3; +typedef uvec::fixed<4> uvec4; +typedef uvec::fixed<5> uvec5; +typedef uvec::fixed<6> uvec6; +typedef uvec::fixed<7> uvec7; +typedef uvec::fixed<8> uvec8; +typedef uvec::fixed<9> uvec9; + +typedef ivec::fixed<2> ivec2; +typedef ivec::fixed<3> ivec3; +typedef ivec::fixed<4> ivec4; +typedef ivec::fixed<5> ivec5; +typedef ivec::fixed<6> ivec6; +typedef ivec::fixed<7> ivec7; +typedef ivec::fixed<8> ivec8; +typedef ivec::fixed<9> ivec9; + +typedef fvec::fixed<2> fvec2; +typedef fvec::fixed<3> fvec3; +typedef fvec::fixed<4> fvec4; +typedef fvec::fixed<5> fvec5; +typedef fvec::fixed<6> fvec6; +typedef fvec::fixed<7> fvec7; +typedef fvec::fixed<8> fvec8; +typedef fvec::fixed<9> fvec9; + +typedef dvec::fixed<2> dvec2; +typedef dvec::fixed<3> dvec3; +typedef dvec::fixed<4> dvec4; +typedef dvec::fixed<5> dvec5; +typedef dvec::fixed<6> dvec6; +typedef dvec::fixed<7> dvec7; +typedef dvec::fixed<8> dvec8; +typedef dvec::fixed<9> dvec9; + +typedef vec::fixed<2> vec2; +typedef vec::fixed<3> vec3; +typedef vec::fixed<4> vec4; +typedef vec::fixed<5> vec5; +typedef vec::fixed<6> vec6; +typedef vec::fixed<7> vec7; +typedef vec::fixed<8> vec8; +typedef vec::fixed<9> vec9; + +typedef cx_fvec::fixed<2> cx_fvec2; +typedef cx_fvec::fixed<3> cx_fvec3; +typedef cx_fvec::fixed<4> cx_fvec4; +typedef cx_fvec::fixed<5> cx_fvec5; +typedef cx_fvec::fixed<6> cx_fvec6; +typedef cx_fvec::fixed<7> cx_fvec7; +typedef cx_fvec::fixed<8> cx_fvec8; +typedef cx_fvec::fixed<9> cx_fvec9; + +typedef cx_dvec::fixed<2> cx_dvec2; +typedef cx_dvec::fixed<3> cx_dvec3; +typedef cx_dvec::fixed<4> cx_dvec4; +typedef cx_dvec::fixed<5> cx_dvec5; +typedef cx_dvec::fixed<6> cx_dvec6; +typedef cx_dvec::fixed<7> cx_dvec7; +typedef cx_dvec::fixed<8> cx_dvec8; +typedef cx_dvec::fixed<9> cx_dvec9; + +typedef cx_vec::fixed<2> cx_vec2; +typedef cx_vec::fixed<3> cx_vec3; +typedef cx_vec::fixed<4> cx_vec4; +typedef cx_vec::fixed<5> cx_vec5; +typedef cx_vec::fixed<6> cx_vec6; +typedef cx_vec::fixed<7> cx_vec7; +typedef cx_vec::fixed<8> cx_vec8; +typedef cx_vec::fixed<9> cx_vec9; + + +// + + +typedef ucolvec::fixed<2> ucolvec2; +typedef ucolvec::fixed<3> ucolvec3; +typedef ucolvec::fixed<4> ucolvec4; +typedef ucolvec::fixed<5> ucolvec5; +typedef ucolvec::fixed<6> ucolvec6; +typedef ucolvec::fixed<7> ucolvec7; +typedef ucolvec::fixed<8> ucolvec8; +typedef ucolvec::fixed<9> ucolvec9; + +typedef icolvec::fixed<2> icolvec2; +typedef icolvec::fixed<3> icolvec3; +typedef icolvec::fixed<4> icolvec4; +typedef icolvec::fixed<5> icolvec5; +typedef icolvec::fixed<6> icolvec6; +typedef icolvec::fixed<7> icolvec7; +typedef icolvec::fixed<8> icolvec8; +typedef icolvec::fixed<9> icolvec9; + +typedef fcolvec::fixed<2> fcolvec2; +typedef fcolvec::fixed<3> fcolvec3; +typedef fcolvec::fixed<4> fcolvec4; +typedef fcolvec::fixed<5> fcolvec5; +typedef fcolvec::fixed<6> fcolvec6; +typedef fcolvec::fixed<7> fcolvec7; +typedef fcolvec::fixed<8> fcolvec8; +typedef fcolvec::fixed<9> fcolvec9; + +typedef dcolvec::fixed<2> dcolvec2; +typedef dcolvec::fixed<3> dcolvec3; +typedef dcolvec::fixed<4> dcolvec4; +typedef dcolvec::fixed<5> dcolvec5; +typedef dcolvec::fixed<6> dcolvec6; +typedef dcolvec::fixed<7> dcolvec7; +typedef dcolvec::fixed<8> dcolvec8; +typedef dcolvec::fixed<9> dcolvec9; + +typedef colvec::fixed<2> colvec2; +typedef colvec::fixed<3> colvec3; +typedef colvec::fixed<4> colvec4; +typedef colvec::fixed<5> colvec5; +typedef colvec::fixed<6> colvec6; +typedef colvec::fixed<7> colvec7; +typedef colvec::fixed<8> colvec8; +typedef colvec::fixed<9> colvec9; + +typedef cx_fcolvec::fixed<2> cx_fcolvec2; +typedef cx_fcolvec::fixed<3> cx_fcolvec3; +typedef cx_fcolvec::fixed<4> cx_fcolvec4; +typedef cx_fcolvec::fixed<5> cx_fcolvec5; +typedef cx_fcolvec::fixed<6> cx_fcolvec6; +typedef cx_fcolvec::fixed<7> cx_fcolvec7; +typedef cx_fcolvec::fixed<8> cx_fcolvec8; +typedef cx_fcolvec::fixed<9> cx_fcolvec9; + +typedef cx_dcolvec::fixed<2> cx_dcolvec2; +typedef cx_dcolvec::fixed<3> cx_dcolvec3; +typedef cx_dcolvec::fixed<4> cx_dcolvec4; +typedef cx_dcolvec::fixed<5> cx_dcolvec5; +typedef cx_dcolvec::fixed<6> cx_dcolvec6; +typedef cx_dcolvec::fixed<7> cx_dcolvec7; +typedef cx_dcolvec::fixed<8> cx_dcolvec8; +typedef cx_dcolvec::fixed<9> cx_dcolvec9; + +typedef cx_colvec::fixed<2> cx_colvec2; +typedef cx_colvec::fixed<3> cx_colvec3; +typedef cx_colvec::fixed<4> cx_colvec4; +typedef cx_colvec::fixed<5> cx_colvec5; +typedef cx_colvec::fixed<6> cx_colvec6; +typedef cx_colvec::fixed<7> cx_colvec7; +typedef cx_colvec::fixed<8> cx_colvec8; +typedef cx_colvec::fixed<9> cx_colvec9; + + +// + + +typedef urowvec::fixed<2> urowvec2; +typedef urowvec::fixed<3> urowvec3; +typedef urowvec::fixed<4> urowvec4; +typedef urowvec::fixed<5> urowvec5; +typedef urowvec::fixed<6> urowvec6; +typedef urowvec::fixed<7> urowvec7; +typedef urowvec::fixed<8> urowvec8; +typedef urowvec::fixed<9> urowvec9; + +typedef irowvec::fixed<2> irowvec2; +typedef irowvec::fixed<3> irowvec3; +typedef irowvec::fixed<4> irowvec4; +typedef irowvec::fixed<5> irowvec5; +typedef irowvec::fixed<6> irowvec6; +typedef irowvec::fixed<7> irowvec7; +typedef irowvec::fixed<8> irowvec8; +typedef irowvec::fixed<9> irowvec9; + +typedef frowvec::fixed<2> frowvec2; +typedef frowvec::fixed<3> frowvec3; +typedef frowvec::fixed<4> frowvec4; +typedef frowvec::fixed<5> frowvec5; +typedef frowvec::fixed<6> frowvec6; +typedef frowvec::fixed<7> frowvec7; +typedef frowvec::fixed<8> frowvec8; +typedef frowvec::fixed<9> frowvec9; + +typedef drowvec::fixed<2> drowvec2; +typedef drowvec::fixed<3> drowvec3; +typedef drowvec::fixed<4> drowvec4; +typedef drowvec::fixed<5> drowvec5; +typedef drowvec::fixed<6> drowvec6; +typedef drowvec::fixed<7> drowvec7; +typedef drowvec::fixed<8> drowvec8; +typedef drowvec::fixed<9> drowvec9; + +typedef rowvec::fixed<2> rowvec2; +typedef rowvec::fixed<3> rowvec3; +typedef rowvec::fixed<4> rowvec4; +typedef rowvec::fixed<5> rowvec5; +typedef rowvec::fixed<6> rowvec6; +typedef rowvec::fixed<7> rowvec7; +typedef rowvec::fixed<8> rowvec8; +typedef rowvec::fixed<9> rowvec9; + +typedef cx_frowvec::fixed<2> cx_frowvec2; +typedef cx_frowvec::fixed<3> cx_frowvec3; +typedef cx_frowvec::fixed<4> cx_frowvec4; +typedef cx_frowvec::fixed<5> cx_frowvec5; +typedef cx_frowvec::fixed<6> cx_frowvec6; +typedef cx_frowvec::fixed<7> cx_frowvec7; +typedef cx_frowvec::fixed<8> cx_frowvec8; +typedef cx_frowvec::fixed<9> cx_frowvec9; + +typedef cx_drowvec::fixed<2> cx_drowvec2; +typedef cx_drowvec::fixed<3> cx_drowvec3; +typedef cx_drowvec::fixed<4> cx_drowvec4; +typedef cx_drowvec::fixed<5> cx_drowvec5; +typedef cx_drowvec::fixed<6> cx_drowvec6; +typedef cx_drowvec::fixed<7> cx_drowvec7; +typedef cx_drowvec::fixed<8> cx_drowvec8; +typedef cx_drowvec::fixed<9> cx_drowvec9; + +typedef cx_rowvec::fixed<2> cx_rowvec2; +typedef cx_rowvec::fixed<3> cx_rowvec3; +typedef cx_rowvec::fixed<4> cx_rowvec4; +typedef cx_rowvec::fixed<5> cx_rowvec5; +typedef cx_rowvec::fixed<6> cx_rowvec6; +typedef cx_rowvec::fixed<7> cx_rowvec7; +typedef cx_rowvec::fixed<8> cx_rowvec8; +typedef cx_rowvec::fixed<9> cx_rowvec9; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/unwrap.hpp b/src/armadillo/include/armadillo_bits/unwrap.hpp new file mode 100644 index 0000000..4e93506 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/unwrap.hpp @@ -0,0 +1,3421 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup unwrap +//! @{ + + +// TODO: document the conditions and restrictions for the use of each unwrap variant: +// TODO: unwrap, unwrap_check, quasi_unwrap, partial_unwrap, partial_unwrap_check + + +template +struct unwrap_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + unwrap_default(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Mat M; + }; + + + +template +struct unwrap_fixed + { + typedef T1 stored_type; + + inline explicit + unwrap_fixed(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const T1& M; + }; + + + +template +struct unwrap_redirect {}; + +template +struct unwrap_redirect { typedef unwrap_default result; }; + +template +struct unwrap_redirect { typedef unwrap_fixed result; }; + + +template +struct unwrap : public unwrap_redirect::value>::result + { + inline + unwrap(const T1& A) + : unwrap_redirect::value>::result(A) + { + } + }; + + + +template +struct unwrap< Mat > + { + typedef Mat stored_type; + + inline + unwrap(const Mat& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Mat& M; + }; + + + +template +struct unwrap< Row > + { + typedef Row stored_type; + + inline + unwrap(const Row& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Row& M; + }; + + + +template +struct unwrap< Col > + { + typedef Col stored_type; + + inline + unwrap(const Col& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Col& M; + }; + + + +template +struct unwrap< subview_col > + { + typedef Col stored_type; + + inline + unwrap(const subview_col& A) + : M(A.colmem, A.n_rows) + { + arma_extra_debug_sigprint(); + } + + const Col M; + }; + + + +template +struct unwrap< subview_cols > + { + typedef Mat stored_type; + + inline + unwrap(const subview_cols& A) + : M(A.colptr(0), A.n_rows, A.n_cols) + { + arma_extra_debug_sigprint(); + } + + const Mat M; + }; + + + +template +struct unwrap< mtGlue > + { + typedef Mat stored_type; + + inline + unwrap(const mtGlue& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Mat M; + }; + + + +template +struct unwrap< mtOp > + { + typedef Mat stored_type; + + inline + unwrap(const mtOp& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Mat M; + }; + + + +// +// +// + + + +template +struct quasi_unwrap_default + { + typedef typename T1::elem_type eT; + + inline + quasi_unwrap_default(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + // NOTE: DO NOT DIRECTLY CHECK FOR ALIASING BY TAKING THE ADDRESS OF THE "M" OBJECT IN ANY quasi_unwrap CLASS !!! + Mat M; + + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; + + template + constexpr bool is_alias(const Mat&) const { return false; } + }; + + + +template +struct quasi_unwrap_fixed + { + typedef typename T1::elem_type eT; + + inline explicit + quasi_unwrap_fixed(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const T1& M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap_redirect {}; + +template +struct quasi_unwrap_redirect { typedef quasi_unwrap_default result; }; + +template +struct quasi_unwrap_redirect { typedef quasi_unwrap_fixed result; }; + + +template +struct quasi_unwrap : public quasi_unwrap_redirect::value>::result + { + typedef typename quasi_unwrap_redirect::value>::result quasi_unwrap_extra; + + inline + quasi_unwrap(const T1& A) + : quasi_unwrap_extra(A) + { + } + + static constexpr bool is_const = quasi_unwrap_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_extra::has_orig_mem; + + using quasi_unwrap_extra::M; + using quasi_unwrap_extra::is_alias; + }; + + + +template +struct quasi_unwrap< Mat > + { + inline + quasi_unwrap(const Mat& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Mat& M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< Row > + { + + inline + quasi_unwrap(const Row& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Row& M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< Col > + { + inline + quasi_unwrap(const Col& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Col& M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< subview > + { + inline + quasi_unwrap(const subview& A) + : sv( A ) + , M ( A, ((A.aux_row1 == 0) && (A.n_rows == A.m.n_rows)) ) // reuse memory if the subview is a contiguous chunk + { + arma_extra_debug_sigprint(); + } + + const subview& sv; + const Mat M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = false; // NOTE: set to false as this is the general case; original memory is only used when the subview is a contiguous chunk + + template + arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } + }; + + + +template +struct quasi_unwrap< subview_row > + { + inline + quasi_unwrap(const subview_row& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + Row M; + + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; + + template + constexpr bool is_alias(const Mat&) const { return false; } + }; + + + +template +struct quasi_unwrap< subview_col > + { + inline + quasi_unwrap(const subview_col& A) + : orig( A.m ) + , M ( const_cast( A.colmem ), A.n_rows, false, false ) + { + arma_extra_debug_sigprint(); + } + + const Mat& orig; + const Col M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< subview_cols > + { + inline + quasi_unwrap(const subview_cols& A) + : orig( A.m ) + , M ( const_cast( A.colptr(0) ), A.n_rows, A.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + const Mat& orig; + const Mat M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< mtGlue > + { + inline + quasi_unwrap(const mtGlue& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + Mat M; + + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; + + template + constexpr bool is_alias(const Mat&) const { return false; } + }; + + + +template +struct quasi_unwrap< mtOp > + { + inline + quasi_unwrap(const mtOp& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + Mat M; + + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; + + template + constexpr bool is_alias(const Mat&) const { return false; } + }; + + + +template +struct quasi_unwrap< Op > + { + typedef typename T1::elem_type eT; + + inline + quasi_unwrap(const Op& A) + : U( A.m ) + , M( const_cast(U.M.memptr()), U.M.n_elem, 1, false, false ) + { + arma_extra_debug_sigprint(); + } + + const quasi_unwrap U; + const Mat M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return U.is_alias(X); } + }; + + + +template +struct quasi_unwrap< Op, op_strans> > + { + inline + quasi_unwrap(const Op, op_strans>& A) + : orig(A.m) + , M (const_cast(A.m.memptr()), A.m.n_elem, false, false) + { + arma_extra_debug_sigprint(); + } + + const Col& orig; + const Row M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< Op, op_strans> > + { + inline + quasi_unwrap(const Op, op_strans>& A) + : orig(A.m) + , M (const_cast(A.m.memptr()), A.m.n_elem, false, false) + { + arma_extra_debug_sigprint(); + } + + const Row& orig; + const Col M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< Op, op_strans> > + { + inline + quasi_unwrap(const Op, op_strans>& A) + : orig( A.m.m ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, false, false ) + { + arma_extra_debug_sigprint(); + } + + const Mat& orig; + const Row M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + }; + + + +template +struct quasi_unwrap_Col_htrans + { + inline quasi_unwrap_Col_htrans(const T1&) {} + }; + + + +template +struct quasi_unwrap_Col_htrans< Op, op_htrans> > + { + inline + quasi_unwrap_Col_htrans(const Op, op_htrans>& A) + : orig(A.m) + , M (const_cast(A.m.memptr()), A.m.n_elem, false, false) + { + arma_extra_debug_sigprint(); + } + + const Col& orig; + const Row M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap_Col_htrans_redirect {}; + +template +struct quasi_unwrap_Col_htrans_redirect { typedef quasi_unwrap_default result; }; + +template +struct quasi_unwrap_Col_htrans_redirect { typedef quasi_unwrap_Col_htrans result; }; + + +template +struct quasi_unwrap< Op, op_htrans> > + : public quasi_unwrap_Col_htrans_redirect< Op, op_htrans>, is_cx::value >::result + { + typedef typename quasi_unwrap_Col_htrans_redirect< Op, op_htrans>, is_cx::value >::result quasi_unwrap_Col_htrans_extra; + + inline + quasi_unwrap(const Op, op_htrans>& A) + : quasi_unwrap_Col_htrans_extra(A) + { + } + + static constexpr bool is_const = quasi_unwrap_Col_htrans_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_Col_htrans_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_Col_htrans_extra::has_orig_mem; + + using quasi_unwrap_Col_htrans_extra::M; + using quasi_unwrap_Col_htrans_extra::is_alias; + }; + + + +template +struct quasi_unwrap_Row_htrans + { + inline quasi_unwrap_Row_htrans(const T1&) {} + }; + + + +template +struct quasi_unwrap_Row_htrans< Op, op_htrans> > + { + inline + quasi_unwrap_Row_htrans(const Op, op_htrans>& A) + : orig(A.m) + , M (const_cast(A.m.memptr()), A.m.n_elem, false, false) + { + arma_extra_debug_sigprint(); + } + + const Row& orig; + const Col M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap_Row_htrans_redirect {}; + +template +struct quasi_unwrap_Row_htrans_redirect { typedef quasi_unwrap_default result; }; + +template +struct quasi_unwrap_Row_htrans_redirect { typedef quasi_unwrap_Row_htrans result; }; + + +template +struct quasi_unwrap< Op, op_htrans> > + : public quasi_unwrap_Row_htrans_redirect< Op, op_htrans>, is_cx::value >::result + { + typedef typename quasi_unwrap_Row_htrans_redirect< Op, op_htrans>, is_cx::value >::result quasi_unwrap_Row_htrans_extra; + + inline + quasi_unwrap(const Op, op_htrans>& A) + : quasi_unwrap_Row_htrans_extra(A) + { + } + + static constexpr bool is_const = quasi_unwrap_Row_htrans_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_Row_htrans_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_Row_htrans_extra::has_orig_mem; + + using quasi_unwrap_Row_htrans_extra::M; + using quasi_unwrap_Row_htrans_extra::is_alias; + }; + + + +template +struct quasi_unwrap_subview_col_htrans + { + inline quasi_unwrap_subview_col_htrans(const T1&) {} + }; + + + +template +struct quasi_unwrap_subview_col_htrans< Op, op_htrans> > + { + inline + quasi_unwrap_subview_col_htrans(const Op, op_htrans>& A) + : orig(A.m.m) + , M (const_cast(A.m.colmem), A.m.n_rows, false, false) + { + arma_extra_debug_sigprint(); + } + + const Mat& orig; + const Row M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap_subview_col_htrans_redirect {}; + +template +struct quasi_unwrap_subview_col_htrans_redirect { typedef quasi_unwrap_default result; }; + +template +struct quasi_unwrap_subview_col_htrans_redirect { typedef quasi_unwrap_subview_col_htrans result; }; + + +template +struct quasi_unwrap< Op, op_htrans> > + : public quasi_unwrap_subview_col_htrans_redirect< Op, op_htrans>, is_cx::value >::result + { + typedef typename quasi_unwrap_subview_col_htrans_redirect< Op, op_htrans>, is_cx::value >::result quasi_unwrap_subview_col_htrans_extra; + + inline + quasi_unwrap(const Op, op_htrans>& A) + : quasi_unwrap_subview_col_htrans_extra(A) + { + } + + static constexpr bool is_const = quasi_unwrap_subview_col_htrans_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_subview_col_htrans_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_subview_col_htrans_extra::has_orig_mem; + + using quasi_unwrap_subview_col_htrans_extra::M; + using quasi_unwrap_subview_col_htrans_extra::is_alias; + }; + + + +template +struct quasi_unwrap< CubeToMatOp > + { + typedef typename T1::elem_type eT; + + inline + quasi_unwrap(const CubeToMatOp& A) + : U( A.m ) + , M( const_cast(U.M.memptr()), U.M.n_elem, 1, false, true ) + { + arma_extra_debug_sigprint(); + } + + const unwrap_cube U; + const Mat M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + constexpr bool is_alias(const Mat&) const { return false; } + }; + + + +template +struct quasi_unwrap< SpToDOp > + { + typedef typename T1::elem_type eT; + + inline + quasi_unwrap(const SpToDOp& A) + : U( A.m ) + , M( const_cast(U.M.values), U.M.n_nonzero, 1, false, true ) + { + arma_extra_debug_sigprint(); + } + + const unwrap_spmat U; + const Mat M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + constexpr bool is_alias(const Mat&) const { return false; } + }; + + + +// +// +// + + + +template +struct unwrap_check_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + unwrap_check_default(const T1& A, const Mat&) + : M(A) + { + arma_extra_debug_sigprint(); + } + + inline + unwrap_check_default(const T1& A, const bool) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Mat M; + }; + + + +template +struct unwrap_check_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline + unwrap_check_fixed(const T1& A, const Mat& B) + : M_local( (&A == &B) ? new T1(A) : nullptr ) + , M ( (&A == &B) ? *M_local : A ) + { + arma_extra_debug_sigprint(); + } + + inline + unwrap_check_fixed(const T1& A, const bool is_alias) + : M_local( is_alias ? new T1(A) : nullptr ) + , M ( is_alias ? *M_local : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~unwrap_check_fixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const T1* M_local; + const T1& M; + }; + + + +template +struct unwrap_check_redirect {}; + +template +struct unwrap_check_redirect { typedef unwrap_check_default result; }; + +template +struct unwrap_check_redirect { typedef unwrap_check_fixed result; }; + + +template +struct unwrap_check : public unwrap_check_redirect::value>::result + { + inline unwrap_check(const T1& A, const Mat& B) + : unwrap_check_redirect::value>::result(A, B) + { + } + + inline unwrap_check(const T1& A, const bool is_alias) + : unwrap_check_redirect::value>::result(A, is_alias) + { + } + }; + + + +template +struct unwrap_check< Mat > + { + typedef Mat stored_type; + + inline + unwrap_check(const Mat& A, const Mat& B) + : M_local( (&A == &B) ? new Mat(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + unwrap_check(const Mat& A, const bool is_alias) + : M_local( is_alias ? new Mat(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const Mat* M_local; + const Mat& M; + }; + + + +template +struct unwrap_check< Row > + { + typedef Row stored_type; + + inline + unwrap_check(const Row& A, const Mat& B) + : M_local( (&A == &B) ? new Row(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + unwrap_check(const Row& A, const bool is_alias) + : M_local( is_alias ? new Row(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const Row* M_local; + const Row& M; + }; + + + +template +struct unwrap_check< Col > + { + typedef Col stored_type; + + inline + unwrap_check(const Col& A, const Mat& B) + : M_local( (&A == &B) ? new Col(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + unwrap_check(const Col& A, const bool is_alias) + : M_local( is_alias ? new Col(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const Col* M_local; + const Col& M; + }; + + + +// +// +// + + + +template +struct unwrap_check_mixed + { + typedef typename T1::elem_type eT1; + + template + inline + unwrap_check_mixed(const T1& A, const Mat&) + : M(A) + { + arma_extra_debug_sigprint(); + } + + //template + inline + unwrap_check_mixed(const T1& A, const bool) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Mat M; + }; + + + +template +struct unwrap_check_mixed< Mat > + { + template + inline + unwrap_check_mixed(const Mat& A, const Mat& B) + : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Mat(A) : nullptr ) + , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + //template + inline + unwrap_check_mixed(const Mat& A, const bool is_alias) + : M_local( is_alias ? new Mat(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~unwrap_check_mixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const Mat* M_local; + const Mat& M; + }; + + + +template +struct unwrap_check_mixed< Row > + { + template + inline + unwrap_check_mixed(const Row& A, const Mat& B) + : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Row(A) : nullptr ) + , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + + //template + inline + unwrap_check_mixed(const Row& A, const bool is_alias) + : M_local( is_alias ? new Row(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~unwrap_check_mixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const Row* M_local; + const Row& M; + }; + + + +template +struct unwrap_check_mixed< Col > + { + template + inline + unwrap_check_mixed(const Col& A, const Mat& B) + : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Col(A) : nullptr ) + , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + //template + inline + unwrap_check_mixed(const Col& A, const bool is_alias) + : M_local( is_alias ? new Col(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~unwrap_check_mixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const Col* M_local; + const Col& M; + }; + + + +// +// +// + + + +template +struct partial_unwrap_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_default(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Mat M; + }; + + +template +struct partial_unwrap_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_fixed(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const T1& M; + }; + + + +template +struct partial_unwrap_redirect {}; + +template +struct partial_unwrap_redirect { typedef partial_unwrap_default result; }; + +template +struct partial_unwrap_redirect { typedef partial_unwrap_fixed result; }; + +template +struct partial_unwrap : public partial_unwrap_redirect::value>::result + { + inline + partial_unwrap(const T1& A) + : partial_unwrap_redirect< T1, is_Mat_fixed::value>::result(A) + { + } + }; + + + +template +struct partial_unwrap< Mat > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Mat& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Mat& M; + }; + + + +template +struct partial_unwrap< Row > + { + typedef Row stored_type; + + inline + partial_unwrap(const Row& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Row& M; + }; + + + +template +struct partial_unwrap< Col > + { + typedef Col stored_type; + + inline + partial_unwrap(const Col& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Col& M; + }; + + + +template +struct partial_unwrap< subview > + { + typedef Mat stored_type; + + inline + partial_unwrap(const subview& A) + : sv( A ) + , M ( A, ((A.aux_row1 == 0) && (A.n_rows == A.m.n_rows)) ) // reuse memory if the subview is a contiguous chunk + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const subview& sv; + const Mat M; + }; + + + +template +struct partial_unwrap< subview_col > + { + typedef Col stored_type; + + inline + partial_unwrap(const subview_col& A) + : orig( A.m ) + , M ( const_cast( A.colmem ), A.n_rows, false, false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Mat& orig; + const Col M; + }; + + + +template +struct partial_unwrap< subview_cols > + { + typedef Mat stored_type; + + inline + partial_unwrap(const subview_cols& A) + : orig( A.m ) + , M ( const_cast( A.colptr(0) ), A.n_rows, A.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Mat& orig; + const Mat M; + }; + + + +template +struct partial_unwrap< subview_row > + { + typedef Row stored_type; + + inline + partial_unwrap(const subview_row& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Row M; + }; + + + +template +struct partial_unwrap_htrans_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_htrans_default(const Op& A) + : M(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Mat M; + }; + + +template +struct partial_unwrap_htrans_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_htrans_fixed(const Op& A) + : M(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const T1& M; + }; + + + +template +struct partial_unwrap_htrans_redirect {}; + +template +struct partial_unwrap_htrans_redirect { typedef partial_unwrap_htrans_default result; }; + +template +struct partial_unwrap_htrans_redirect { typedef partial_unwrap_htrans_fixed result; }; + +template +struct partial_unwrap< Op > : public partial_unwrap_htrans_redirect::value>::result + { + inline partial_unwrap(const Op& A) + : partial_unwrap_htrans_redirect::value>::result(A) + { + } + }; + + + +template +struct partial_unwrap< Op< Mat, op_htrans> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< Mat, op_htrans>& A) + : M(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Mat& M; + }; + + + +template +struct partial_unwrap< Op< Row, op_htrans> > + { + typedef Row stored_type; + + inline + partial_unwrap(const Op< Row, op_htrans>& A) + : M(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Row& M; + }; + + + +template +struct partial_unwrap< Op< Col, op_htrans> > + { + typedef Col stored_type; + + inline + partial_unwrap(const Op< Col, op_htrans>& A) + : M(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Col& M; + }; + + + +template +struct partial_unwrap< Op< subview, op_htrans> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview, op_htrans>& A) + : sv( A.m ) + , M ( A.m, ((A.m.aux_row1 == 0) && (A.m.n_rows == A.m.m.n_rows)) ) // reuse memory if the subview is a contiguous chunk + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const subview& sv; + const Mat M; + }; + + + +template +struct partial_unwrap< Op< subview_cols, op_htrans> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview_cols, op_htrans>& A) + : orig( A.m.m ) + , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, A.m.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Mat& orig; + const Mat M; + }; + + + +template +struct partial_unwrap< Op< subview_col, op_htrans> > + { + typedef Col stored_type; + + inline + partial_unwrap(const Op< subview_col, op_htrans>& A) + : orig( A.m.m ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, false, false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Mat& orig; + const Col M; + }; + + + +template +struct partial_unwrap< Op< subview_row, op_htrans> > + { + typedef Row stored_type; + + inline + partial_unwrap(const Op< subview_row, op_htrans>& A) + : M(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Row M; + }; + + + +template +struct partial_unwrap_htrans2_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_htrans2_default(const Op& A) + : val(A.aux) + , M (A.m) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const Mat M; + }; + + +template +struct partial_unwrap_htrans2_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_htrans2_fixed(const Op& A) + : val(A.aux) + , M (A.m) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const T1& M; + }; + + + +template +struct partial_unwrap_htrans2_redirect {}; + +template +struct partial_unwrap_htrans2_redirect { typedef partial_unwrap_htrans2_default result; }; + +template +struct partial_unwrap_htrans2_redirect { typedef partial_unwrap_htrans2_fixed result; }; + +template +struct partial_unwrap< Op > : public partial_unwrap_htrans2_redirect::value>::result + { + inline partial_unwrap(const Op& A) + : partial_unwrap_htrans2_redirect::value>::result(A) + { + } + }; + + + +template +struct partial_unwrap< Op< Mat, op_htrans2> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< Mat, op_htrans2>& A) + : val(A.aux) + , M (A.m) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const Mat& M; + }; + + + +template +struct partial_unwrap< Op< Row, op_htrans2> > + { + typedef Row stored_type; + + inline + partial_unwrap(const Op< Row, op_htrans2>& A) + : val(A.aux) + , M (A.m) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const Row& M; + }; + + + +template +struct partial_unwrap< Op< Col, op_htrans2> > + { + typedef Col stored_type; + + inline + partial_unwrap(const Op< Col, op_htrans2>& A) + : val(A.aux) + , M (A.m) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const Col& M; + }; + + + +template +struct partial_unwrap< Op< subview, op_htrans2> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview, op_htrans2>& A) + : sv ( A.m ) + , val( A.aux ) + , M ( A.m, ((A.m.aux_row1 == 0) && (A.m.n_rows == A.m.m.n_rows)) ) // reuse memory if the subview is a contiguous chunk + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const subview& sv; + const eT val; + const Mat M; + }; + + + +template +struct partial_unwrap< Op< subview_cols, op_htrans2> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview_cols, op_htrans2>& A) + : orig( A.m.m ) + , val ( A.aux ) + , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, A.m.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const Mat& orig; + const eT val; + const Mat M; + }; + + + +template +struct partial_unwrap< Op< subview_col, op_htrans2> > + { + typedef Col stored_type; + + inline + partial_unwrap(const Op< subview_col, op_htrans2>& A) + : orig( A.m.m ) + , val ( A.aux ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, false, false ) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const Mat& orig; + + const eT val; + const Col M; + }; + + + +template +struct partial_unwrap< Op< subview_row, op_htrans2> > + { + typedef Row stored_type; + + inline + partial_unwrap(const Op< subview_row, op_htrans2>& A) + : val(A.aux) + , M (A.m ) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const Row M; + }; + + + +template +struct partial_unwrap_scalar_times_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_scalar_times_default(const eOp& A) + : val(A.aux) + , M (A.P.Q) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Mat M; + }; + + + +template +struct partial_unwrap_scalar_times_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_scalar_times_fixed(const eOp& A) + : val(A.aux) + , M (A.P.Q) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const T1& M; + }; + + + +template +struct partial_unwrap_scalar_times_redirect {}; + +template +struct partial_unwrap_scalar_times_redirect { typedef partial_unwrap_scalar_times_default result; }; + +template +struct partial_unwrap_scalar_times_redirect { typedef partial_unwrap_scalar_times_fixed result; }; + + +template +struct partial_unwrap< eOp > : public partial_unwrap_scalar_times_redirect::value>::result + { + typedef typename T1::elem_type eT; + + inline + partial_unwrap(const eOp& A) + : partial_unwrap_scalar_times_redirect< T1, is_Mat_fixed::value>::result(A) + { + } + }; + + + +template +struct partial_unwrap< eOp, eop_scalar_times> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const eOp,eop_scalar_times>& A) + : val(A.aux) + , M (A.P.Q) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Mat& M; + }; + + + +template +struct partial_unwrap< eOp, eop_scalar_times> > + { + typedef Row stored_type; + + inline + partial_unwrap(const eOp,eop_scalar_times>& A) + : val(A.aux) + , M (A.P.Q) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Row& M; + }; + + + +template +struct partial_unwrap< eOp, eop_scalar_times> > + { + typedef Col stored_type; + + inline + partial_unwrap(const eOp,eop_scalar_times>& A) + : val(A.aux) + , M (A.P.Q) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Col& M; + }; + + + +template +struct partial_unwrap< eOp, eop_scalar_times> > + { + typedef Col stored_type; + + inline + partial_unwrap(const eOp,eop_scalar_times>& A) + : orig( A.P.Q.m ) + , val ( A.aux ) + , M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, false, false ) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Mat& orig; + + const eT val; + const Col M; + }; + + + +template +struct partial_unwrap< eOp, eop_scalar_times> > + { + typedef Row stored_type; + + inline + partial_unwrap(const eOp,eop_scalar_times>& A) + : val(A.aux) + , M (A.P.Q) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Row M; + }; + + + +template +struct partial_unwrap_neg_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_neg_default(const eOp& A) + : M(A.P.Q) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Mat M; + }; + + + +template +struct partial_unwrap_neg_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_neg_fixed(const eOp& A) + : M(A.P.Q) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const T1& M; + }; + + + +template +struct partial_unwrap_neg_redirect {}; + +template +struct partial_unwrap_neg_redirect { typedef partial_unwrap_neg_default result; }; + +template +struct partial_unwrap_neg_redirect { typedef partial_unwrap_neg_fixed result; }; + + +template +struct partial_unwrap< eOp > : public partial_unwrap_neg_redirect::value>::result + { + typedef typename T1::elem_type eT; + + inline + partial_unwrap(const eOp& A) + : partial_unwrap_neg_redirect< T1, is_Mat_fixed::value>::result(A) + { + } + }; + + + +template +struct partial_unwrap< eOp, eop_neg> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const eOp,eop_neg>& A) + : M(A.P.Q) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Mat& M; + }; + + + +template +struct partial_unwrap< eOp, eop_neg> > + { + typedef Row stored_type; + + inline + partial_unwrap(const eOp,eop_neg>& A) + : M(A.P.Q) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Row& M; + }; + + + +template +struct partial_unwrap< eOp, eop_neg> > + { + typedef Col stored_type; + + inline + partial_unwrap(const eOp,eop_neg>& A) + : M(A.P.Q) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Col& M; + }; + + + +template +struct partial_unwrap< eOp, eop_neg> > + { + typedef Col stored_type; + + inline + partial_unwrap(const eOp,eop_neg>& A) + : orig( A.P.Q.m ) + , M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, false, false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Mat& orig; + const Col M; + }; + + + +template +struct partial_unwrap< eOp, eop_neg> > + { + typedef Row stored_type; + + inline + partial_unwrap(const eOp,eop_neg>& A) + : M(A.P.Q) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Row M; + }; + + + +// + + + +template +struct partial_unwrap_check_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_check_default(const T1& A, const Mat&) + : M(A) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Mat M; + }; + + +template +struct partial_unwrap_check_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_check_fixed(const T1& A, const Mat& B) + : M_local( (&A == &B) ? new T1(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check_fixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const T1* M_local; + const T1& M; + }; + + + +template +struct partial_unwrap_check_redirect {}; + +template +struct partial_unwrap_check_redirect { typedef partial_unwrap_check_default result; }; + +template +struct partial_unwrap_check_redirect { typedef partial_unwrap_check_fixed result; }; + +template +struct partial_unwrap_check : public partial_unwrap_check_redirect::value>::result + { + typedef typename T1::elem_type eT; + + inline partial_unwrap_check(const T1& A, const Mat& B) + : partial_unwrap_check_redirect::value>::result(A, B) + { + } + }; + + + +template +struct partial_unwrap_check< Mat > + { + typedef Mat stored_type; + + inline + partial_unwrap_check(const Mat& A, const Mat& B) + : M_local ( (&A == &B) ? new Mat(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + // the order below is important + const Mat* M_local; + const Mat& M; + }; + + + +template +struct partial_unwrap_check< Row > + { + typedef Row stored_type; + + inline + partial_unwrap_check(const Row& A, const Mat& B) + : M_local ( (&A == &B) ? new Row(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + // the order below is important + const Row* M_local; + const Row& M; + }; + + + +template +struct partial_unwrap_check< Col > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const Col& A, const Mat& B) + : M_local ( (&A == &B) ? new Col(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + // the order below is important + const Col* M_local; + const Col& M; + }; + + + +// NOTE: we can get away with this shortcut as the partial_unwrap_check class is only used by the glue_times class, +// NOTE: which relies on partial_unwrap_check to check for aliasing +template +struct partial_unwrap_check< subview_col > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const subview_col& A, const Mat& B) + : M ( const_cast( A.colmem ), A.n_rows, (&(A.m) == &B), false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Col M; + }; + + + +template +struct partial_unwrap_check_htrans_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_check_htrans_default(const Op& A, const Mat&) + : M(A.m) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Mat M; + }; + + +template +struct partial_unwrap_check_htrans_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_check_htrans_fixed(const Op& A, const Mat& B) + : M_local( (&(A.m) == &B) ? new T1(A.m) : nullptr ) + , M ( (&(A.m) == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check_htrans_fixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const T1* M_local; + const T1& M; + }; + + + +template +struct partial_unwrap_check_htrans_redirect {}; + +template +struct partial_unwrap_check_htrans_redirect { typedef partial_unwrap_check_htrans_default result; }; + +template +struct partial_unwrap_check_htrans_redirect { typedef partial_unwrap_check_htrans_fixed result; }; + + +template +struct partial_unwrap_check< Op > : public partial_unwrap_check_htrans_redirect::value>::result + { + typedef typename T1::elem_type eT; + + inline partial_unwrap_check(const Op& A, const Mat& B) + : partial_unwrap_check_htrans_redirect::value>::result(A, B) + { + } + }; + + + +template +struct partial_unwrap_check< Op< Mat, op_htrans> > + { + typedef Mat stored_type; + + inline + partial_unwrap_check(const Op< Mat, op_htrans>& A, const Mat& B) + : M_local ( (&A.m == &B) ? new Mat(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + // the order below is important + const Mat* M_local; + const Mat& M; + }; + + + +template +struct partial_unwrap_check< Op< Row, op_htrans> > + { + typedef Row stored_type; + + inline + partial_unwrap_check(const Op< Row, op_htrans>& A, const Mat& B) + : M_local ( (&A.m == &B) ? new Row(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + // the order below is important + const Row* M_local; + const Row& M; + }; + + + +template +struct partial_unwrap_check< Op< Col, op_htrans> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const Op< Col, op_htrans>& A, const Mat& B) + : M_local ( (&A.m == &B) ? new Col(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + // the order below is important + const Col* M_local; + const Col& M; + }; + + + +// NOTE: we can get away with this shortcut as the partial_unwrap_check class is only used by the glue_times class, +// NOTE: which relies on partial_unwrap_check to check for aliasing +template +struct partial_unwrap_check< Op< subview_col, op_htrans> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const Op< subview_col, op_htrans>& A, const Mat& B) + : M ( const_cast( A.m.colmem ), A.m.n_rows, (&(A.m.m) == &B), false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Col M; + }; + + + +template +struct partial_unwrap_check_htrans2_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_check_htrans2_default(const Op& A, const Mat&) + : val(A.aux) + , M (A.m) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const Mat M; + }; + + + +template +struct partial_unwrap_check_htrans2_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_check_htrans2_fixed(const Op& A, const Mat& B) + : val (A.aux) + , M_local( (&(A.m) == &B) ? new T1(A.m) : nullptr ) + , M ( (&(A.m) == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check_htrans2_fixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const T1* M_local; + const T1& M; + }; + + + +template +struct partial_unwrap_check_htrans2_redirect {}; + +template +struct partial_unwrap_check_htrans2_redirect { typedef partial_unwrap_check_htrans2_default result; }; + +template +struct partial_unwrap_check_htrans2_redirect { typedef partial_unwrap_check_htrans2_fixed result; }; + + +template +struct partial_unwrap_check< Op > : public partial_unwrap_check_htrans2_redirect::value>::result + { + typedef typename T1::elem_type eT; + + inline partial_unwrap_check(const Op& A, const Mat& B) + : partial_unwrap_check_htrans2_redirect::value>::result(A, B) + { + } + }; + + + +template +struct partial_unwrap_check< Op< Mat, op_htrans2> > + { + typedef Mat stored_type; + + inline + partial_unwrap_check(const Op< Mat, op_htrans2>& A, const Mat& B) + : val (A.aux) + , M_local ( (&A.m == &B) ? new Mat(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + // the order below is important + const eT val; + const Mat* M_local; + const Mat& M; + }; + + + +template +struct partial_unwrap_check< Op< Row, op_htrans2> > + { + typedef Row stored_type; + + inline + partial_unwrap_check(const Op< Row, op_htrans2>& A, const Mat& B) + : val (A.aux) + , M_local ( (&A.m == &B) ? new Row(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + // the order below is important + const eT val; + const Row* M_local; + const Row& M; + }; + + + +template +struct partial_unwrap_check< Op< Col, op_htrans2> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const Op< Col, op_htrans2>& A, const Mat& B) + : val (A.aux) + , M_local ( (&A.m == &B) ? new Col(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + // the order below is important + const eT val; + const Col* M_local; + const Col& M; + }; + + + +// NOTE: we can get away with this shortcut as the partial_unwrap_check class is only used by the glue_times class, +// NOTE: which relies on partial_unwrap_check to check for aliasing +template +struct partial_unwrap_check< Op< subview_col, op_htrans2> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const Op< subview_col, op_htrans2>& A, const Mat& B) + : val( A.aux ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, (&(A.m.m) == &B), false ) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const eT val; + const Col M; + }; + + + +template +struct partial_unwrap_check_scalar_times_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_check_scalar_times_default(const eOp& A, const Mat&) + : val(A.aux) + , M (A.P.Q) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Mat M; + }; + + + +template +struct partial_unwrap_check_scalar_times_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_check_scalar_times_fixed(const eOp& A, const Mat& B) + : val ( A.aux ) + , M_local( (&(A.P.Q) == &B) ? new T1(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? (*M_local) : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check_scalar_times_fixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const T1* M_local; + const T1& M; + }; + + + +template +struct partial_unwrap_check_scalar_times_redirect {}; + +template +struct partial_unwrap_check_scalar_times_redirect { typedef partial_unwrap_check_scalar_times_default result; }; + +template +struct partial_unwrap_check_scalar_times_redirect { typedef partial_unwrap_check_scalar_times_fixed result; }; + + +template +struct partial_unwrap_check< eOp > : public partial_unwrap_check_scalar_times_redirect::value>::result + { + typedef typename T1::elem_type eT; + + inline partial_unwrap_check(const eOp& A, const Mat& B) + : partial_unwrap_check_scalar_times_redirect::value>::result(A, B) + { + } + }; + + + +template +struct partial_unwrap_check< eOp, eop_scalar_times> > + { + typedef Mat stored_type; + + inline + partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) + : val (A.aux) + , M_local( (&(A.P.Q) == &B) ? new Mat(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Mat* M_local; + const Mat& M; + }; + + + +template +struct partial_unwrap_check< eOp, eop_scalar_times> > + { + typedef Row stored_type; + + inline + partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) + : val(A.aux) + , M_local( (&(A.P.Q) == &B) ? new Row(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Row* M_local; + const Row& M; + }; + + + +template +struct partial_unwrap_check< eOp, eop_scalar_times> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) + : val ( A.aux ) + , M_local( (&(A.P.Q) == &B) ? new Col(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Col* M_local; + const Col& M; + }; + + + +// NOTE: we can get away with this shortcut as the partial_unwrap_check class is only used by the glue_times class, +// NOTE: which relies on partial_unwrap_check to check for aliasing +template +struct partial_unwrap_check< eOp, eop_scalar_times> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) + : val( A.aux ) + , M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, (&(A.P.Q.m) == &B), false ) + { + arma_extra_debug_sigprint(); + } + + arma_inline eT get_val() const { return val; } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const eT val; + const Col M; + }; + + + +template +struct partial_unwrap_check_neg_default + { + typedef typename T1::elem_type eT; + typedef Mat stored_type; + + inline + partial_unwrap_check_neg_default(const eOp& A, const Mat&) + : M(A.P.Q) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Mat M; + }; + + + +template +struct partial_unwrap_check_neg_fixed + { + typedef typename T1::elem_type eT; + typedef T1 stored_type; + + inline explicit + partial_unwrap_check_neg_fixed(const eOp& A, const Mat& B) + : M_local( (&(A.P.Q) == &B) ? new T1(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? (*M_local) : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check_neg_fixed() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(-1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const T1* M_local; + const T1& M; + }; + + + +template +struct partial_unwrap_check_neg_redirect {}; + +template +struct partial_unwrap_check_neg_redirect { typedef partial_unwrap_check_neg_default result; }; + +template +struct partial_unwrap_check_neg_redirect { typedef partial_unwrap_check_neg_fixed result; }; + + +template +struct partial_unwrap_check< eOp > : public partial_unwrap_check_neg_redirect::value>::result + { + typedef typename T1::elem_type eT; + + inline partial_unwrap_check(const eOp& A, const Mat& B) + : partial_unwrap_check_neg_redirect::value>::result(A, B) + { + } + }; + + + +template +struct partial_unwrap_check< eOp, eop_neg> > + { + typedef Mat stored_type; + + inline + partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) + : M_local( (&(A.P.Q) == &B) ? new Mat(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(-1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Mat* M_local; + const Mat& M; + }; + + + +template +struct partial_unwrap_check< eOp, eop_neg> > + { + typedef Row stored_type; + + inline + partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) + : M_local( (&(A.P.Q) == &B) ? new Row(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(-1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Row* M_local; + const Row& M; + }; + + + +template +struct partial_unwrap_check< eOp, eop_neg> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) + : M_local( (&(A.P.Q) == &B) ? new Col(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + { + arma_extra_debug_sigprint(); + } + + inline + ~partial_unwrap_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + constexpr eT get_val() const { return eT(-1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Col* M_local; + const Col& M; + }; + + + +// NOTE: we can get away with this shortcut as the partial_unwrap_check class is only used by the glue_times class, +// NOTE: which relies on partial_unwrap_check to check for aliasing +template +struct partial_unwrap_check< eOp, eop_neg> > + { + typedef Col stored_type; + + inline + partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) + : M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, (&(A.P.Q.m) == &B), false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(-1); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = true; + + const Col M; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/unwrap_cube.hpp b/src/armadillo/include/armadillo_bits/unwrap_cube.hpp new file mode 100644 index 0000000..ca91cfa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/unwrap_cube.hpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup unwrap_cube +//! @{ + + + +template +struct unwrap_cube + { + typedef typename T1::elem_type eT; + + inline + unwrap_cube(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Cube M; + + template + constexpr bool is_alias(const Cube&) const { return false; } + }; + + + +template +struct unwrap_cube< Cube > + { + inline + unwrap_cube(const Cube& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const Cube& M; + + template + arma_inline bool is_alias(const Cube& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +// +// +// + + + +template +struct unwrap_cube_check + { + typedef typename T1::elem_type eT; + + inline + unwrap_cube_check(const T1& A, const Cube&) + : M(A) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_arma_cube_type::value == false )); + } + + inline + unwrap_cube_check(const T1& A, const bool) + : M(A) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_arma_cube_type::value == false )); + } + + const Cube M; + }; + + + +template +struct unwrap_cube_check< Cube > + { + inline + unwrap_cube_check(const Cube& A, const Cube& B) + : M_local( (&A == &B) ? new Cube(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + + inline + unwrap_cube_check(const Cube& A, const bool is_alias) + : M_local( is_alias ? new Cube(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + + inline + ~unwrap_cube_check() + { + arma_extra_debug_sigprint(); + + if(M_local) { delete M_local; } + } + + + // the order below is important + const Cube* M_local; + const Cube& M; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/unwrap_spmat.hpp b/src/armadillo/include/armadillo_bits/unwrap_spmat.hpp new file mode 100644 index 0000000..0597aaa --- /dev/null +++ b/src/armadillo/include/armadillo_bits/unwrap_spmat.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup unwrap_spmat +//! @{ + + + +template +struct unwrap_spmat + { + typedef typename T1::elem_type eT; + + typedef SpMat stored_type; + + inline + unwrap_spmat(const T1& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const SpMat M; + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct unwrap_spmat< SpMat > + { + typedef SpMat stored_type; + + inline + unwrap_spmat(const SpMat& A) + : M(A) + { + arma_extra_debug_sigprint(); + + M.sync(); + } + + const SpMat& M; + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +template +struct unwrap_spmat< SpRow > + { + typedef SpRow stored_type; + + inline + unwrap_spmat(const SpRow& A) + : M(A) + { + arma_extra_debug_sigprint(); + + M.sync(); + } + + const SpRow& M; + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +template +struct unwrap_spmat< SpCol > + { + typedef SpCol stored_type; + + inline + unwrap_spmat(const SpCol& A) + : M(A) + { + arma_extra_debug_sigprint(); + + M.sync(); + } + + const SpCol& M; + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&M) == void_ptr(&X)); } + }; + + + +template +struct unwrap_spmat< SpOp > + { + typedef typename T1::elem_type eT; + + typedef SpMat stored_type; + + inline + unwrap_spmat(const SpOp& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const SpMat M; + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct unwrap_spmat< SpGlue > + { + typedef typename T1::elem_type eT; + + typedef SpMat stored_type; + + inline + unwrap_spmat(const SpGlue& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const SpMat M; + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct unwrap_spmat< mtSpOp > + { + typedef SpMat stored_type; + + inline + unwrap_spmat(const mtSpOp& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const SpMat M; + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct unwrap_spmat< mtSpGlue > + { + typedef SpMat stored_type; + + inline + unwrap_spmat(const mtSpGlue& A) + : M(A) + { + arma_extra_debug_sigprint(); + } + + const SpMat M; + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/upgrade_val.hpp b/src/armadillo/include/armadillo_bits/upgrade_val.hpp new file mode 100644 index 0000000..a5e9da2 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/upgrade_val.hpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup upgrade_val +//! @{ + + + +//! upgrade_val is used to ensure an operation such as multiplication is possible between two types. +//! values are upgraded only where necessary. + +template +struct upgrade_val + { + typedef typename promote_type::result T1_result; + typedef typename promote_type::result T2_result; + + arma_inline + static + typename promote_type::result + apply(const T1 x) + { + typedef typename promote_type::result out_type; + return out_type(x); + } + + arma_inline + static + typename promote_type::result + apply(const T2 x) + { + typedef typename promote_type::result out_type; + return out_type(x); + } + + }; + + +// template<> +template +struct upgrade_val + { + typedef T T1_result; + typedef T T2_result; + + arma_inline static const T& apply(const T& x) { return x; } + }; + + +//! upgrade a type to allow multiplication with a complex type +//! eg. the int in "int * complex" is upgraded to a double +// template<> +template +struct upgrade_val< std::complex, T2 > + { + typedef std::complex T1_result; + typedef T T2_result; + + arma_inline static const std::complex& apply(const std::complex& x) { return x; } + arma_inline static T apply(const T2 x) { return T(x); } + }; + + +// template<> +template +struct upgrade_val< T1, std::complex > + { + typedef T T1_result; + typedef std::complex T2_result; + + arma_inline static T apply(const T1 x) { return T(x); } + arma_inline static const std::complex& apply(const std::complex& x) { return x; } + }; + + +//! ensure we don't lose precision when multiplying a complex number with a higher precision real number +template<> +struct upgrade_val< std::complex, double > + { + typedef std::complex T1_result; + typedef double T2_result; + + arma_inline static const std::complex apply(const std::complex& x) { return std::complex(x); } + arma_inline static double apply(const double x) { return x; } + }; + + +template<> +struct upgrade_val< double, std::complex > + { + typedef double T1_result; + typedef std::complex T2_result; + + arma_inline static double apply(const double x) { return x; } + arma_inline static const std::complex apply(const std::complex& x) { return std::complex(x); } + }; + + +//! ensure we don't lose precision when multiplying complex numbers with different underlying types +template<> +struct upgrade_val< std::complex, std::complex > + { + typedef std::complex T1_result; + typedef std::complex T2_result; + + arma_inline static const std::complex apply(const std::complex& x) { return std::complex(x); } + arma_inline static const std::complex& apply(const std::complex& x) { return x; } + }; + + +template<> +struct upgrade_val< std::complex, std::complex > + { + typedef std::complex T1_result; + typedef std::complex T2_result; + + arma_inline static const std::complex& apply(const std::complex& x) { return x; } + arma_inline static const std::complex apply(const std::complex& x) { return std::complex(x); } + }; + + +//! work around limitations in the complex class (at least as present in gcc 4.1 & 4.3) +template<> +struct upgrade_val< std::complex, float > + { + typedef std::complex T1_result; + typedef double T2_result; + + arma_inline static const std::complex& apply(const std::complex& x) { return x; } + arma_inline static double apply(const float x) { return double(x); } + }; + + +template<> +struct upgrade_val< float, std::complex > + { + typedef double T1_result; + typedef std::complex T2_result; + + arma_inline static double apply(const float x) { return double(x); } + arma_inline static const std::complex& apply(const std::complex& x) { return x; } + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/wall_clock_bones.hpp b/src/armadillo/include/armadillo_bits/wall_clock_bones.hpp new file mode 100644 index 0000000..29c3014 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/wall_clock_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup wall_clock +//! @{ + + +//! Class for measuring time intervals +class wall_clock + { + public: + + inline wall_clock(); + inline ~wall_clock(); + + inline void tic(); //!< start the timer + arma_warn_unused inline double toc(); //!< return the number of seconds since the last call to tic() + + + private: + + bool valid = false; + + std::chrono::steady_clock::time_point chrono_time1; + }; + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/wall_clock_meat.hpp b/src/armadillo/include/armadillo_bits/wall_clock_meat.hpp new file mode 100644 index 0000000..54ed68a --- /dev/null +++ b/src/armadillo/include/armadillo_bits/wall_clock_meat.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup wall_clock +//! @{ + + +inline +wall_clock::wall_clock() + { + arma_extra_debug_sigprint(); + } + + + +inline +wall_clock::~wall_clock() + { + arma_extra_debug_sigprint(); + } + + + +inline +void +wall_clock::tic() + { + arma_extra_debug_sigprint(); + + chrono_time1 = std::chrono::steady_clock::now(); + valid = true; + } + + + +inline +double +wall_clock::toc() + { + arma_extra_debug_sigprint(); + + if(valid) + { + const std::chrono::steady_clock::time_point chrono_time2 = std::chrono::steady_clock::now(); + + typedef std::chrono::duration duration_type; // TODO: check this + + const duration_type chrono_span = std::chrono::duration_cast< duration_type >(chrono_time2 - chrono_time1); + + return chrono_span.count(); + } + + return 0.0; + } + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/xtrans_mat_bones.hpp b/src/armadillo/include/armadillo_bits/xtrans_mat_bones.hpp new file mode 100644 index 0000000..5875666 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/xtrans_mat_bones.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup xtrans_mat +//! @{ + + +template +class xtrans_mat : public Base< eT, xtrans_mat > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + static constexpr bool really_do_conj = (do_conj && is_cx::yes); + + arma_aligned const Mat& X; + arma_aligned mutable Mat Y; + + arma_aligned const uword n_rows; + arma_aligned const uword n_cols; + arma_aligned const uword n_elem; + + inline explicit xtrans_mat(const Mat& in_X); + + inline void extract(Mat& out) const; + + inline eT operator[](const uword ii) const; + inline eT at_alt (const uword ii) const; + + arma_inline eT at(const uword in_row, const uword in_col) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/xtrans_mat_meat.hpp b/src/armadillo/include/armadillo_bits/xtrans_mat_meat.hpp new file mode 100644 index 0000000..1872c30 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/xtrans_mat_meat.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup xtrans_mat +//! @{ + + +template +inline +xtrans_mat::xtrans_mat(const Mat& in_X) + : X (in_X ) + , n_rows(in_X.n_cols) // deliberately swapped + , n_cols(in_X.n_rows) + , n_elem(in_X.n_elem) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +xtrans_mat::extract(Mat& out) const + { + arma_extra_debug_sigprint(); + + really_do_conj ? op_htrans::apply_mat(out, X) : op_strans::apply_mat(out, X); + } + + + +template +inline +eT +xtrans_mat::operator[](const uword ii) const + { + if(Y.n_elem > 0) + { + return Y[ii]; + } + else + { + really_do_conj ? op_htrans::apply_mat(Y, X) : op_strans::apply_mat(Y, X); + return Y[ii]; + } + } + + + +template +inline +eT +xtrans_mat::at_alt(const uword ii) const + { + return (*this).operator[](ii); + } + + + +template +arma_inline +eT +xtrans_mat::at(const uword in_row, const uword in_col) const + { + return really_do_conj ? eT(access::alt_conj(X.at(in_col, in_row))) : eT(X.at(in_col, in_row)); + // in_row and in_col deliberately swapped above + } + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/xvec_htrans_bones.hpp b/src/armadillo/include/armadillo_bits/xvec_htrans_bones.hpp new file mode 100644 index 0000000..6eab710 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/xvec_htrans_bones.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup xvec_htrans +//! @{ + + +template +class xvec_htrans : public Base< eT, xvec_htrans > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; + + arma_aligned const eT* const mem; + + const uword n_rows; + const uword n_cols; + const uword n_elem; + + + inline explicit xvec_htrans(const eT* const in_mem, const uword in_n_rows, const uword in_n_cols); + + inline void extract(Mat& out) const; + + inline eT operator[](const uword ii) const; + inline eT at_alt (const uword ii) const; + + inline eT at (const uword in_row, const uword in_col) const; + }; + + + +//! @} diff --git a/src/armadillo/include/armadillo_bits/xvec_htrans_meat.hpp b/src/armadillo/include/armadillo_bits/xvec_htrans_meat.hpp new file mode 100644 index 0000000..b79a7ef --- /dev/null +++ b/src/armadillo/include/armadillo_bits/xvec_htrans_meat.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup xvec_htrans +//! @{ + + +template +inline +xvec_htrans::xvec_htrans(const eT* const in_mem, const uword in_n_rows, const uword in_n_cols) + : mem (in_mem ) + , n_rows(in_n_cols ) // deliberately swapped + , n_cols(in_n_rows ) + , n_elem(in_n_rows*in_n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +xvec_htrans::extract(Mat& out) const + { + arma_extra_debug_sigprint(); + + // NOTE: this function assumes that matrix 'out' has already been set to the correct size + + const eT* in_mem = mem; + eT* out_mem = out.memptr(); + + const uword N = n_elem; + + for(uword ii=0; ii < N; ++ii) + { + out_mem[ii] = access::alt_conj( in_mem[ii] ); + } + } + + + +template +inline +eT +xvec_htrans::operator[](const uword ii) const + { + return access::alt_conj( mem[ii] ); + } + + + +template +inline +eT +xvec_htrans::at_alt(const uword ii) const + { + return access::alt_conj( mem[ii] ); + } + + + +template +inline +eT +xvec_htrans::at(const uword in_row, const uword in_col) const + { + //return (n_rows == 1) ? access::alt_conj( mem[in_col] ) : access::alt_conj( mem[in_row] ); + + return access::alt_conj( mem[in_row + in_col] ); // either in_row or in_col must be zero, as we're storing a vector + } + + + +//! @} -- cgit v1.2.1