From 8811f681a05c1ff06d51af0a3b07357db3f8355b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20Fritze?= <rene.fritze@wwu.de>
Date: Wed, 17 Jul 2019 13:27:51 +0200
Subject: [PATCH] [provider] refactor visualization to remove duplication

---
 dune/xt/grid/gridprovider/provider.hh       |  68 +-------
 dune/xt/grid/output/entity_visualization.hh | 172 ++++++++++++--------
 2 files changed, 115 insertions(+), 125 deletions(-)

diff --git a/dune/xt/grid/gridprovider/provider.hh b/dune/xt/grid/gridprovider/provider.hh
index 4fb50ec84..929c6219a 100644
--- a/dune/xt/grid/gridprovider/provider.hh
+++ b/dune/xt/grid/gridprovider/provider.hh
@@ -31,6 +31,7 @@
 #include <dune/xt/grid/grids.hh>
 #include <dune/xt/grid/layers.hh>
 #include <dune/xt/grid/type_traits.hh>
+#include <dune/xt/grid/output/entity_visualization.hh>
 
 namespace Dune {
 namespace XT {
@@ -356,59 +357,6 @@ public:
   }
 
 private:
-  template <class G, bool enable = has_boundary_id<G>::value>
-  struct add_boundary_id_visualization
-  {
-    add_boundary_id_visualization() {}
-
-    template <class V>
-    void operator()(V& vtk_writer, const std::vector<double>& boundary_id, const int lvl) const
-    {
-      vtk_writer.addCellData(boundary_id, "boundary_id__level_" + Common::to_string(lvl));
-    }
-
-    std::vector<double> generateBoundaryIdVisualization(const LevelGridViewType& gridView) const
-    {
-      std::vector<double> data(gridView.indexSet().size(0));
-      // walk the grid
-      const auto it_end = gridView.template end<0>();
-      for (auto it = gridView.template begin<0>(); it != it_end; ++it) {
-        const auto& entity = *it;
-        const auto& index = gridView.indexSet().index(entity);
-        data[index] = 0.0;
-        size_t numberOfBoundarySegments = 0;
-        bool isOnBoundary = false;
-        const auto intersectionItEnd = gridView.iend(entity);
-        for (auto intersectionIt = gridView.ibegin(entity); intersectionIt != intersectionItEnd; ++intersectionIt) {
-          if (!intersectionIt->neighbor() && intersectionIt->boundary()) {
-            isOnBoundary = true;
-            numberOfBoundarySegments += 1;
-            data[index] += double(intersectionIt->boundaryId());
-          }
-        }
-        if (isOnBoundary) {
-          data[index] /= double(numberOfBoundarySegments);
-        }
-      } // walk the grid
-      return data;
-    }
-  }; // struct add_boundary_id_visualization<..., true>
-
-  template <class G>
-  struct add_boundary_id_visualization<G, false>
-  {
-    add_boundary_id_visualization() {}
-
-    template <class V>
-    void operator()(V& /*vtk_writer*/, const std::vector<double>& /*boundary_id*/, const int /*lvl*/) const
-    {}
-
-    std::vector<double> generateBoundaryIdVisualization(const LevelGridViewType&) const
-    {
-      return std::vector<double>();
-    }
-  };
-
   void visualize_plain(const std::string filename) const
   {
     if (GridType::dimension > 3) // give us a call if you have any idea!
@@ -421,9 +369,9 @@ private:
       std::vector<double> entityId = generateEntityVisualization(grid_view);
       vtkwriter.addCellData(entityId, "entity_id__level_" + Common::to_string(lvl));
       // boundary id
-      const add_boundary_id_visualization<GridType> add_boundary_id;
-      const std::vector<double> boundary_id = add_boundary_id.generateBoundaryIdVisualization(grid_view);
-      add_boundary_id(vtkwriter, boundary_id, lvl);
+      const std::vector<double> boundary_ids =
+          ElementVisualization::BoundaryIDFunctor<LevelGridViewType>(grid_view).values(grid_view);
+      vtkwriter.addCellData(boundary_ids, "boundary_id__level_" + Common::to_string(lvl));
       // write
       vtkwriter.write(filename + "__level_" + Common::to_string(lvl), VTK::appendedraw);
     }
@@ -444,11 +392,11 @@ private:
       std::vector<double> entityId = generateEntityVisualization(grid_view);
       vtkwriter.addCellData(entityId, "entity_id__level_" + Common::to_string(lvl));
       // boundary id
-      const add_boundary_id_visualization<GridType> add_boundary_id;
-      const std::vector<double> boundary_id = add_boundary_id.generateBoundaryIdVisualization(grid_view);
-      add_boundary_id(vtkwriter, boundary_id, lvl);
+      const std::vector<double> boundary_ids =
+          ElementVisualization::BoundaryIDFunctor<LevelGridViewType>(grid_view).values(grid_view);
+      vtkwriter.addCellData(boundary_ids, "boundary_id__level_" + Common::to_string(lvl));
       // dirichlet values
-      std::vector<double> dirichlet = generateBoundaryVisualization(grid_view, *boundary_info_ptr, "dirichlet");
+      const auto dirichlet = generateBoundaryVisualization(grid_view, *boundary_info_ptr, "dirichlet");
       vtkwriter.addCellData(dirichlet, "isDirichletBoundary__level_" + Common::to_string(lvl));
       // neumann values
       std::vector<double> neumann = generateBoundaryVisualization(grid_view, *boundary_info_ptr, "neumann");
diff --git a/dune/xt/grid/output/entity_visualization.hh b/dune/xt/grid/output/entity_visualization.hh
index 88b6ff13a..e890b00dd 100644
--- a/dune/xt/grid/output/entity_visualization.hh
+++ b/dune/xt/grid/output/entity_visualization.hh
@@ -23,6 +23,9 @@
 #include <dune/xt/common/filesystem.hh>
 #include <dune/xt/common/logging.hh>
 #include <dune/xt/common/ranges.hh>
+#include <dune/xt/grid/capabilities.hh>
+#include <dune/xt/common/type_traits.hh>
+#include <dune/xt/grid/type_traits.hh>
 
 namespace Dune {
 namespace XT {
@@ -43,21 +46,19 @@ struct ElementVisualization
   };
 
   // demonstrate attaching data to elements
-  template <class Grid, class F>
-  static void elementdata(const Grid& grid, const F& f)
+  template <class View, class F>
+  static void elementdata(const View& view, const F& f)
   {
-    // get grid view on leaf part
-    const auto gridView = grid.leafGridView();
-
     // make a mapper for codim 0 entities in the leaf grid
-    Dune::LeafMultipleCodimMultipleGeomTypeMapper<Grid, P0Layout> mapper(grid);
+    using Grid = extract_grid_t<View>;
+    Dune::LeafMultipleCodimMultipleGeomTypeMapper<Grid, P0Layout> mapper(view.grid());
 
     std::vector<double> values(mapper.size());
-    for (auto&& entity : elements(gridView)) {
+    for (auto&& entity : elements(view)) {
       values[mapper.index(entity)] = f(entity);
     }
 
-    Dune::VTKWriter<typename Grid::LeafGridView> vtkwriter(gridView);
+    Dune::VTKWriter<typename Grid::LeafGridView> vtkwriter(view);
     vtkwriter.addCellData(values, "data");
     const std::string piecefilesFolderName = "piecefiles";
     const std::string piecefilesPath = f.dir() + "/" + piecefilesFolderName + "/";
@@ -65,13 +66,18 @@ struct ElementVisualization
     vtkwriter.pwrite(f.filename(), f.dir(), piecefilesFolderName, Dune::VTK::appendedraw);
   }
 
+  template <class GridViewType>
   class FunctorBase
   {
   public:
-    FunctorBase(const std::string fname, const std::string dname)
-      : filename_(fname)
-      , dir_(dname)
+    using Element = extract_entity_t<GridViewType>;
+    FunctorBase(std::string filename = "Functor", const std::string dirname = ".")
+      : filename_(filename)
+      , dir_(dirname)
     {}
+
+    virtual ~FunctorBase() {}
+
     const std::string filename() const
     {
       return filename_;
@@ -81,59 +87,71 @@ struct ElementVisualization
       return dir_;
     }
 
+    virtual double operator()(const Element& /*ent*/) const = 0;
+
+    std::vector<double> values(const GridViewType& view)
+    {
+      std::vector<double> ret(view.size(0));
+      return ret;
+    }
+
   protected:
     const std::string filename_;
     const std::string dir_;
   };
 
-  class VolumeFunctor : public FunctorBase
+  template <class GridViewType>
+  class VolumeFunctor : public FunctorBase<GridViewType>
   {
   public:
-    VolumeFunctor(const std::string fname, const std::string dname)
-      : FunctorBase(fname, dname)
+    using Element = typename FunctorBase<GridViewType>::Element;
+    VolumeFunctor(std::string filename = "VolumeFunctor", const std::string dirname = ".")
+      : FunctorBase<GridViewType>(filename, dirname)
     {}
 
-    template <class Entity>
-    double operator()(const Entity& ent) const
+    double operator()(const Element& ent) const
     {
       return ent.geometry().volume();
     }
   };
 
-  class ProcessIdFunctor : public FunctorBase
+  template <class GridViewType>
+  class ProcessIdFunctor : public FunctorBase<GridViewType>
   {
   public:
-    ProcessIdFunctor(const std::string fname, const std::string dname)
-      : FunctorBase(fname, dname)
+    using Element = typename FunctorBase<GridViewType>::Element;
+    ProcessIdFunctor(std::string filename = "ProcessIDFunctor", const std::string dirname = ".")
+      : FunctorBase<GridViewType>(filename, dirname)
     {}
 
-    template <class Entity>
-    double operator()(const Entity& /*ent*/) const
+    double operator()(const Element& /*ent*/) const
     {
       return Dune::MPIHelper::getCollectiveCommunication().rank();
     }
   };
 
-  template <class GridType>
-  class BoundaryFunctor : public FunctorBase
+  template <class GridViewType, bool enable = has_boundary_id<GridViewType>::value>
+  class BoundaryIDFunctor : public FunctorBase<GridViewType>
   {
-    const GridType& grid_;
+    const GridViewType& gridview_;
 
   public:
-    BoundaryFunctor(const GridType& grid, const std::string fname, const std::string dname)
-      : FunctorBase(fname, dname)
-      , grid_(grid)
+    using Element = typename FunctorBase<GridViewType>::Element;
+    BoundaryIDFunctor(const GridViewType& view,
+                      std::string filename = "BoundaryIDFunctor",
+                      const std::string dirname = ".")
+      : FunctorBase<GridViewType>(filename, dirname)
+      , gridview_(view)
     {}
 
-    template <class Entity>
-    double operator()(const Entity& entity) const
+    double operator()(const Element& entity) const
     {
       double ret(0.0);
       int numberOfBoundarySegments(0);
       bool isOnBoundary = false;
-      const auto leafview = grid_.leafGridView();
-      const auto intersection_it_end = leafview.iend(entity);
-      for (auto intersection_it = leafview.ibegin(entity); intersection_it != intersection_it_end; ++intersection_it) {
+
+      const auto intersection_it_end = gridview_.iend(entity);
+      for (auto intersection_it = gridview_.ibegin(entity); intersection_it != intersection_it_end; ++intersection_it) {
         const auto& intersection = *intersection_it;
         if (!intersection.neighbor() && intersection.boundary()) {
           isOnBoundary = true;
@@ -148,23 +166,44 @@ struct ElementVisualization
     }
   };
 
-  class AreaMarker : public FunctorBase
+  template <class GridViewType>
+  class BoundaryIDFunctor<GridViewType, false> : public FunctorBase<GridViewType>
   {
+    const GridViewType& gridview_;
 
   public:
-    AreaMarker(const std::string fname, const std::string dname)
-      : FunctorBase(fname, dname)
-    {}
+    using Element = typename FunctorBase<GridViewType>::Element;
+    BoundaryIDFunctor(const GridViewType& view,
+                      std::string filename = "BoundaryIDFunctor",
+                      const std::string dirname = ".")
+      : FunctorBase<GridViewType>(filename, dirname)
+      , gridview_(view)
+    {
+      DXTC_LOG_INFO_0 << "Boundary visualization for unsupported grid requested " << XT::Common::get_typename(gridview_)
+                      << std::endl;
+    }
 
-    template <class Entity>
-    double operator()(const Entity& entity) const
+    double operator()(const Element&) const
     {
-      typedef typename Entity::Geometry EntityGeometryType;
+      return -1;
+    }
+  };
 
-      typedef Dune::FieldVector<typename EntityGeometryType::ctype, EntityGeometryType::coorddimension> DomainType;
+  template <class GridViewType>
+  class AreaMarker : public FunctorBase<GridViewType>
+  {
 
-      const EntityGeometryType& geometry = entity.geometry();
+  public:
+    using Element = typename FunctorBase<GridViewType>::Element;
+    AreaMarker(std::string filename = "AreaFunctor", const std::string dirname = ".")
+      : FunctorBase<GridViewType>(filename, dirname)
+    {}
 
+    double operator()(const Element& entity) const
+    {
+      typedef typename Element::Geometry EntityGeometryType;
+      typedef Dune::FieldVector<typename EntityGeometryType::ctype, EntityGeometryType::coorddimension> DomainType;
+      const EntityGeometryType& geometry = entity.geometry();
       DomainType baryCenter(0.0);
 
       for (auto corner : Common::value_range(geometry.corners())) {
@@ -183,17 +222,18 @@ struct ElementVisualization
     }
   };
 
-  class GeometryFunctor : public FunctorBase
+  template <class GridViewType>
+  class GeometryFunctor : public FunctorBase<GridViewType>
   {
   public:
-    GeometryFunctor(const std::string fname, const std::string dname)
-      : FunctorBase(fname, dname)
+    using Element = typename FunctorBase<GridViewType>::Element;
+    GeometryFunctor(std::string filename = "GeometryFunctor", const std::string dirname = ".")
+      : FunctorBase<GridViewType>(filename, dirname)
     {}
 
-    template <class Entity>
-    double operator()(const Entity& ent) const
+    double operator()(const Element& ent) const
     {
-      const typename Entity::Geometry& geo = ent.geometry();
+      const auto& geo = ent.geometry();
       double vol = geo.volume();
       if (vol < 0) {
         boost::io::ios_all_saver guard(DXTC_LOG_ERROR);
@@ -208,17 +248,18 @@ struct ElementVisualization
     }
   };
 
-  class PartitionTypeFunctor : public FunctorBase
+  template <class GridViewType>
+  class PartitionTypeFunctor : public FunctorBase<GridViewType>
   {
   public:
-    PartitionTypeFunctor(const std::string fname, const std::string dname)
-      : FunctorBase(fname, dname)
+    using Element = typename FunctorBase<GridViewType>::Element;
+    PartitionTypeFunctor(std::string filename = "PartitionTypeFunctor", const std::string dirname = ".")
+      : FunctorBase<GridViewType>(filename, dirname)
     {}
 
-    template <class Entity>
-    double operator()(const Entity& ent) const
+    double operator()(const Element& ent) const
     {
-      const typename Entity::Geometry& geo = ent.geometry();
+      const auto& geo = ent.geometry();
       const int type{static_cast<int>(ent.partitionType())};
       DXTC_LOG_ERROR << "TYPE " << type << std::endl;
       return static_cast<double>(type);
@@ -230,20 +271,21 @@ struct ElementVisualization
   static void all(const Grid& grid, const std::string outputDir = "visualisation")
   {
     // make function objects
-    BoundaryFunctor<Grid> boundaryFunctor(grid, "boundaryFunctor", outputDir);
-    AreaMarker areaMarker("areaMarker", outputDir);
-    GeometryFunctor geometryFunctor("geometryFunctor", outputDir);
-    ProcessIdFunctor processIdFunctor("ProcessIdFunctor", outputDir);
-    VolumeFunctor volumeFunctor("volumeFunctor", outputDir);
-    PartitionTypeFunctor partitionTypeFunctor("partitionTypeFunctor", outputDir);
+    BoundaryIDFunctor<Grid> boundaryFunctor(grid, "boundaryFunctor", outputDir);
+    AreaMarker<Grid> areaMarker("areaMarker", outputDir);
+    GeometryFunctor<Grid> geometryFunctor("geometryFunctor", outputDir);
+    ProcessIdFunctor<Grid> processIdFunctor("ProcessIdFunctor", outputDir);
+    VolumeFunctor<Grid> volumeFunctor("volumeFunctor", outputDir);
+    PartitionTypeFunctor<Grid> partitionTypeFunctor("partitionTypeFunctor", outputDir);
 
     // call the visualization functions
-    elementdata(grid, boundaryFunctor);
-    elementdata(grid, areaMarker);
-    elementdata(grid, geometryFunctor);
-    elementdata(grid, processIdFunctor);
-    elementdata(grid, volumeFunctor);
-    elementdata(grid, partitionTypeFunctor);
+    const auto view = grid.leafGridView();
+    elementdata(view, boundaryFunctor);
+    elementdata(view, areaMarker);
+    elementdata(view, geometryFunctor);
+    elementdata(view, processIdFunctor);
+    elementdata(view, volumeFunctor);
+    elementdata(view, partitionTypeFunctor);
   }
 };
 
-- 
GitLab