From d0952ab03ad5d4aa38257608458da75be3c3136c Mon Sep 17 00:00:00 2001
From: Tobias Leibner <tobias.leibner@googlemail.com>
Date: Wed, 31 Oct 2018 12:04:02 +0100
Subject: [PATCH] [walker] do not dereference PerThreadValue on all elements

---
 .gitsuper              |   4 +-
 dune/xt/grid/walker.hh | 126 +++++++++++++++++++++++++++++------------
 2 files changed, 93 insertions(+), 37 deletions(-)

diff --git a/.gitsuper b/.gitsuper
index 85c1e6f39..3152dc3c7 100644
--- a/.gitsuper
+++ b/.gitsuper
@@ -17,7 +17,7 @@ status = 1a3bcab04b011a5d6e44f9983cae6ff89fa695e8 bin (heads/master)
 	+30e7ad34db59be19bbf67bb72fc52eba50a5245d dune-xt-common (heads/dailywork_tleibner)
 	+3e3f3bf06e21cbbf3c0a559891b44c6f5d987d0f dune-xt-data (heads/master)
 	+f05aa7470ead4150ca7a91894cd2ad77dfcedc46 dune-xt-functions (heads/new-master)
-	+4276ffe6f0f1f28217eb6f8f064f5b86d7b89862 dune-xt-grid (heads/new-master)
+	+cf70396c9755a34d8cea36a220ce800d44360967 dune-xt-grid (heads/new-master)
 	+f6904b69f9a3ee5d45ee824d3b244e59cfed7ff4 dune-xt-la (heads/master)
 	 09d0378f616b94d68bcdd9fc6114813181849ec0 scripts (heads/master)
 commit = 5f5841ee7a2dff290b98845c46262151752189c1
@@ -110,7 +110,7 @@ commit = f05aa7470ead4150ca7a91894cd2ad77dfcedc46
 [submodule.dune-xt-grid]
 remote = git@github.com:dune-community/dune-xt-grid.git
 status = 2424627f0ad5de7e4aaa5e7f48bc2a02414d95a1 .vcsetup (heads/master)
-commit = 4276ffe6f0f1f28217eb6f8f064f5b86d7b89862
+commit = cf70396c9755a34d8cea36a220ce800d44360967
 
 [submodule.dune-xt-la]
 remote = git@github.com:dune-community/dune-xt-la.git
diff --git a/dune/xt/grid/walker.hh b/dune/xt/grid/walker.hh
index 667cbc9b3..b5cfe770a 100644
--- a/dune/xt/grid/walker.hh
+++ b/dune/xt/grid/walker.hh
@@ -229,6 +229,18 @@ private:
     thread_storage = Common::PerThreadValue<std::list<WrapperType>>(storage);
   }
 
+  void reinitialize_thread_storage()
+  {
+    element_functor_wrappers_ = Common::PerThreadValue<std::list<internal::ElementFunctorWrapper<GridViewType>>>(
+        stored_element_functor_wrappers_);
+    intersection_functor_wrappers_ =
+        Common::PerThreadValue<std::list<internal::IntersectionFunctorWrapper<GridViewType>>>(
+            stored_intersection_functor_wrappers_);
+    element_and_intersection_functor_wrappers_ =
+        Common::PerThreadValue<std::list<internal::ElementAndIntersectionFunctorWrapper<GridViewType>>>(
+            stored_element_and_intersection_functor_wrappers_);
+  }
+
 public:
   explicit Walker(GridViewType grd_vw)
     : grid_view_(grd_vw)
@@ -241,14 +253,7 @@ public:
     , stored_intersection_functor_wrappers_(other.stored_intersection_functor_wrappers_)
     , stored_element_and_intersection_functor_wrappers_(other.stored_element_and_intersection_functor_wrappers_)
   {
-    element_functor_wrappers_ = Common::PerThreadValue<std::list<internal::ElementFunctorWrapper<GridViewType>>>(
-        stored_element_functor_wrappers_);
-    intersection_functor_wrappers_ =
-        Common::PerThreadValue<std::list<internal::IntersectionFunctorWrapper<GridViewType>>>(
-            stored_intersection_functor_wrappers_);
-    element_and_intersection_functor_wrappers_ =
-        Common::PerThreadValue<std::list<internal::ElementAndIntersectionFunctorWrapper<GridViewType>>>(
-            stored_element_and_intersection_functor_wrappers_);
+    reinitialize_thread_storage();
   }
 
   Walker(ThisType&& source) = default;
@@ -492,24 +497,44 @@ public:
 
   virtual void prepare() override
   {
-    auto prep = [](auto& functor_list) {
-      for (auto&& pt_wrapper : functor_list) {
-        for (auto&& wrapper : pt_wrapper)
-          wrapper.functor().prepare();
-      }
+    auto prep = [](auto& wrapper_list) {
+      for (auto&& wrapper : wrapper_list)
+        wrapper.functor().prepare();
     };
-    prep(element_functor_wrappers_);
-    prep(intersection_functor_wrappers_);
-    prep(element_and_intersection_functor_wrappers_);
+    prep(stored_element_functor_wrappers_);
+    prep(stored_intersection_functor_wrappers_);
+    prep(stored_element_and_intersection_functor_wrappers_);
+    // prepare is called in a single thread, so make sure all functors in the thread are also prepared
+    reinitialize_thread_storage();
   } // ... prepare()
 
+  // prepare calling thread
+  void prepare_thread()
+  {
+    auto prep = [](auto& wrapper_list) {
+      for (auto&& wrapper : wrapper_list)
+        wrapper.functor().prepare();
+    };
+    prep(*element_functor_wrappers_);
+    prep(*intersection_functor_wrappers_);
+    prep(*element_and_intersection_functor_wrappers_);
+  }
+
   virtual void apply_local(const ElementType& element) override
   {
-    for (auto&& wrapper : *element_functor_wrappers_) {
+    apply_local(element, *element_functor_wrappers_, *element_and_intersection_functor_wrappers_);
+  } // ... apply_local(...)
+
+  void apply_local(const ElementType& element,
+                   std::list<internal::ElementFunctorWrapper<GridViewType>>& element_functor_wrappers,
+                   std::list<internal::ElementAndIntersectionFunctorWrapper<GridViewType>>&
+                       element_and_intersection_functor_wrappers)
+  {
+    for (auto&& wrapper : element_functor_wrappers) {
       if (wrapper.filter().contains(grid_view_, element))
         wrapper.functor().apply_local(element);
     }
-    for (auto&& wrapper : *element_and_intersection_functor_wrappers_) {
+    for (auto&& wrapper : element_and_intersection_functor_wrappers) {
       if (wrapper.element_filter().contains(grid_view_, element))
         wrapper.functor().apply_local(element);
     }
@@ -519,21 +544,36 @@ public:
                            const ElementType& inside_element,
                            const ElementType& outside_element) override
   {
-    for (auto&& wrapper : *intersection_functor_wrappers_) {
+    apply_local(intersection,
+                inside_element,
+                outside_element,
+                *intersection_functor_wrappers_,
+                *element_and_intersection_functor_wrappers_);
+  } // ... apply_local(...)
+
+  virtual void apply_local(const IntersectionType& intersection,
+                           const ElementType& inside_element,
+                           const ElementType& outside_element,
+                           std::list<internal::IntersectionFunctorWrapper<GridViewType>>& intersection_functor_wrappers,
+                           std::list<internal::ElementAndIntersectionFunctorWrapper<GridViewType>>&
+                               element_and_intersection_functor_wrappers)
+  {
+    for (auto&& wrapper : intersection_functor_wrappers) {
       if (wrapper.filter().contains(grid_view_, intersection))
         wrapper.functor().apply_local(intersection, inside_element, outside_element);
     }
-    for (auto&& wrapper : *element_and_intersection_functor_wrappers_) {
+    for (auto&& wrapper : element_and_intersection_functor_wrappers) {
       if (wrapper.intersection_filter().contains(grid_view_, intersection))
         wrapper.functor().apply_local(intersection, inside_element, outside_element);
     }
   } // ... apply_local(...)
 
+  // finalize all threads
   virtual void finalize() override
   {
-    auto fin = [](auto& list) {
-      for (auto&& pt : list) {
-        for (auto&& wrapper : pt)
+    auto fin = [](auto& per_thread_value) {
+      for (auto&& list : per_thread_value) {
+        for (auto&& wrapper : list)
           wrapper.functor().finalize();
       }
     };
@@ -542,6 +582,18 @@ public:
     fin(intersection_functor_wrappers_);
   } // ... finalize()
 
+  // finalize calling thread
+  void finalize_thread()
+  {
+    auto fin = [](auto& wrapper_list) {
+      for (auto&& wrapper : wrapper_list)
+        wrapper.functor().finalize();
+    };
+    fin(*element_functor_wrappers_);
+    fin(*intersection_functor_wrappers_);
+    fin(*element_and_intersection_functor_wrappers_);
+  }
+
   /**
    * \}
    */
@@ -581,14 +633,7 @@ public:
     stored_element_functor_wrappers_.clear();
     stored_intersection_functor_wrappers_.clear();
     stored_element_and_intersection_functor_wrappers_.clear();
-    element_functor_wrappers_ = Common::PerThreadValue<std::list<internal::ElementFunctorWrapper<GridViewType>>>(
-        stored_element_functor_wrappers_);
-    intersection_functor_wrappers_ =
-        Common::PerThreadValue<std::list<internal::IntersectionFunctorWrapper<GridViewType>>>(
-            stored_intersection_functor_wrappers_);
-    element_and_intersection_functor_wrappers_ =
-        Common::PerThreadValue<std::list<internal::ElementAndIntersectionFunctorWrapper<GridViewType>>>(
-            stored_element_and_intersection_functor_wrappers_);
+    reinitialize_thread_storage();
   }
 
   BaseType* copy() override
@@ -679,6 +724,9 @@ private:
   template <class ElementRange>
   void walk_range(const ElementRange& element_range)
   {
+    auto& element_functor_wrappers = *element_functor_wrappers_;
+    auto& intersection_functor_wrappers = *intersection_functor_wrappers_;
+    auto& element_and_intersection_functor_wrappers = *element_and_intersection_functor_wrappers_;
 #ifdef __INTEL_COMPILER
     const auto it_end = element_range.end();
     for (auto it = element_range.begin(); it != it_end; ++it) {
@@ -687,10 +735,10 @@ private:
     for (const ElementType& element : element_range) {
 #endif
       // apply element functors
-      apply_local(element);
+      apply_local(element, element_functor_wrappers, element_and_intersection_functor_wrappers);
 
       // only walk the intersections, if there are codim1 functors present
-      if ((intersection_functor_wrappers_->size() + element_and_intersection_functor_wrappers_->size()) > 0) {
+      if ((intersection_functor_wrappers.size() + element_and_intersection_functor_wrappers.size()) > 0) {
         // Do not use intersections(...) here, since that does not work for a SubdomainGridPart which is based on
         // alugrid and then wrapped as a grid view (see also https://github.com/dune-community/dune-xt-grid/issues/26)
         const auto intersection_it_end = grid_view_.iend(element);
@@ -699,9 +747,17 @@ private:
           const auto& intersection = *intersection_it;
           if (intersection.neighbor()) {
             const auto neighbor = intersection.outside();
-            apply_local(intersection, element, neighbor);
+            apply_local(intersection,
+                        element,
+                        neighbor,
+                        intersection_functor_wrappers,
+                        element_and_intersection_functor_wrappers);
           } else
-            apply_local(intersection, element, element);
+            apply_local(intersection,
+                        element,
+                        element,
+                        intersection_functor_wrappers,
+                        element_and_intersection_functor_wrappers);
         } // walk the intersections
       } // only walk the intersections, if there are codim1 functors present
     } // .. walk elements
-- 
GitLab