Skip to content
Snippets Groups Projects
Commit 2e3cfc21 authored by René Milk's avatar René Milk
Browse files

Merge pull request #76 from tobiasleibner/threadsafe_expression

Thread-safe expression function
parents c3bef2ad 12963113
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <dune/stuff/common/configuration.hh> #include <dune/stuff/common/configuration.hh>
#include <dune/stuff/common/exceptions.hh> #include <dune/stuff/common/exceptions.hh>
#include <dune/stuff/common/parallel/threadstorage.hh>
#include "expression/base.hh" #include "expression/base.hh"
#include "interfaces.hh" #include "interfaces.hh"
...@@ -215,15 +216,15 @@ public: ...@@ -215,15 +216,15 @@ public:
bool failure = false; bool failure = false;
std::string error_type; std::string error_type;
for (size_t rr = 0; rr < dimRange; ++rr) { for (size_t rr = 0; rr < dimRange; ++rr) {
tmp_row_ = ret[rr]; *tmp_row_ = ret[rr];
for (size_t cc = 0; cc < dimRangeCols; ++cc) { for (size_t cc = 0; cc < dimRangeCols; ++cc) {
if (DSC::isnan(tmp_row_[cc])) { if (DSC::isnan(tmp_row_->operator[](cc))) {
failure = true; failure = true;
error_type = "NaN"; error_type = "NaN";
} else if (DSC::isinf(tmp_row_[cc])) { } else if (DSC::isinf(tmp_row_->operator[](cc))) {
failure = true; failure = true;
error_type = "inf"; error_type = "inf";
} else if (std::abs(tmp_row_[cc]) > (0.9 * std::numeric_limits<double>::max())) { } else if (std::abs(tmp_row_->operator[](cc)) > (0.9 * std::numeric_limits<double>::max())) {
failure = true; failure = true;
error_type = "an unlikely value"; error_type = "an unlikely value";
} }
...@@ -236,13 +237,13 @@ public: ...@@ -236,13 +237,13 @@ public:
<< function_->variable() << function_->variable()
<< "\n" << "\n"
<< "The expression of this functional is: " << "The expression of this functional is: "
<< function_->expression().at(0) << function_->expression().at(rr * dimRangeCols + cc)
<< "\n" << "\n"
<< "You tried to evaluate it with: xx = " << "You tried to evaluate it with: xx = "
<< xx << xx
<< "\n" << "\n"
<< "The result was: " << "The result was: "
<< ret << tmp_row_->operator[](cc)
<< "\n\n" << "\n\n"
<< "You can disable this check by defining DUNE_STUFF_FUNCTIONS_EXPRESSION_DISABLE_CHECKS\n"); << "You can disable this check by defining DUNE_STUFF_FUNCTIONS_EXPRESSION_DISABLE_CHECKS\n");
} }
...@@ -296,11 +297,11 @@ private: ...@@ -296,11 +297,11 @@ private:
template <size_t rC> template <size_t rC>
void evaluate_helper(const DomainType& xx, RangeType& ret, internal::ChooseVariant<rC>) const void evaluate_helper(const DomainType& xx, RangeType& ret, internal::ChooseVariant<rC>) const
{ {
function_->evaluate(xx, tmp_vector_); function_->evaluate(xx, *tmp_vector_);
for (size_t rr = 0; rr < dimRange; ++rr) { for (size_t rr = 0; rr < dimRange; ++rr) {
auto& retRow = ret[rr]; auto& retRow = ret[rr];
for (size_t cc = 0; cc < dimRangeCols; ++cc) for (size_t cc = 0; cc < dimRangeCols; ++cc)
retRow[cc] = tmp_vector_[rr * dimRangeCols + cc]; retRow[cc] = (*tmp_vector_)[rr * dimRangeCols + cc];
} }
} // ... evaluate_helper(...) } // ... evaluate_helper(...)
...@@ -381,9 +382,8 @@ private: ...@@ -381,9 +382,8 @@ private:
std::shared_ptr<const MathExpressionFunctionType> function_; std::shared_ptr<const MathExpressionFunctionType> function_;
size_t order_; size_t order_;
std::string name_; std::string name_;
mutable FieldVector<RangeFieldType, dimRange * dimRangeCols> tmp_vector_; mutable typename DS::PerThreadValue<FieldVector<RangeFieldType, dimRange * dimRangeCols>> tmp_vector_;
mutable FieldVector<RangeFieldType, dimRangeCols> tmp_row_; mutable typename DS::PerThreadValue<FieldVector<RangeFieldType, dimRangeCols>> tmp_row_;
mutable FieldVector<RangeFieldType, dimDomain> tmp_gradient_row_;
std::vector<std::vector<std::shared_ptr<const MathExpressionGradientType>>> gradients_; std::vector<std::vector<std::shared_ptr<const MathExpressionGradientType>>> gradients_;
}; // class Expression }; // class Expression
......
...@@ -86,6 +86,7 @@ public: ...@@ -86,6 +86,7 @@ public:
void evaluate(const Dune::FieldVector<DomainFieldType, dimDomain>& arg, void evaluate(const Dune::FieldVector<DomainFieldType, dimDomain>& arg,
Dune::FieldVector<RangeFieldType, dimRange>& ret) const Dune::FieldVector<RangeFieldType, dimRange>& ret) const
{ {
std::lock_guard<std::mutex> guard(mutex_);
// copy arg // copy arg
for (typename Dune::FieldVector<DomainFieldType, dimDomain>::size_type ii = 0; ii < dimDomain; ++ii) for (typename Dune::FieldVector<DomainFieldType, dimDomain>::size_type ii = 0; ii < dimDomain; ++ii)
*(arg_[ii]) = arg[ii]; *(arg_[ii]) = arg[ii];
...@@ -99,6 +100,7 @@ public: ...@@ -99,6 +100,7 @@ public:
*/ */
void evaluate(const Dune::DynamicVector<DomainFieldType>& arg, Dune::DynamicVector<RangeFieldType>& ret) const void evaluate(const Dune::DynamicVector<DomainFieldType>& arg, Dune::DynamicVector<RangeFieldType>& ret) const
{ {
std::lock_guard<std::mutex> guard(mutex_);
// check for sizes // check for sizes
assert(arg.size() > 0); assert(arg.size() > 0);
if (ret.size() != dimRange) if (ret.size() != dimRange)
...@@ -114,6 +116,7 @@ public: ...@@ -114,6 +116,7 @@ public:
void evaluate(const Dune::FieldVector<DomainFieldType, dimDomain>& arg, void evaluate(const Dune::FieldVector<DomainFieldType, dimDomain>& arg,
Dune::DynamicVector<RangeFieldType>& ret) const Dune::DynamicVector<RangeFieldType>& ret) const
{ {
std::lock_guard<std::mutex> guard(mutex_);
// check for sizes // check for sizes
if (ret.size() != dimRange) if (ret.size() != dimRange)
ret = Dune::DynamicVector<RangeFieldType>(dimRange); ret = Dune::DynamicVector<RangeFieldType>(dimRange);
...@@ -130,6 +133,7 @@ public: ...@@ -130,6 +133,7 @@ public:
*/ */
void evaluate(const Dune::DynamicVector<DomainFieldType>& arg, Dune::FieldVector<RangeFieldType, dimRange>& ret) const void evaluate(const Dune::DynamicVector<DomainFieldType>& arg, Dune::FieldVector<RangeFieldType, dimRange>& ret) const
{ {
std::lock_guard<std::mutex> guard(mutex_);
assert(arg.size() > 0); assert(arg.size() > 0);
// copy arg // copy arg
for (size_t ii = 0; ii < std::min(dimDomain, arg.size()); ++ii) for (size_t ii = 0; ii < std::min(dimDomain, arg.size()); ++ii)
...@@ -208,6 +212,7 @@ private: ...@@ -208,6 +212,7 @@ private:
RVar* var_arg_[dimDomain]; RVar* var_arg_[dimDomain];
RVar* vararray_[dimDomain]; RVar* vararray_[dimDomain];
ROperation* op_[dimRange]; ROperation* op_[dimRange];
mutable std::mutex mutex_;
}; // class MathExpressionBase }; // class MathExpressionBase
} // namespace Functions } // namespace Functions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment