From 8f7cd92f699c7695ec7b2d8ffd193cb177f4c8f1 Mon Sep 17 00:00:00 2001
From: Tobias Leibner <tobias.leibner@googlemail.com>
Date: Thu, 29 Jun 2017 14:18:17 +0200
Subject: [PATCH] [container.matrix-interface] add operators

---
 dune/xt/la/container/common.hh           |  25 +++--
 dune/xt/la/container/eigen/dense.hh      |   5 +
 dune/xt/la/container/eigen/sparse.hh     |   5 +
 dune/xt/la/container/istl.hh             |   5 +
 dune/xt/la/container/matrix-interface.hh | 111 +++++++++++++++++++++++
 dune/xt/la/container/pattern.cc          |  11 +++
 dune/xt/la/container/pattern.hh          |   2 +
 7 files changed, 156 insertions(+), 8 deletions(-)

diff --git a/dune/xt/la/container/common.hh b/dune/xt/la/container/common.hh
index 8ae76d14a..01dbb6ede 100644
--- a/dune/xt/la/container/common.hh
+++ b/dune/xt/la/container/common.hh
@@ -632,6 +632,11 @@ public:
 
   /// \}
 
+  using MatrixInterfaceType::operator+;
+  using MatrixInterfaceType::operator-;
+  using MatrixInterfaceType::operator+=;
+  using MatrixInterfaceType::operator-=;
+
 protected:
   /**
    * \see ContainerInterface
@@ -944,6 +949,18 @@ public:
 
   /// \}
 
+  using MatrixInterfaceType::operator+;
+  using MatrixInterfaceType::operator-;
+  using MatrixInterfaceType::operator+=;
+  using MatrixInterfaceType::operator-=;
+
+protected:
+  inline void ensure_uniqueness()
+  {
+    if (!entries_.unique())
+      entries_ = std::make_shared<EntriesVectorType>(*entries_);
+  } // ... ensure_uniqueness(...)
+
 private:
   size_t get_entry_index(const size_t rr, const size_t cc, const bool throw_if_not_in_pattern = true) const
   {
@@ -959,14 +976,6 @@ private:
     return size_t(-1);
   }
 
-protected:
-  inline void ensure_uniqueness()
-  {
-    if (!entries_.unique())
-      entries_ = std::make_shared<EntriesVectorType>(*entries_);
-  } // ... ensure_uniqueness(...)
-
-private:
   size_t num_rows_, num_cols_;
   std::shared_ptr<EntriesVectorType> entries_;
   std::shared_ptr<IndexVectorType> row_pointers_;
diff --git a/dune/xt/la/container/eigen/dense.hh b/dune/xt/la/container/eigen/dense.hh
index 5f261be8f..cf926cff3 100644
--- a/dune/xt/la/container/eigen/dense.hh
+++ b/dune/xt/la/container/eigen/dense.hh
@@ -647,6 +647,11 @@ public:
    * \}
    */
 
+  using MatrixInterfaceType::operator+;
+  using MatrixInterfaceType::operator-;
+  using MatrixInterfaceType::operator+=;
+  using MatrixInterfaceType::operator-=;
+
 protected:
   inline void ensure_uniqueness()
   {
diff --git a/dune/xt/la/container/eigen/sparse.hh b/dune/xt/la/container/eigen/sparse.hh
index 281f7fbbb..799b42bb5 100644
--- a/dune/xt/la/container/eigen/sparse.hh
+++ b/dune/xt/la/container/eigen/sparse.hh
@@ -434,6 +434,11 @@ public:
 
   /// \}
 
+  using MatrixInterfaceType::operator+;
+  using MatrixInterfaceType::operator-;
+  using MatrixInterfaceType::operator+=;
+  using MatrixInterfaceType::operator-=;
+
 private:
   bool these_are_valid_indices(const size_t ii, const size_t jj) const
   {
diff --git a/dune/xt/la/container/istl.hh b/dune/xt/la/container/istl.hh
index fae004561..fac0a369c 100644
--- a/dune/xt/la/container/istl.hh
+++ b/dune/xt/la/container/istl.hh
@@ -690,6 +690,11 @@ public:
 
   /// \}
 
+  using MatrixInterfaceType::operator+;
+  using MatrixInterfaceType::operator-;
+  using MatrixInterfaceType::operator+=;
+  using MatrixInterfaceType::operator-=;
+
 private:
   void build_sparse_matrix(const size_t rr, const size_t cc, const SparsityPatternDefault& patt)
   {
diff --git a/dune/xt/la/container/matrix-interface.hh b/dune/xt/la/container/matrix-interface.hh
index a3c4b6bcc..62c1037cf 100644
--- a/dune/xt/la/container/matrix-interface.hh
+++ b/dune/xt/la/container/matrix-interface.hh
@@ -114,6 +114,36 @@ public:
     CHECK_AND_CALL_CRTP(this->as_imp().unit_col(jj));
   }
 
+  template <class MM>
+  derived_type operator*(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    return multiply(other);
+  }
+
+  template <class MM>
+  derived_type operator+(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    return add(other);
+  }
+
+  template <class MM>
+  derived_type operator-(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    return subtract(other);
+  }
+
+  template <class MM>
+  derived_type& operator+=(const MatrixInterface<MM, ScalarType>& other)
+  {
+    return add_assign(other);
+  }
+
+  template <class MM>
+  derived_type& operator-=(const MatrixInterface<MM, ScalarType>& other)
+  {
+    return subtract_assign(other);
+  }
+
   /**
    * \brief  Checks entries for inf or nan.
    * \return false if any entry is inf or nan, else true
@@ -148,6 +178,15 @@ public:
     return ret;
   } // ... sup_norm(...)
 
+  derived_type transposed() const
+  {
+    derived_type yy(rows(), cols(), 0.);
+    for (size_t rr = 0; rr < rows(); ++rr)
+      for (size_t cc = 0; cc < cols(); ++cc)
+        yy.set_entry(rr, cc, get_entry(cc, rr));
+    return yy;
+  }
+
   /**
    * \brief Returns the number of entries in the sparsity pattern of the matrix.
    *
@@ -243,6 +282,78 @@ public:
   } // ... almost_equal(...)
   /// \}
 
+protected:
+  template <class MM>
+  derived_type multiply(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    if (other.rows() != cols())
+      DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be multiplied do not match!");
+    derived_type yy(rows(), other.cols(), 0.);
+    for (size_t rr = 0; rr < rows(); ++rr)
+      for (size_t cc = 0; cc < other.cols(); ++cc)
+        for (size_t kk = 0; kk < cols(); ++kk)
+          yy.add_to_entry(rr, cc, get_entry(rr, kk) * other.get_entry(kk, cc));
+    return yy;
+  }
+
+  template <class MM>
+  derived_type add(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    if (other.rows() != rows() || other.cols() != cols())
+      DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!");
+    auto new_pattern = pattern() + other.pattern();
+    derived_type yy(rows(), other.cols(), new_pattern);
+    for (size_t rr = 0; rr < rows(); ++rr)
+      for (const auto& cc : new_pattern[rr])
+        yy.set_entry(rr, cc, get_entry(rr, cc) + other.get_entry(rr, cc));
+    return yy;
+  }
+
+  template <class MM>
+  derived_type subtract(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    if (other.rows() != rows() || other.cols() != cols())
+      DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!");
+    auto new_pattern = pattern() + other.pattern();
+    derived_type yy(rows(), other.cols(), new_pattern);
+    for (size_t rr = 0; rr < rows(); ++rr)
+      for (const auto& cc : new_pattern[rr])
+        yy.set_entry(rr, cc, get_entry(rr, cc) - other.get_entry(rr, cc));
+    return yy;
+  }
+
+  template <class MM>
+  derived_type& add_assign(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    if (other.rows() != rows() || other.cols() != cols())
+      DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!");
+    const auto this_pattern = pattern();
+    auto new_pattern = this_pattern + other.pattern();
+    if (new_pattern != this_pattern)
+      DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match,
+                 "The matrix to be added contains entries that are not in this' pattern!");
+    for (size_t rr = 0; rr < rows(); ++rr)
+      for (const auto& cc : this_pattern[rr])
+        add_to_entry(rr, cc, other.get_entry(rr, cc));
+    return this->as_imp();
+  }
+
+  template <class MM>
+  derived_type& subtract_assign(const MatrixInterface<MM, ScalarType>& other) const
+  {
+    if (other.rows() != rows() || other.cols() != cols())
+      DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match, "Dimensions of matrices to be added do not match!");
+    const auto this_pattern = pattern();
+    auto new_pattern = this_pattern + other.pattern();
+    if (new_pattern != this_pattern)
+      DUNE_THROW(XT::Common::Exceptions::shapes_do_not_match,
+                 "The matrix to be added contains entries that are not in this' pattern!");
+    for (size_t rr = 0; rr < rows(); ++rr)
+      for (const auto& cc : this_pattern[rr])
+        add_to_entry(rr, cc, -other.get_entry(rr, cc));
+    return this->as_imp();
+  }
+
 private:
   template <class T, class S>
   friend std::ostream& operator<<(std::ostream& /*out*/, const MatrixInterface<T, S>& /*matrix*/);
diff --git a/dune/xt/la/container/pattern.cc b/dune/xt/la/container/pattern.cc
index edeb185c9..a9ea00301 100644
--- a/dune/xt/la/container/pattern.cc
+++ b/dune/xt/la/container/pattern.cc
@@ -63,6 +63,17 @@ bool SparsityPatternDefault::operator!=(const SparsityPatternDefault& other) con
   return vector_of_vectors_ != other.vector_of_vectors_;
 }
 
+SparsityPatternDefault operator+(const SparsityPatternDefault& other) const
+{
+  assert(other.size() == size() && "SparsityPatterns must have the same number of rows for addition!");
+  SparsityPatternDefault ret = *this;
+  for (size_t rr = 0; rr < size(); ++rr)
+    for (const auto& cc : vector_of_vectors_[rr])
+      ret.insert(rr, cc);
+  ret.sort();
+  return ret;
+}
+
 void SparsityPatternDefault::insert(const size_t outer_index, const size_t inner_index)
 {
   assert(outer_index < size() && "Wrong index requested!");
diff --git a/dune/xt/la/container/pattern.hh b/dune/xt/la/container/pattern.hh
index 0a0e84873..aa36343a3 100644
--- a/dune/xt/la/container/pattern.hh
+++ b/dune/xt/la/container/pattern.hh
@@ -45,6 +45,8 @@ public:
 
   bool operator!=(const SparsityPatternDefault& other) const;
 
+  SparsityPatternDefault operator+(const SparsityPatternDefault& other) const;
+
   void insert(const size_t outer_index, const size_t inner_index);
 
   void sort(const size_t outer_index);
-- 
GitLab