From eda5bc26f44ee9a6f83dcf8c91f17296d7fc509d Mon Sep 17 00:00:00 2001 From: Nao Pross Date: Mon, 12 Feb 2024 14:52:43 +0100 Subject: Move into version control --- .../include/armadillo_bits/translate_atlas.hpp | 282 +++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 src/armadillo/include/armadillo_bits/translate_atlas.hpp (limited to 'src/armadillo/include/armadillo_bits/translate_atlas.hpp') diff --git a/src/armadillo/include/armadillo_bits/translate_atlas.hpp b/src/armadillo/include/armadillo_bits/translate_atlas.hpp new file mode 100644 index 0000000..95d43d5 --- /dev/null +++ b/src/armadillo/include/armadillo_bits/translate_atlas.hpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_ATLAS) + + +// TODO: remove support for ATLAS in next major version + +//! \namespace atlas namespace for ATLAS functions +namespace atlas + { + + template + inline static const eT& tmp_real(const eT& X) { return X; } + + template + inline static const T tmp_real(const std::complex& X) { return X.real(); } + + + + template + arma_inline + eT + cblas_asum(const int N, const eT* X) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + return eT( arma_wrapper(cblas_sasum)(N, (const T*)X, 1) ); + } + else + if(is_double::value) + { + typedef double T; + return eT( arma_wrapper(cblas_dasum)(N, (const T*)X, 1) ); + } + + return eT(0); + } + + + + template + arma_inline + eT + cblas_nrm2(const int N, const eT* X) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + return eT( arma_wrapper(cblas_snrm2)(N, (const T*)X, 1) ); + } + else + if(is_double::value) + { + typedef double T; + return eT( arma_wrapper(cblas_dnrm2)(N, (const T*)X, 1) ); + } + + return eT(0); + } + + + + template + arma_inline + eT + cblas_dot(const int N, const eT* X, const eT* Y) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + return eT( arma_wrapper(cblas_sdot)(N, (const T*)X, 1, (const T*)Y, 1) ); + } + else + if(is_double::value) + { + typedef double T; + return eT( arma_wrapper(cblas_ddot)(N, (const T*)X, 1, (const T*)Y, 1) ); + } + + return eT(0); + } + + + + template + arma_inline + eT + cblas_cx_dot(const int N, const eT* X, const eT* Y) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_cx_float::value) + { + typedef typename std::complex T; + + T out; + arma_wrapper(cblas_cdotu_sub)(N, (const T*)X, 1, (const T*)Y, 1, &out); + + return eT(out); + } + else + if(is_cx_double::value) + { + typedef typename std::complex T; + + T out; + arma_wrapper(cblas_zdotu_sub)(N, (const T*)X, 1, (const T*)Y, 1, &out); + + return eT(out); + } + + return eT(0); + } + + + + template + inline + void + cblas_gemv + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, + const int M, const int N, + const eT alpha, + const eT *A, const int lda, + const eT *X, const int incX, + const eT beta, + eT *Y, const int incY + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + arma_wrapper(cblas_sgemv)(layout, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(cblas_dgemv)(layout, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); + } + else + if(is_cx_float::value) + { + typedef std::complex T; + arma_wrapper(cblas_cgemv)(layout, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); + } + else + if(is_cx_double::value) + { + typedef std::complex T; + arma_wrapper(cblas_zgemv)(layout, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); + } + } + + + + template + inline + void + cblas_gemm + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, + const atlas_CBLAS_TRANS TransB, const int M, const int N, + const int K, const eT alpha, const eT *A, + const int lda, const eT *B, const int ldb, + const eT beta, eT *C, const int ldc + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + arma_wrapper(cblas_sgemm)(layout, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(cblas_dgemm)(layout, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); + } + else + if(is_cx_float::value) + { + typedef std::complex T; + arma_wrapper(cblas_cgemm)(layout, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); + } + else + if(is_cx_double::value) + { + typedef std::complex T; + arma_wrapper(cblas_zgemm)(layout, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); + } + } + + + + template + inline + void + cblas_syrk + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, + const int N, const int K, const eT alpha, + const eT* A, const int lda, const eT beta, eT* C, const int ldc + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float T; + arma_wrapper(cblas_ssyrk)(layout, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(cblas_dsyrk)(layout, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); + } + } + + + + template + inline + void + cblas_herk + ( + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, + const int N, const int K, const T alpha, + const std::complex* A, const int lda, const T beta, std::complex* C, const int ldc + ) + { + arma_type_check((is_supported_blas_type::value == false)); + + if(is_float::value) + { + typedef float TT; + typedef std::complex cx_TT; + + arma_wrapper(cblas_cherk)(layout, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); + } + else + if(is_double::value) + { + typedef double TT; + typedef std::complex cx_TT; + + arma_wrapper(cblas_zherk)(layout, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); + } + } + + } + +#endif -- cgit v1.2.1