Skip to content
Snippets Groups Projects
Commit 9be4d00a authored by Dr. Felix Tobias Schindler's avatar Dr. Felix Tobias Schindler
Browse files

[operators.lincomb] fix segfault for temporary lincomb ops

parent e1f55365
No related branches found
No related tags found
No related merge requests found
......@@ -7,8 +7,8 @@
// Authors:
// Felix Schindler (2018)
#ifndef DUNE_GDT_const_operators_LINCOMB_HH
#define DUNE_GDT_const_operators_LINCOMB_HH
#ifndef DUNE_GDT_OPERATORS_LINCOMB_HH
#define DUNE_GDT_OPERATORS_LINCOMB_HH
#include <vector>
......@@ -36,6 +36,8 @@ public:
using typename BaseType::FieldType;
using typename BaseType::MatrixOperatorType;
using typename BaseType::VectorType;
using typename BaseType::ConstLincombOperatorType;
using typename BaseType::LincombOperatorType;
using OperatorType = BaseType;
......@@ -45,35 +47,40 @@ public:
{
}
ConstLincombOperator(const ThisType& other) = default;
ConstLincombOperator(ThisType&& source) = default;
void add(const OperatorType& op, const FieldType& coeff = 1.)
{
const_operators_.emplace_back(op);
const_ops_.emplace_back(op);
coeffs_.emplace_back(coeff);
}
void add(const OperatorType*&& op, const FieldType& coeff = 1.)
void add(OperatorType*&& op, const FieldType& coeff = 1.)
{
const_operators_.emplace_back(std::move(op));
keep_alive_.emplace_back(std::move(op));
const_ops_.emplace_back(*keep_alive_.back());
coeffs_.emplace_back(coeff);
}
void add(const ThisType& op, const FieldType& coeff = 1.)
{
// Only adding op itself would lead to segfaults in some circumstances
// Only adding op itself would lead to segfaults if op is a temporary
for (size_t ii = 0; ii < op.num_ops(); ++ii) {
const_operators_.emplace_back(op.const_operators_[ii]);
const_ops_.emplace_back(op.const_ops_[ii]);
coeffs_.emplace_back(coeff * op.coeffs_[ii]);
}
}
void add(const ThisType*&& op, const FieldType& coeff = 1.)
void add(ThisType*&& op, const FieldType& coeff = 1.)
{
this->add(*op, coeff);
}
size_t num_ops() const
{
return const_operators_.size();
return const_ops_.size();
}
const OperatorType& op(const size_t ii) const
......@@ -81,7 +88,7 @@ public:
DUNE_THROW_IF(ii >= this->num_ops(),
Exceptions::operator_error,
"ii = " << ii << "\n this->num_ops() = " << this->num_ops());
return const_operators_[ii].access();
return const_ops_[ii].access();
}
const FieldType& coeff(const size_t ii) const
......@@ -94,7 +101,7 @@ public:
bool linear() const override final
{
for (const auto& op : const_operators_)
for (const auto& op : const_ops_)
if (!op.access().linear())
return false;
return true;
......@@ -112,7 +119,7 @@ public:
using BaseType::apply;
void apply(const VectorType& source, VectorType& range, const XT::Common::Parameter& param = {}) const
void apply(const VectorType& source, VectorType& range, const XT::Common::Parameter& param = {}) const override final
{
range.set_all(0);
auto tmp = range;
......@@ -123,32 +130,12 @@ public:
}
} // ... append(...)
std::vector<std::string> invert_options() const
{
DUNE_THROW(Exceptions::operator_error, "This operator is not invertible!");
return std::vector<std::string>();
}
XT::Common::Configuration invert_options(const std::string& /*type*/) const
{
DUNE_THROW(Exceptions::operator_error, "This operator is not invertible!");
return XT::Common::Configuration();
}
void apply_inverse(const VectorType& /*range*/,
VectorType& /*source*/,
const XT::Common::Configuration& /*opts*/,
const XT::Common::Parameter& /*param*/ = {}) const
{
DUNE_THROW(Exceptions::operator_error, "This operator is not invertible!");
}
std::vector<std::string> jacobian_options() const
std::vector<std::string> jacobian_options() const override final
{
return {"lincomb"};
}
XT::Common::Configuration jacobian_options(const std::string& type) const
XT::Common::Configuration jacobian_options(const std::string& type) const override final
{
DUNE_THROW_IF(type != this->jacobian_options().at(0), Exceptions::operator_error, "type = " << type);
using XT::Common::to_string;
......@@ -173,7 +160,7 @@ public:
void jacobian(const VectorType& source,
MatrixOperatorType& jacobian_op,
const XT::Common::Configuration& opts,
const XT::Common::Parameter& param = {}) const
const XT::Common::Parameter& param = {}) const override final
{
// some checks
DUNE_THROW_IF(!source.valid(), Exceptions::operator_error, "source contains inf or nan!");
......@@ -216,10 +203,119 @@ public:
}
} // ... jacobian(...)
private:
ThisType& operator*=(const FieldType& alpha)
{
for (auto& coeff : coeffs_)
coeff *= alpha;
return *this;
}
ThisType& operator/=(const FieldType& alpha)
{
for (auto& coeff : coeffs_)
coeff /= alpha;
return *this;
}
ThisType& operator+=(const BaseType& other)
{
this->add(other);
return *this;
}
ThisType& operator+=(const ThisType& other)
{
for (size_t ii = 0; ii < other.num_ops(); ++ii) {
const_ops_.emplace_back(other.const_ops_[ii]);
coeffs_.emplace_back(other.coeffs_[ii]);
}
return *this;
}
ThisType& operator-=(const BaseType& other)
{
this->add(other, -1.);
return *this;
}
ThisType& operator-=(const ThisType& other)
{
for (size_t ii = 0; ii < other.num_ops(); ++ii) {
const_ops_.emplace_back(other.const_ops_[ii]);
coeffs_.emplace_back(-1 * other.coeffs_[ii]);
}
return *this;
}
// We need to override some operator+-*/ from the interface to avoid segfaults due to temporaries
ConstLincombOperatorType operator*(const FieldType& alpha) const override final
{
ConstLincombOperatorType ret(*this);
ret *= alpha;
return ret;
}
ConstLincombOperatorType operator/(const FieldType& alpha) const override final
{
ConstLincombOperatorType ret(*this);
ret /= alpha;
return ret;
}
using BaseType::operator+;
ConstLincombOperatorType operator+(const ConstLincombOperatorType& other) const override final
{
ConstLincombOperatorType ret(*this);
ret += other;
return ret;
}
ConstLincombOperatorType operator+(const BaseType& other) const override final
{
ConstLincombOperatorType ret(*this);
ret += other;
return ret;
}
ConstLincombOperatorType operator+(const VectorType& vector) const override final
{
ConstLincombOperatorType ret(*this);
ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector),
1.);
return ret;
}
using BaseType::operator-;
ConstLincombOperatorType operator-(const ConstLincombOperatorType& other) const override final
{
ConstLincombOperatorType ret(*this);
ret += other;
return ret;
}
ConstLincombOperatorType operator-(const BaseType& other) const override final
{
ConstLincombOperatorType ret(*this);
ret += other;
return ret;
}
ConstLincombOperatorType operator-(const VectorType& vector) const override final
{
ConstLincombOperatorType ret(*this);
ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector),
-1.);
return ret;
}
protected:
const SourceSpaceType& source_space_;
const RangeSpaceType& range_space_;
std::vector<XT::Common::ConstStorageProvider<OperatorType>> const_operators_;
std::vector<std::shared_ptr<OperatorType>> keep_alive_;
std::vector<XT::Common::ConstStorageProvider<OperatorType>> const_ops_;
std::vector<FieldType> coeffs_;
}; // class ConstLincombOperator
......@@ -250,35 +346,48 @@ public:
using typename BaseType::SourceSpaceType;
using typename BaseType::RangeSpaceType;
using typename BaseType::FieldType;
using typename BaseType::OperatorType;
using typename BaseType::MatrixOperatorType;
using typename BaseType::LincombOperatorType;
using typename BaseType::VectorType;
using OperatorType = OperatorInterface<M, SGV, s_r, s_rC, r_r, r_rC, RGV>;
LincombOperator(const SourceSpaceType& src_space, const RangeSpaceType& rng_space)
: BaseType(src_space, rng_space)
{
}
LincombOperator(ThisType& other)
: BaseType(other)
{
for (auto& op : other.ops_)
this->ops_.emplace_back(op);
}
LincombOperator(ThisType&& source)
: BaseType(source)
, ops_(std::move(source.ops_))
{
}
using BaseType::add;
void add(OperatorType& op, const FieldType& coeff = 1.)
{
operators_.emplace_back(op);
BaseType::add(operators_.back().access(), coeff);
ops_.emplace_back(op);
BaseType::add(ops_.back().access(), coeff);
}
void add(OperatorType*&& op, const FieldType& coeff = 1.)
{
operators_.emplace_back(std::move(op));
BaseType::add(operators_.back().access(), coeff);
BaseType::add(std::move(op), coeff);
ops_.emplace_back(*this->keep_alive_.back());
}
void add(ThisType& op, const FieldType& coeff = 1.)
{
for (size_t ii = 0; ii < op.num_ops(); ++ii)
operators_.emplace_back(op.operators_[ii]);
BaseType::add(op, coeff);
for (size_t ii = 0; ii < op.num_ops(); ++ii)
ops_.emplace_back(op.ops_[ii]);
}
void add(ThisType*&& op, const FieldType& coeff = 1.)
......@@ -293,18 +402,110 @@ public:
DUNE_THROW_IF(ii >= this->num_ops(),
Exceptions::operator_error,
"ii = " << ii << "\n this->num_ops() = " << this->num_ops());
return operators_[ii].access();
return ops_[ii].access();
}
OperatorType& assemble(const bool use_tbb = false)
OperatorType& assemble(const bool use_tbb = false) override final
{
for (auto& op : operators_)
for (auto& op : ops_)
op.access().assemble(use_tbb);
return *this;
}
// we need to override some operators, see above
using BaseType::operator+=;
ThisType& operator+=(ThisType& other)
{
for (size_t ii = 0; ii < other.num_ops(); ++ii) {
this->const_ops_.emplace_back(other.const_ops_[ii]);
ops_.emplace_back(other.ops_[ii]);
this->coeffs_.emplace_back(other.coeffs_[ii]);
}
return *this;
}
using BaseType::operator-=;
ThisType& operator-=(ThisType& other)
{
for (size_t ii = 0; ii < other.num_ops(); ++ii) {
this->const_ops_.emplace_back(other.const_ops_[ii]);
ops_.emplace_back(other.ops_[ii]);
this->coeffs_.emplace_back(-1 * other.coeffs_[ii]);
}
return *this;
}
using BaseType::operator*;
LincombOperatorType operator*(const FieldType& alpha)override final
{
LincombOperatorType ret(*this);
ret *= alpha;
return ret;
}
using BaseType::operator/;
LincombOperatorType operator/(const FieldType& alpha) override final
{
LincombOperatorType ret(*this);
ret /= alpha;
return ret;
}
using BaseType::operator+;
LincombOperatorType operator+(LincombOperatorType& other) override final
{
LincombOperatorType ret(*this);
ret += other;
return ret;
}
LincombOperatorType operator+(OperatorType& other) override final
{
LincombOperatorType ret(*this);
ret += other;
return ret;
}
LincombOperatorType operator+(const VectorType& vector) override final
{
LincombOperatorType ret(*this);
ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector),
-1.);
return ret;
}
using BaseType::operator-;
LincombOperatorType operator-(LincombOperatorType& other) override final
{
LincombOperatorType ret(*this);
ret -= other;
return ret;
}
LincombOperatorType operator-(OperatorType& other) override final
{
LincombOperatorType ret(*this);
ret -= other;
return ret;
}
LincombOperatorType operator-(const VectorType& vector) override final
{
LincombOperatorType ret(*this);
ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector),
-1.);
return ret;
}
private:
std::vector<XT::Common::StorageProvider<OperatorType>> operators_;
std::vector<XT::Common::StorageProvider<OperatorType>> ops_;
}; // class LincombOperator
......@@ -327,4 +528,4 @@ make_lincomb_operator(const SpaceInterface<GV, r, rC, F>& space)
} // namespace GDT
} // namespace Dune
#endif // DUNE_GDT_const_operators_LINCOMB_HH
#endif // DUNE_GDT_OPERATORS_LINCOMB_HH
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment