From ba9920a8f83b7cb36acfa96637fbd90b5403aa7a Mon Sep 17 00:00:00 2001
From: Tobias Leibner <tobias.leibner@uni-muenster.de>
Date: Tue, 28 Feb 2017 15:39:35 +0100
Subject: [PATCH] [container] add get_entry_ref to matrix interface

---
 dune/xt/la/container/common.hh           | 16 ++++++++++++++++
 dune/xt/la/container/eigen/dense.hh      |  7 +++++++
 dune/xt/la/container/eigen/sparse.hh     |  9 +++++++++
 dune/xt/la/container/istl.hh             |  9 +++++++++
 dune/xt/la/container/matrix-interface.hh |  6 ++++++
 5 files changed, 47 insertions(+)

diff --git a/dune/xt/la/container/common.hh b/dune/xt/la/container/common.hh
index f48884936..55bb01fe2 100644
--- a/dune/xt/la/container/common.hh
+++ b/dune/xt/la/container/common.hh
@@ -553,6 +553,13 @@ public:
     return backend()[ii][jj];
   } // ... get_entry(...)
 
+  ScalarType& get_entry_ref(const size_t ii, const size_t jj)
+  {
+    assert(ii < rows());
+    assert(jj < cols());
+    return backend()[ii][jj];
+  } // ... get_entry(...)
+
   void clear_row(const size_t ii)
   {
     auto& backend_ref = backend();
@@ -840,6 +847,15 @@ public:
     return index == size_t(-1) ? ScalarType(0) : entries_->operator[](index);
   }
 
+  inline ScalarType& get_entry_ref(const size_t rr, const size_t cc)
+  {
+    std::lock_guard<std::mutex> DUNE_UNUSED(lock)(mutex_);
+    const size_t index = get_entry_index(rr, cc, false);
+    if (index == size_t(-1))
+      DUNE_THROW(Dune::RangeError, "Entry not in matrix pattern!");
+    return entries_->operator[](index);
+  }
+
   inline void set_entry(const size_t rr, const size_t cc, const ScalarType value)
   {
     ensure_uniqueness();
diff --git a/dune/xt/la/container/eigen/dense.hh b/dune/xt/la/container/eigen/dense.hh
index 2d32790a3..974924026 100644
--- a/dune/xt/la/container/eigen/dense.hh
+++ b/dune/xt/la/container/eigen/dense.hh
@@ -566,6 +566,13 @@ public:
     return backend()(ii, jj);
   } // ... get_entry(...)
 
+  ScalarType& get_entry_ref(const size_t ii, const size_t jj)
+  {
+    assert(ii < rows());
+    assert(jj < cols());
+    return backend()(ii, jj);
+  } // ... get_entry(...)
+
   void clear_row(const size_t ii)
   {
     auto& backend_ref = backend();
diff --git a/dune/xt/la/container/eigen/sparse.hh b/dune/xt/la/container/eigen/sparse.hh
index 5ebad97bf..2b15e7ddd 100644
--- a/dune/xt/la/container/eigen/sparse.hh
+++ b/dune/xt/la/container/eigen/sparse.hh
@@ -292,6 +292,15 @@ public:
                            internal::boost_numeric_cast<EIGEN_size_t>(jj));
   }
 
+  ScalarType& get_entry_ref(const size_t ii, const size_t jj)
+  {
+    std::lock_guard<std::mutex> DUNE_UNUSED(lock)(mutex_);
+    assert(ii < rows());
+    assert(jj < cols());
+    return backend().coeffRef(internal::boost_numeric_cast<EIGEN_size_t>(ii),
+                              internal::boost_numeric_cast<EIGEN_size_t>(jj));
+  }
+
   void clear_row(const size_t ii)
   {
     auto& backend_ref = backend();
diff --git a/dune/xt/la/container/istl.hh b/dune/xt/la/container/istl.hh
index 19ebf86c2..a2a98335d 100644
--- a/dune/xt/la/container/istl.hh
+++ b/dune/xt/la/container/istl.hh
@@ -558,6 +558,15 @@ public:
       return ScalarType(0);
   } // ... get_entry(...)
 
+  ScalarType& get_entry_ref(const size_t ii, const size_t jj)
+  {
+    assert(ii < rows());
+    assert(jj < cols());
+    if (!these_are_valid_indices(ii, jj))
+      DUNE_THROW(Dune::RangeError, "Matrix entry not in pattern!");
+    return backend_->operator[](ii)[jj][0][0];
+  } // ... get_entry(...)
+
   void clear_row(const size_t ii)
   {
     auto& backend_ref = backend();
diff --git a/dune/xt/la/container/matrix-interface.hh b/dune/xt/la/container/matrix-interface.hh
index fded9b502..9094b73c3 100644
--- a/dune/xt/la/container/matrix-interface.hh
+++ b/dune/xt/la/container/matrix-interface.hh
@@ -86,6 +86,12 @@ public:
     return this->as_imp().get_entry(ii, jj);
   }
 
+  inline ScalarType& get_entry_ref(const size_t ii, const size_t jj)
+  {
+    CHECK_CRTP(this->as_imp().get_entry_ref(ii, jj));
+    return this->as_imp().get_entry_ref(ii, jj);
+  }
+
   inline void clear_row(const size_t ii)
   {
     CHECK_AND_CALL_CRTP(this->as_imp().clear_row(ii));
-- 
GitLab