diff --git a/dune/xt/functions/base/visualization.hh b/dune/xt/functions/base/visualization.hh index cfbbb053442e358dae829e0886db9bf160a9b5a9..00b898a44c5d361fbd44cb187e7cea45622fe3da 100644 --- a/dune/xt/functions/base/visualization.hh +++ b/dune/xt/functions/base/visualization.hh @@ -13,61 +13,60 @@ #ifndef DUNE_XT_FUNCTIONS_BASE_VISUALIZATION_HH #define DUNE_XT_FUNCTIONS_BASE_VISUALIZATION_HH +#include <algorithm> + #include <dune/grid/io/file/vtk/function.hh> #include <dune/xt/common/numeric_cast.hh> +#include <dune/xt/common/parameter.hh> + #include <dune/xt/grid/type_traits.hh> -#include <dune/xt/functions/interfaces/grid-function.hh> namespace Dune { namespace XT { namespace Functions { -template <class GridViewType, size_t range_dim, size_t range_dim_cols, class RangeField> -class VisualizationAdapter : public VTKFunction<GridViewType> -{ - static_assert(XT::Grid::is_view<GridViewType>::value, ""); +// forward +template <class Element, size_t rangeDim, size_t rangeDimCols, class RangeField> +class GridFunctionInterface; + +template <size_t r, size_t rC, class R = double> +class VisualizerInterface +{ public: - using EntityType = XT::Grid::extract_entity_t<GridViewType>; - using GridFunctionType = GridFunctionInterface<EntityType, range_dim, range_dim_cols, RangeField>; + using RangeType = typename RangeTypeSelector<R, r, rC>::type; -private: - using LocalFunctionType = typename GridFunctionType::LocalFunctionType; - using DomainType = typename LocalFunctionType::DomainType; - using RangeType = typename LocalFunctionType::RangeType; + virtual int ncomps() const = 0; + + virtual double evaluate(const int& comp, const RangeType& val) const = 0; +}; // class VisualizerInterface + + +// visualizes all components of the function +template <size_t r, size_t rC, class R = double> +class DefaultVisualizer : public VisualizerInterface<r, rC, R> +{ + using BaseType = VisualizerInterface<r, rC, R>; public: - VisualizationAdapter(const GridFunctionType& localizable_function, - const std::string nm = "", - const XT::Common::Parameter& param = {}) - : local_function_(localizable_function.local_function()) - , name_(nm.empty() ? localizable_function.name() : nm) - , param_(param) - {} + using typename BaseType::RangeType; - int ncomps() const override final + virtual int ncomps() const override final { return helper<>::ncomps(); } - std::string name() const override final - { - return name_; - } - - double evaluate(int comp, const EntityType& en, const DomainType& xx) const override final + virtual double evaluate(const int& comp, const RangeType& val) const override final { - local_function_->bind(en); - return helper<>::evaluate(comp, local_function_->evaluate(xx, param_)); + return helper<>::evaluate(comp, val); } private: - template <size_t r_ = range_dim, size_t rC_ = range_dim_cols, bool anything = true> - class helper + template <size_t r_ = r, size_t rC_ = rC, bool anything = true> + struct helper { - public: static int ncomps() { return 1; @@ -80,9 +79,8 @@ private: }; // class helper<...> template <size_t r_, bool anything> - class helper<r_, 1, anything> + struct helper<r_, 1, anything> { - public: static int ncomps() { return r_; @@ -94,9 +92,133 @@ private: assert(comp < Common::numeric_cast<int>(r_)); return val[comp]; } - }; // class helper<..., 1> + }; // struct helper<..., 1> +}; // class VisualizerInterface + + +template <size_t r, size_t rC = 1, class R = double> +class SumVisualizer : public VisualizerInterface<r, rC, R> +{ + static_assert(rC == 1, "Not implemented for rC > 1"); + using BaseType = VisualizerInterface<r, rC, R>; + +public: + using typename BaseType::RangeType; + + virtual int ncomps() const override final + { + return 1; + } + + virtual double evaluate(const int& comp, const RangeType& val) const override final + { + return std::accumulate(val.begin(), val.end(), 0.); + } +}; // class SumVisualizer + + +template <size_t r, size_t rC = 1, class R = double> +class ComponentVisualizer : public VisualizerInterface<r, rC, R> +{ + static_assert(rC == 1, "Not implemented for rC > 1"); + using BaseType = VisualizerInterface<r, rC, R>; + +public: + using typename BaseType::RangeType; + + ComponentVisualizer(const int comp) + : comp_(comp) + {} + + virtual int ncomps() const override final + { + return 1; + } + + virtual double evaluate(const int& comp, const RangeType& val) const override final + { + return val[comp]; + } +private: + int comp_; +}; // class ComponentVisualizer + + +template <size_t r, size_t rC = 1, class R = double> +class GenericVisualizer : public VisualizerInterface<r, rC, R> +{ + using BaseType = VisualizerInterface<r, rC, R>; + +public: + using typename BaseType::RangeType; + using EvalType = std::function<double(const int comp, const RangeType& val)>; + + GenericVisualizer(const int ncomps, EvalType eval) + : ncomps_(ncomps) + , eval_(eval) + {} + + virtual int ncomps() const override final + { + return ncomps_; + } + + virtual double evaluate(const int& comp, const RangeType& val) const override final + { + return eval_(comp, val); + } + +private: + int ncomps_; + EvalType eval_; +}; // class GenericVisualizer + + +template <class GridViewType, size_t range_dim, size_t range_dim_cols, class RangeField> +class VisualizationAdapter : public VTKFunction<GridViewType> +{ + static_assert(XT::Grid::is_view<GridViewType>::value, ""); + +public: + using EntityType = XT::Grid::extract_entity_t<GridViewType>; + using GridFunctionType = GridFunctionInterface<EntityType, range_dim, range_dim_cols, RangeField>; + +private: + using LocalFunctionType = typename GridFunctionType::LocalFunctionType; + using DomainType = typename LocalFunctionType::DomainType; + +public: + VisualizationAdapter(const GridFunctionType& localizable_function, + const VisualizerInterface<range_dim, range_dim_cols, RangeField>& visualizer, + const std::string nm = "", + const XT::Common::Parameter& param = {}) + : local_function_(localizable_function.local_function()) + , visualizer_(visualizer) + , name_(nm.empty() ? localizable_function.name() : nm) + , param_(param) + {} + + int ncomps() const override final + { + return visualizer_.ncomps(); + } + + std::string name() const override final + { + return name_; + } + + double evaluate(int comp, const EntityType& en, const DomainType& xx) const override final + { + local_function_->bind(en); + const auto value = local_function_->evaluate(xx, param_); + return visualizer_.evaluate(comp, value); + } + +private: mutable std::unique_ptr<LocalFunctionType> local_function_; + const VisualizerInterface<range_dim, range_dim_cols, RangeField>& visualizer_; const std::string name_; const XT::Common::Parameter param_; }; // class VisualizationAdapter diff --git a/dune/xt/functions/interfaces/function.hh b/dune/xt/functions/interfaces/function.hh index cba06a0cd744af1b4068e9ae50f54fb65cf59a93..538fd35333bd6d1188ed7b039e98bf48df56daf2 100644 --- a/dune/xt/functions/interfaces/function.hh +++ b/dune/xt/functions/interfaces/function.hh @@ -100,6 +100,7 @@ public: using RangeReturnType = typename RangeSelector::return_type; using DerivativeRangeReturnType = typename DerivativeRangeSelector::return_type; + using RowDerivativeRangeReturnType = typename DerivativeRangeSelector::row_derivative_return_type; using SingleDerivativeRangeReturnType = typename DerivativeRangeSelector::return_single_type; /** diff --git a/dune/xt/functions/interfaces/grid-function.hh b/dune/xt/functions/interfaces/grid-function.hh index 8679658c57fb5fc66a07d9bac2d2396c7becfa73..1b310e68180bb7ec39351cfad413c4b5dc70678a 100644 --- a/dune/xt/functions/interfaces/grid-function.hh +++ b/dune/xt/functions/interfaces/grid-function.hh @@ -36,6 +36,7 @@ #include <dune/xt/functions/exceptions.hh> #include <dune/xt/functions/type_traits.hh> +#include <dune/xt/functions/base/visualization.hh> #include "element-functions.hh" @@ -44,10 +45,6 @@ namespace XT { namespace Functions { -// forward, required in GridFunctionInterface::visualize -template <class GridViewType, size_t range_dim, size_t range_dim_cols, class RangeField> -class VisualizationAdapter; - template <class MinuendType, class SubtrahendType> class DifferenceGridFunction; @@ -157,7 +154,8 @@ public: const std::string path, const bool subsampling = true, const VTK::OutputType vtk_output_type = VTK::appendedraw, - const XT::Common::Parameter& param = {}) const + const XT::Common::Parameter& param = {}, + const VisualizerInterface<r, rC, R>& visualizer = DefaultVisualizer<r, rC, R>()) const { if (path.empty()) DUNE_THROW(Exceptions::wrong_input_given, "path must not be empty!"); @@ -168,7 +166,7 @@ public: using GridViewType = std::decay_t<decltype(grid_view)>; const auto adapter = std::make_shared<VisualizationAdapter<GridViewType, range_dim, range_dim_cols, RangeFieldType>>( - *this, "", param); + *this, visualizer, "", param); std::unique_ptr<VTKWriter<GridViewType>> vtk_writer = subsampling ? Common::make_unique<SubsamplingVTKWriter<GridViewType>>(grid_view, /*subsampling_level=*/2) : Common::make_unique<VTKWriter<GridViewType>>(grid_view, VTK::nonconforming); diff --git a/dune/xt/functions/type_traits.hh b/dune/xt/functions/type_traits.hh index cdc32bce04a96c05bdf58d7ee2f85333e8e15bd0..a81ccce14e3ac2472a9cd5e2f01ecfaecfe08d73 100644 --- a/dune/xt/functions/type_traits.hh +++ b/dune/xt/functions/type_traits.hh @@ -126,9 +126,13 @@ struct DerivativeRangeTypeSelector using return_single_type = FieldVector<R, d>; using dynamic_single_type = DynamicVector<R>; - using type = FieldVector<FieldMatrix<R, rC, d>, r>; - using return_type = XT::Common::FieldVector<XT::Common::FieldMatrix<R, rC, d>, r>; - using dynamic_type = DynamicVector<DynamicMatrix<R>>; + using row_derivative_type = FieldMatrix<R, rC, d>; + using row_derivative_return_type = XT::Common::FieldMatrix<R, rC, d>; + using dynamic_row_derivative_type = DynamicMatrix<R>; + + using type = FieldVector<row_derivative_type, r>; + using return_type = XT::Common::FieldVector<row_derivative_return_type, r>; + using dynamic_type = DynamicVector<dynamic_row_derivative_type>; static void ensure_size(dynamic_type& arg) { @@ -159,6 +163,10 @@ struct DerivativeRangeTypeSelector<d, R, r, 1> using return_single_type = XT::Common::FieldVector<R, d>; using dynamic_single_type = DynamicVector<R>; + using row_derivative_type = FieldMatrix<R, r, d>; + using row_derivative_return_type = XT::Common::FieldMatrix<R, r, d>; + using dynamic_row_derivative_type = DynamicMatrix<R>; + using type = FieldMatrix<R, r, d>; using return_type = XT::Common::FieldMatrix<R, r, d>; using dynamic_type = DynamicMatrix<R>;