diff --git a/dune/xt/common/numeric.hh b/dune/xt/common/numeric.hh index bf48996d7824bdf3db9d2788bbcdd780e3779288..48abc3f7c45e3e8a3658e1db403a0b9f23499a88 100644 --- a/dune/xt/common/numeric.hh +++ b/dune/xt/common/numeric.hh @@ -25,43 +25,23 @@ namespace XT { namespace Common { -template <class InputIt, class T> -T reduce(InputIt first, InputIt last, T init) +template <class... Args> +decltype(auto) reduce(Args&&... args) { #if CPP17_PARALLELISM_TS_SUPPORTED - return std::reduce(first, last, init); + return std::reduce(std::forward<Args>(args)...); #else - return std::accumulate(first, last, init); + return std::accumulate(std::forward<Args>(args)...); #endif } -template <class InputIt, class T, class BinaryOp> -T reduce(InputIt first, InputIt last, T init, BinaryOp binary_op) +template <class... Args> +decltype(auto) transform_reduce(Args&&... args) { #if CPP17_PARALLELISM_TS_SUPPORTED - return std::reduce(first, last, init, binary_op); + return std::transform_reduce(std::forward<Args>(args)...); #else - return std::accumulate(first, last, init, binary_op); -#endif -} - -template <class InputIt1, class InputIt2, class T> -T transform_reduce(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init) -{ -#if CPP17_PARALLELISM_TS_SUPPORTED - return std::transform_reduce(first1, last1, first2, init); -#else - return std::inner_product(first1, last1, first2, init); -#endif -} - -template <class InputIt1, class InputIt2, class T, class BinaryOp1, class BinaryOp2> -T transform_reduce(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init, BinaryOp1 binary_op1, BinaryOp2 binary_op2) -{ -#if CPP17_PARALLELISM_TS_SUPPORTED - return std::transform_reduce(first1, last1, first2, init, binary_op1, binary_op2); -#else - return std::inner_product(first1, last1, first2, init, binary_op1, binary_op2); + return std::inner_product(std::forward<Args>(args)...); #endif }