diff --git a/dune/xt/common/numeric.hh b/dune/xt/common/numeric.hh
index bf48996d7824bdf3db9d2788bbcdd780e3779288..48abc3f7c45e3e8a3658e1db403a0b9f23499a88 100644
--- a/dune/xt/common/numeric.hh
+++ b/dune/xt/common/numeric.hh
@@ -25,43 +25,23 @@ namespace XT {
 namespace Common {
 
 
-template <class InputIt, class T>
-T reduce(InputIt first, InputIt last, T init)
+template <class... Args>
+decltype(auto) reduce(Args&&... args)
 {
 #if CPP17_PARALLELISM_TS_SUPPORTED
-  return std::reduce(first, last, init);
+  return std::reduce(std::forward<Args>(args)...);
 #else
-  return std::accumulate(first, last, init);
+  return std::accumulate(std::forward<Args>(args)...);
 #endif
 }
 
-template <class InputIt, class T, class BinaryOp>
-T reduce(InputIt first, InputIt last, T init, BinaryOp binary_op)
+template <class... Args>
+decltype(auto) transform_reduce(Args&&... args)
 {
 #if CPP17_PARALLELISM_TS_SUPPORTED
-  return std::reduce(first, last, init, binary_op);
+  return std::transform_reduce(std::forward<Args>(args)...);
 #else
-  return std::accumulate(first, last, init, binary_op);
-#endif
-}
-
-template <class InputIt1, class InputIt2, class T>
-T transform_reduce(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init)
-{
-#if CPP17_PARALLELISM_TS_SUPPORTED
-  return std::transform_reduce(first1, last1, first2, init);
-#else
-  return std::inner_product(first1, last1, first2, init);
-#endif
-}
-
-template <class InputIt1, class InputIt2, class T, class BinaryOp1, class BinaryOp2>
-T transform_reduce(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init, BinaryOp1 binary_op1, BinaryOp2 binary_op2)
-{
-#if CPP17_PARALLELISM_TS_SUPPORTED
-  return std::transform_reduce(first1, last1, first2, init, binary_op1, binary_op2);
-#else
-  return std::inner_product(first1, last1, first2, init, binary_op1, binary_op2);
+  return std::inner_product(std::forward<Args>(args)...);
 #endif
 }