From 3ee9c1342afd7d5d378019d8bf2c86bab9684bb7 Mon Sep 17 00:00:00 2001 From: Tobias Leibner <tobias.leibner@googlemail.com> Date: Thu, 31 Jan 2019 11:58:23 +0100 Subject: [PATCH] [container.matrix] add transposed pattern, fix multiply for sparse matrices --- dune/xt/la/container/matrix-interface.hh | 14 ++++++++------ dune/xt/la/container/pattern.cc | 10 ++++++++++ dune/xt/la/container/pattern.hh | 2 ++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/dune/xt/la/container/matrix-interface.hh b/dune/xt/la/container/matrix-interface.hh index 7de50637a..18b303052 100644 --- a/dune/xt/la/container/matrix-interface.hh +++ b/dune/xt/la/container/matrix-interface.hh @@ -235,10 +235,11 @@ public: derived_type transposed() const { - derived_type yy(rows(), cols(), 0.); + const auto this_pattern = pattern(); + derived_type yy(cols(), rows(), this_pattern.transposed(cols())); for (size_t rr = 0; rr < rows(); ++rr) - for (size_t cc = 0; cc < cols(); ++cc) - yy.set_entry(rr, cc, get_entry(cc, rr)); + for (const auto& cc : this_pattern.inner(rr)) + yy.set_entry(cc, rr, get_entry(rr, cc)); return yy; } @@ -376,11 +377,12 @@ protected: { if (other.rows() != cols()) DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be multiplied do not match!"); - const auto new_pattern = multiplication_pattern(pattern(), other.pattern(), other.cols()); + const auto this_pattern = pattern(); + const auto new_pattern = multiplication_pattern(this_pattern, other.pattern(), other.cols()); derived_type yy(rows(), other.cols(), new_pattern); for (size_t rr = 0; rr < rows(); ++rr) - for (size_t cc = 0; cc < other.cols(); ++cc) - for (size_t kk = 0; kk < cols(); ++kk) + for (auto&& cc : new_pattern.inner(rr)) + for (auto&& kk : this_pattern.inner(rr)) yy.add_to_entry(rr, cc, get_entry(rr, kk) * other.get_entry(kk, cc)); return yy; } diff --git a/dune/xt/la/container/pattern.cc b/dune/xt/la/container/pattern.cc index a2ba53365..c44c9fb5a 100644 --- a/dune/xt/la/container/pattern.cc +++ b/dune/xt/la/container/pattern.cc @@ -110,6 +110,16 @@ void SparsityPatternDefault::sort() std::sort(inner_vector.begin(), inner_vector.end()); } +SparsityPatternDefault SparsityPatternDefault::transposed(const size_t cols) const +{ + SparsityPatternDefault transposed_pattern(cols); + for (size_t rr = 0; rr < size(); ++rr) + for (const auto& cc : inner(rr)) + transposed_pattern.insert(cc, rr); + transposed_pattern.sort(); + return transposed_pattern; +} + SparsityPatternDefault dense_pattern(const size_t rows, const size_t cols) { SparsityPatternDefault ret(rows); diff --git a/dune/xt/la/container/pattern.hh b/dune/xt/la/container/pattern.hh index 6d1e85e4a..b30a203d4 100644 --- a/dune/xt/la/container/pattern.hh +++ b/dune/xt/la/container/pattern.hh @@ -61,6 +61,8 @@ public: void sort(); + SparsityPatternDefault transposed(const size_t cols) const; + private: BaseType vector_of_vectors_; }; // class SparsityPatternDefault -- GitLab