From 8f7cd92f699c7695ec7b2d8ffd193cb177f4c8f1 Mon Sep 17 00:00:00 2001 From: Tobias Leibner <tobias.leibner@googlemail.com> Date: Thu, 29 Jun 2017 14:18:17 +0200 Subject: [PATCH] [container.matrix-interface] add operators --- dune/xt/la/container/common.hh | 25 +++-- dune/xt/la/container/eigen/dense.hh | 5 + dune/xt/la/container/eigen/sparse.hh | 5 + dune/xt/la/container/istl.hh | 5 + dune/xt/la/container/matrix-interface.hh | 111 +++++++++++++++++++++++ dune/xt/la/container/pattern.cc | 11 +++ dune/xt/la/container/pattern.hh | 2 + 7 files changed, 156 insertions(+), 8 deletions(-) diff --git a/dune/xt/la/container/common.hh b/dune/xt/la/container/common.hh index 8ae76d14a..01dbb6ede 100644 --- a/dune/xt/la/container/common.hh +++ b/dune/xt/la/container/common.hh @@ -632,6 +632,11 @@ public: /// \} + using MatrixInterfaceType::operator+; + using MatrixInterfaceType::operator-; + using MatrixInterfaceType::operator+=; + using MatrixInterfaceType::operator-=; + protected: /** * \see ContainerInterface @@ -944,6 +949,18 @@ public: /// \} + using MatrixInterfaceType::operator+; + using MatrixInterfaceType::operator-; + using MatrixInterfaceType::operator+=; + using MatrixInterfaceType::operator-=; + +protected: + inline void ensure_uniqueness() + { + if (!entries_.unique()) + entries_ = std::make_shared<EntriesVectorType>(*entries_); + } // ... ensure_uniqueness(...) + private: size_t get_entry_index(const size_t rr, const size_t cc, const bool throw_if_not_in_pattern = true) const { @@ -959,14 +976,6 @@ private: return size_t(-1); } -protected: - inline void ensure_uniqueness() - { - if (!entries_.unique()) - entries_ = std::make_shared<EntriesVectorType>(*entries_); - } // ... ensure_uniqueness(...) - -private: size_t num_rows_, num_cols_; std::shared_ptr<EntriesVectorType> entries_; std::shared_ptr<IndexVectorType> row_pointers_; diff --git a/dune/xt/la/container/eigen/dense.hh b/dune/xt/la/container/eigen/dense.hh index 5f261be8f..cf926cff3 100644 --- a/dune/xt/la/container/eigen/dense.hh +++ b/dune/xt/la/container/eigen/dense.hh @@ -647,6 +647,11 @@ public: * \} */ + using MatrixInterfaceType::operator+; + using MatrixInterfaceType::operator-; + using MatrixInterfaceType::operator+=; + using MatrixInterfaceType::operator-=; + protected: inline void ensure_uniqueness() { diff --git a/dune/xt/la/container/eigen/sparse.hh b/dune/xt/la/container/eigen/sparse.hh index 281f7fbbb..799b42bb5 100644 --- a/dune/xt/la/container/eigen/sparse.hh +++ b/dune/xt/la/container/eigen/sparse.hh @@ -434,6 +434,11 @@ public: /// \} + using MatrixInterfaceType::operator+; + using MatrixInterfaceType::operator-; + using MatrixInterfaceType::operator+=; + using MatrixInterfaceType::operator-=; + private: bool these_are_valid_indices(const size_t ii, const size_t jj) const { diff --git a/dune/xt/la/container/istl.hh b/dune/xt/la/container/istl.hh index fae004561..fac0a369c 100644 --- a/dune/xt/la/container/istl.hh +++ b/dune/xt/la/container/istl.hh @@ -690,6 +690,11 @@ public: /// \} + using MatrixInterfaceType::operator+; + using MatrixInterfaceType::operator-; + using MatrixInterfaceType::operator+=; + using MatrixInterfaceType::operator-=; + private: void build_sparse_matrix(const size_t rr, const size_t cc, const SparsityPatternDefault& patt) { diff --git a/dune/xt/la/container/matrix-interface.hh b/dune/xt/la/container/matrix-interface.hh index a3c4b6bcc..62c1037cf 100644 --- a/dune/xt/la/container/matrix-interface.hh +++ b/dune/xt/la/container/matrix-interface.hh @@ -114,6 +114,36 @@ public: CHECK_AND_CALL_CRTP(this->as_imp().unit_col(jj)); } + template <class MM> + derived_type operator*(const MatrixInterface<MM, ScalarType>& other) const + { + return multiply(other); + } + + template <class MM> + derived_type operator+(const MatrixInterface<MM, ScalarType>& other) const + { + return add(other); + } + + template <class MM> + derived_type operator-(const MatrixInterface<MM, ScalarType>& other) const + { + return subtract(other); + } + + template <class MM> + derived_type& operator+=(const MatrixInterface<MM, ScalarType>& other) + { + return add_assign(other); + } + + template <class MM> + derived_type& operator-=(const MatrixInterface<MM, ScalarType>& other) + { + return subtract_assign(other); + } + /** * \brief Checks entries for inf or nan. * \return false if any entry is inf or nan, else true @@ -148,6 +178,15 @@ public: return ret; } // ... sup_norm(...) + derived_type transposed() const + { + derived_type yy(rows(), cols(), 0.); + for (size_t rr = 0; rr < rows(); ++rr) + for (size_t cc = 0; cc < cols(); ++cc) + yy.set_entry(rr, cc, get_entry(cc, rr)); + return yy; + } + /** * \brief Returns the number of entries in the sparsity pattern of the matrix. * @@ -243,6 +282,78 @@ public: } // ... almost_equal(...) /// \} +protected: + template <class MM> + derived_type multiply(const MatrixInterface<MM, ScalarType>& other) const + { + if (other.rows() != cols()) + DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be multiplied do not match!"); + derived_type yy(rows(), other.cols(), 0.); + for (size_t rr = 0; rr < rows(); ++rr) + for (size_t cc = 0; cc < other.cols(); ++cc) + for (size_t kk = 0; kk < cols(); ++kk) + yy.add_to_entry(rr, cc, get_entry(rr, kk) * other.get_entry(kk, cc)); + return yy; + } + + template <class MM> + derived_type add(const MatrixInterface<MM, ScalarType>& other) const + { + if (other.rows() != rows() || other.cols() != cols()) + DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!"); + auto new_pattern = pattern() + other.pattern(); + derived_type yy(rows(), other.cols(), new_pattern); + for (size_t rr = 0; rr < rows(); ++rr) + for (const auto& cc : new_pattern[rr]) + yy.set_entry(rr, cc, get_entry(rr, cc) + other.get_entry(rr, cc)); + return yy; + } + + template <class MM> + derived_type subtract(const MatrixInterface<MM, ScalarType>& other) const + { + if (other.rows() != rows() || other.cols() != cols()) + DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!"); + auto new_pattern = pattern() + other.pattern(); + derived_type yy(rows(), other.cols(), new_pattern); + for (size_t rr = 0; rr < rows(); ++rr) + for (const auto& cc : new_pattern[rr]) + yy.set_entry(rr, cc, get_entry(rr, cc) - other.get_entry(rr, cc)); + return yy; + } + + template <class MM> + derived_type& add_assign(const MatrixInterface<MM, ScalarType>& other) const + { + if (other.rows() != rows() || other.cols() != cols()) + DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!"); + const auto this_pattern = pattern(); + auto new_pattern = this_pattern + other.pattern(); + if (new_pattern != this_pattern) + DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, + "The matrix to be added contains entries that are not in this' pattern!"); + for (size_t rr = 0; rr < rows(); ++rr) + for (const auto& cc : this_pattern[rr]) + add_to_entry(rr, cc, other.get_entry(rr, cc)); + return this->as_imp(); + } + + template <class MM> + derived_type& subtract_assign(const MatrixInterface<MM, ScalarType>& other) const + { + if (other.rows() != rows() || other.cols() != cols()) + DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!"); + const auto this_pattern = pattern(); + auto new_pattern = this_pattern + other.pattern(); + if (new_pattern != this_pattern) + DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, + "The matrix to be added contains entries that are not in this' pattern!"); + for (size_t rr = 0; rr < rows(); ++rr) + for (const auto& cc : this_pattern[rr]) + add_to_entry(rr, cc, -other.get_entry(rr, cc)); + return this->as_imp(); + } + private: template <class T, class S> friend std::ostream& operator<<(std::ostream& /*out*/, const MatrixInterface<T, S>& /*matrix*/); diff --git a/dune/xt/la/container/pattern.cc b/dune/xt/la/container/pattern.cc index edeb185c9..a9ea00301 100644 --- a/dune/xt/la/container/pattern.cc +++ b/dune/xt/la/container/pattern.cc @@ -63,6 +63,17 @@ bool SparsityPatternDefault::operator!=(const SparsityPatternDefault& other) con return vector_of_vectors_ != other.vector_of_vectors_; } +SparsityPatternDefault operator+(const SparsityPatternDefault& other) const +{ + assert(other.size() == size() && "SparsityPatterns must have the same number of rows for addition!"); + SparsityPatternDefault ret = *this; + for (size_t rr = 0; rr < size(); ++rr) + for (const auto& cc : vector_of_vectors_[rr]) + ret.insert(rr, cc); + ret.sort(); + return ret; +} + void SparsityPatternDefault::insert(const size_t outer_index, const size_t inner_index) { assert(outer_index < size() && "Wrong index requested!"); diff --git a/dune/xt/la/container/pattern.hh b/dune/xt/la/container/pattern.hh index 0a0e84873..aa36343a3 100644 --- a/dune/xt/la/container/pattern.hh +++ b/dune/xt/la/container/pattern.hh @@ -45,6 +45,8 @@ public: bool operator!=(const SparsityPatternDefault& other) const; + SparsityPatternDefault operator+(const SparsityPatternDefault& other) const; + void insert(const size_t outer_index, const size_t inner_index); void sort(const size_t outer_index); -- GitLab