diff --git a/dune/gdt/operators/lincomb.hh b/dune/gdt/operators/lincomb.hh index e37c47eacceb81fdf9c00ba6fa6ed0844e2a327f..f888b32b60000fb923a6c7434c389ec5f57bff37 100644 --- a/dune/gdt/operators/lincomb.hh +++ b/dune/gdt/operators/lincomb.hh @@ -7,8 +7,8 @@ // Authors: // Felix Schindler (2018) -#ifndef DUNE_GDT_const_operators_LINCOMB_HH -#define DUNE_GDT_const_operators_LINCOMB_HH +#ifndef DUNE_GDT_OPERATORS_LINCOMB_HH +#define DUNE_GDT_OPERATORS_LINCOMB_HH #include <vector> @@ -36,6 +36,8 @@ public: using typename BaseType::FieldType; using typename BaseType::MatrixOperatorType; using typename BaseType::VectorType; + using typename BaseType::ConstLincombOperatorType; + using typename BaseType::LincombOperatorType; using OperatorType = BaseType; @@ -45,35 +47,40 @@ public: { } + ConstLincombOperator(const ThisType& other) = default; + + ConstLincombOperator(ThisType&& source) = default; + void add(const OperatorType& op, const FieldType& coeff = 1.) { - const_operators_.emplace_back(op); + const_ops_.emplace_back(op); coeffs_.emplace_back(coeff); } - void add(const OperatorType*&& op, const FieldType& coeff = 1.) + void add(OperatorType*&& op, const FieldType& coeff = 1.) { - const_operators_.emplace_back(std::move(op)); + keep_alive_.emplace_back(std::move(op)); + const_ops_.emplace_back(*keep_alive_.back()); coeffs_.emplace_back(coeff); } void add(const ThisType& op, const FieldType& coeff = 1.) { - // Only adding op itself would lead to segfaults in some circumstances + // Only adding op itself would lead to segfaults if op is a temporary for (size_t ii = 0; ii < op.num_ops(); ++ii) { - const_operators_.emplace_back(op.const_operators_[ii]); + const_ops_.emplace_back(op.const_ops_[ii]); coeffs_.emplace_back(coeff * op.coeffs_[ii]); } } - void add(const ThisType*&& op, const FieldType& coeff = 1.) + void add(ThisType*&& op, const FieldType& coeff = 1.) { this->add(*op, coeff); } size_t num_ops() const { - return const_operators_.size(); + return const_ops_.size(); } const OperatorType& op(const size_t ii) const @@ -81,7 +88,7 @@ public: DUNE_THROW_IF(ii >= this->num_ops(), Exceptions::operator_error, "ii = " << ii << "\n this->num_ops() = " << this->num_ops()); - return const_operators_[ii].access(); + return const_ops_[ii].access(); } const FieldType& coeff(const size_t ii) const @@ -94,7 +101,7 @@ public: bool linear() const override final { - for (const auto& op : const_operators_) + for (const auto& op : const_ops_) if (!op.access().linear()) return false; return true; @@ -112,7 +119,7 @@ public: using BaseType::apply; - void apply(const VectorType& source, VectorType& range, const XT::Common::Parameter& param = {}) const + void apply(const VectorType& source, VectorType& range, const XT::Common::Parameter& param = {}) const override final { range.set_all(0); auto tmp = range; @@ -123,32 +130,12 @@ public: } } // ... append(...) - std::vector<std::string> invert_options() const - { - DUNE_THROW(Exceptions::operator_error, "This operator is not invertible!"); - return std::vector<std::string>(); - } - - XT::Common::Configuration invert_options(const std::string& /*type*/) const - { - DUNE_THROW(Exceptions::operator_error, "This operator is not invertible!"); - return XT::Common::Configuration(); - } - - void apply_inverse(const VectorType& /*range*/, - VectorType& /*source*/, - const XT::Common::Configuration& /*opts*/, - const XT::Common::Parameter& /*param*/ = {}) const - { - DUNE_THROW(Exceptions::operator_error, "This operator is not invertible!"); - } - - std::vector<std::string> jacobian_options() const + std::vector<std::string> jacobian_options() const override final { return {"lincomb"}; } - XT::Common::Configuration jacobian_options(const std::string& type) const + XT::Common::Configuration jacobian_options(const std::string& type) const override final { DUNE_THROW_IF(type != this->jacobian_options().at(0), Exceptions::operator_error, "type = " << type); using XT::Common::to_string; @@ -173,7 +160,7 @@ public: void jacobian(const VectorType& source, MatrixOperatorType& jacobian_op, const XT::Common::Configuration& opts, - const XT::Common::Parameter& param = {}) const + const XT::Common::Parameter& param = {}) const override final { // some checks DUNE_THROW_IF(!source.valid(), Exceptions::operator_error, "source contains inf or nan!"); @@ -216,10 +203,119 @@ public: } } // ... jacobian(...) -private: + ThisType& operator*=(const FieldType& alpha) + { + for (auto& coeff : coeffs_) + coeff *= alpha; + return *this; + } + + ThisType& operator/=(const FieldType& alpha) + { + for (auto& coeff : coeffs_) + coeff /= alpha; + return *this; + } + + ThisType& operator+=(const BaseType& other) + { + this->add(other); + return *this; + } + + ThisType& operator+=(const ThisType& other) + { + for (size_t ii = 0; ii < other.num_ops(); ++ii) { + const_ops_.emplace_back(other.const_ops_[ii]); + coeffs_.emplace_back(other.coeffs_[ii]); + } + return *this; + } + + ThisType& operator-=(const BaseType& other) + { + this->add(other, -1.); + return *this; + } + + ThisType& operator-=(const ThisType& other) + { + for (size_t ii = 0; ii < other.num_ops(); ++ii) { + const_ops_.emplace_back(other.const_ops_[ii]); + coeffs_.emplace_back(-1 * other.coeffs_[ii]); + } + return *this; + } + + // We need to override some operator+-*/ from the interface to avoid segfaults due to temporaries + + ConstLincombOperatorType operator*(const FieldType& alpha) const override final + { + ConstLincombOperatorType ret(*this); + ret *= alpha; + return ret; + } + + ConstLincombOperatorType operator/(const FieldType& alpha) const override final + { + ConstLincombOperatorType ret(*this); + ret /= alpha; + return ret; + } + + using BaseType::operator+; + + ConstLincombOperatorType operator+(const ConstLincombOperatorType& other) const override final + { + ConstLincombOperatorType ret(*this); + ret += other; + return ret; + } + + ConstLincombOperatorType operator+(const BaseType& other) const override final + { + ConstLincombOperatorType ret(*this); + ret += other; + return ret; + } + + ConstLincombOperatorType operator+(const VectorType& vector) const override final + { + ConstLincombOperatorType ret(*this); + ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector), + 1.); + return ret; + } + + using BaseType::operator-; + + ConstLincombOperatorType operator-(const ConstLincombOperatorType& other) const override final + { + ConstLincombOperatorType ret(*this); + ret += other; + return ret; + } + + ConstLincombOperatorType operator-(const BaseType& other) const override final + { + ConstLincombOperatorType ret(*this); + ret += other; + return ret; + } + + ConstLincombOperatorType operator-(const VectorType& vector) const override final + { + ConstLincombOperatorType ret(*this); + ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector), + -1.); + return ret; + } + +protected: const SourceSpaceType& source_space_; const RangeSpaceType& range_space_; - std::vector<XT::Common::ConstStorageProvider<OperatorType>> const_operators_; + std::vector<std::shared_ptr<OperatorType>> keep_alive_; + std::vector<XT::Common::ConstStorageProvider<OperatorType>> const_ops_; std::vector<FieldType> coeffs_; }; // class ConstLincombOperator @@ -250,35 +346,48 @@ public: using typename BaseType::SourceSpaceType; using typename BaseType::RangeSpaceType; using typename BaseType::FieldType; + using typename BaseType::OperatorType; using typename BaseType::MatrixOperatorType; + using typename BaseType::LincombOperatorType; using typename BaseType::VectorType; - using OperatorType = OperatorInterface<M, SGV, s_r, s_rC, r_r, r_rC, RGV>; - LincombOperator(const SourceSpaceType& src_space, const RangeSpaceType& rng_space) : BaseType(src_space, rng_space) { } + LincombOperator(ThisType& other) + : BaseType(other) + { + for (auto& op : other.ops_) + this->ops_.emplace_back(op); + } + + LincombOperator(ThisType&& source) + : BaseType(source) + , ops_(std::move(source.ops_)) + { + } + using BaseType::add; void add(OperatorType& op, const FieldType& coeff = 1.) { - operators_.emplace_back(op); - BaseType::add(operators_.back().access(), coeff); + ops_.emplace_back(op); + BaseType::add(ops_.back().access(), coeff); } void add(OperatorType*&& op, const FieldType& coeff = 1.) { - operators_.emplace_back(std::move(op)); - BaseType::add(operators_.back().access(), coeff); + BaseType::add(std::move(op), coeff); + ops_.emplace_back(*this->keep_alive_.back()); } void add(ThisType& op, const FieldType& coeff = 1.) { - for (size_t ii = 0; ii < op.num_ops(); ++ii) - operators_.emplace_back(op.operators_[ii]); BaseType::add(op, coeff); + for (size_t ii = 0; ii < op.num_ops(); ++ii) + ops_.emplace_back(op.ops_[ii]); } void add(ThisType*&& op, const FieldType& coeff = 1.) @@ -293,18 +402,110 @@ public: DUNE_THROW_IF(ii >= this->num_ops(), Exceptions::operator_error, "ii = " << ii << "\n this->num_ops() = " << this->num_ops()); - return operators_[ii].access(); + return ops_[ii].access(); } - OperatorType& assemble(const bool use_tbb = false) + OperatorType& assemble(const bool use_tbb = false) override final { - for (auto& op : operators_) + for (auto& op : ops_) op.access().assemble(use_tbb); return *this; } + // we need to override some operators, see above + + using BaseType::operator+=; + + ThisType& operator+=(ThisType& other) + { + for (size_t ii = 0; ii < other.num_ops(); ++ii) { + this->const_ops_.emplace_back(other.const_ops_[ii]); + ops_.emplace_back(other.ops_[ii]); + this->coeffs_.emplace_back(other.coeffs_[ii]); + } + return *this; + } + + using BaseType::operator-=; + + ThisType& operator-=(ThisType& other) + { + for (size_t ii = 0; ii < other.num_ops(); ++ii) { + this->const_ops_.emplace_back(other.const_ops_[ii]); + ops_.emplace_back(other.ops_[ii]); + this->coeffs_.emplace_back(-1 * other.coeffs_[ii]); + } + return *this; + } + + using BaseType::operator*; + + LincombOperatorType operator*(const FieldType& alpha)override final + { + LincombOperatorType ret(*this); + ret *= alpha; + return ret; + } + + using BaseType::operator/; + + LincombOperatorType operator/(const FieldType& alpha) override final + { + LincombOperatorType ret(*this); + ret /= alpha; + return ret; + } + + using BaseType::operator+; + + LincombOperatorType operator+(LincombOperatorType& other) override final + { + LincombOperatorType ret(*this); + ret += other; + return ret; + } + + LincombOperatorType operator+(OperatorType& other) override final + { + LincombOperatorType ret(*this); + ret += other; + return ret; + } + + LincombOperatorType operator+(const VectorType& vector) override final + { + LincombOperatorType ret(*this); + ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector), + -1.); + return ret; + } + + using BaseType::operator-; + + LincombOperatorType operator-(LincombOperatorType& other) override final + { + LincombOperatorType ret(*this); + ret -= other; + return ret; + } + + LincombOperatorType operator-(OperatorType& other) override final + { + LincombOperatorType ret(*this); + ret -= other; + return ret; + } + + LincombOperatorType operator-(const VectorType& vector) override final + { + LincombOperatorType ret(*this); + ret.add(new ConstantOperator<M, SGV, s_r, s_rC, r_r, r_rC, RGV>(this->source_space(), this->range_space(), vector), + -1.); + return ret; + } + private: - std::vector<XT::Common::StorageProvider<OperatorType>> operators_; + std::vector<XT::Common::StorageProvider<OperatorType>> ops_; }; // class LincombOperator @@ -327,4 +528,4 @@ make_lincomb_operator(const SpaceInterface<GV, r, rC, F>& space) } // namespace GDT } // namespace Dune -#endif // DUNE_GDT_const_operators_LINCOMB_HH +#endif // DUNE_GDT_OPERATORS_LINCOMB_HH