Skip to content
Snippets Groups Projects
matrix-view.hh 16.4 KiB
Newer Older
René Fritze's avatar
René Fritze committed
// This file is part of the dune-xt project:
// Copyright 2009-2018 dune-xt developers and contributors. All rights reserved.
// License: Dual licensed as BSD 2-Clause License (
//      or  GPL-2.0+ (
//          with "runtime exception" (
// Authors:
René Fritze's avatar
René Fritze committed
//   Tobias Leibner (2019)


#include <dune/xt/common/exceptions.hh>
#include <dune/xt/common/parallel/threadstorage.hh>

#include "matrix-interface.hh"

namespace Dune {
namespace XT {
namespace LA {

// forwards
template <class MatrixImp>
class ConstMatrixView;

template <class MatrixImp>
class MatrixView;

namespace internal {

template <class MatrixImp>
class ConstMatrixViewTraits
  : public MatrixTraitsBase<typename MatrixImp::ScalarType,
                            typename MatrixImp::Traits::BackendType,

template <class MatrixImp>
class MatrixViewTraits
  : public MatrixTraitsBase<typename MatrixImp::ScalarType,
                            typename MatrixImp::Traits::BackendType,

template <class MatrixImp>
MatrixImp& empty_matrix_ref()
  static MatrixImp matrix_;
  return matrix_;

} // namespace internal

template <class MatrixImp>
class ConstMatrixView
  : public MatrixInterface<internal::ConstMatrixViewTraits<MatrixImp>, typename MatrixImp::ScalarType>
  using BaseType = MatrixInterface<internal::ConstMatrixViewTraits<MatrixImp>, typename MatrixImp::ScalarType>;
  using ThisType = ConstMatrixView;

  using ScalarType = typename BaseType::ScalarType;
  using RealType = typename BaseType::RealType;
  using Matrix = MatrixImp;

  // This constructor is only here for the interface to compile
  explicit ConstMatrixView(const size_t /*rr*/ = 0,
                           const size_t /*cc*/ = 0,
                           const ScalarType /*value*/ = ScalarType(0),
                           const size_t /*num_mutexes*/ = 1)
    : matrix_(internal::empty_matrix_ref<MatrixImp>())
    , first_row_(0)
    , past_last_row_(0)
    , first_col_(0)
    , past_last_col_(0)
    , pattern_(std::shared_ptr<SparsityPatternDefault>(nullptr))
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "This constructor does not make sense for MatrixView");

  // This constructor is only here for the interface to compile
  ConstMatrixView(const size_t /*rr*/,
                  const size_t /*cc*/,
                  const SparsityPatternDefault& /*pattern*/,
                  const size_t /*num_mutexes*/ = 1)
    : matrix_(internal::empty_matrix_ref<MatrixImp>())
    , first_row_(0)
    , past_last_row_(0)
    , first_col_(0)
    , past_last_col_(0)
    , pattern_(std::shared_ptr<SparsityPatternDefault>(nullptr))
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "This constructor does not make sense for MatrixView");

  // This is the actual constructor
  ConstMatrixView(const Matrix& matrix,
                  const size_t first_row,
                  const size_t past_last_row,
                  const size_t first_col,
                  const size_t past_last_col)
    : matrix_(matrix)
    , first_row_(first_row)
    , past_last_row_(past_last_row)
    , first_col_(first_col)
    , past_last_col_(past_last_col)
    , pattern_(std::shared_ptr<SparsityPatternDefault>(nullptr))

  size_t row_index(const size_t ii) const
    assert(ii < rows());
    return first_row_ + ii;

  size_t col_index(const size_t jj) const
    assert(jj < cols());
    return first_col_ + jj;

  inline size_t rows() const
    return past_last_row_ - first_row_;

  inline size_t cols() const
    return past_last_col_ - first_col_;

  inline void scal(const ScalarType& /*alpha*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  inline void axpy(const ScalarType& /*alpha*/, const ThisType& /*xx*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  template <class XX, class YY>
  inline void mv(const XX& xx, YY& yy) const
    assert(xx.size() == cols() && yy.size() == rows());
    const auto& patt = get_pattern();
    for (size_t ii = 0; ii < rows(); ++ii) {
      yy[ii] = 0.;
      for (auto&& jj : patt.inner(ii))
        yy[ii] += get_entry(ii, jj) * xx[jj];

  template <class XX, class YY>
  inline void mtv(const XX& xx, YY& yy) const
    assert(xx.size() == rows() && yy.size() == cols());
    const auto& patt = get_pattern();
    std::fill(yy.begin(), yy.end(), 0.);
    for (size_t ii = 0; ii < rows(); ++ii) {
      for (auto&& jj : patt.inner(ii))
        yy[jj] += get_entry(ii, jj) * xx[ii];

  inline void add_to_entry(const size_t /*ii*/, const size_t /*jj*/, const ScalarType& /*value*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  inline void set_entry(const size_t /*ii*/, const size_t /*jj*/, const ScalarType& /*value*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  inline ScalarType get_entry(const size_t ii, const size_t jj) const
    assert(ii < rows() && jj < cols());
    return matrix_.get_entry(row_index(ii), col_index(jj));

  inline void clear_row(const size_t /*ii*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  inline void clear_col(const size_t /*jj*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  inline void unit_row(const size_t /*ii*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  inline void unit_col(const size_t /*jj*/)
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "You cannot use non-const methods on ConstMatrixView");

  inline bool valid() const
    for (size_t ii = 0; ii < rows(); ++ii) {
      for (size_t jj = 0; jj < cols(); ++jj) {
        const auto entry = get_entry(ii, jj);
        if (Common::isnan(entry) || Common::isinf(entry))
          return false;
    return true;

  RealType sup_norm() const override final
    RealType ret = 0;
    for (size_t ii = 0; ii < rows(); ++ii)
      for (size_t jj = 0; jj < cols(); ++jj)
        ret = std::max(ret, std::abs(get_entry(ii, jj)));
    return ret;
  } // ... sup_norm(...)

  virtual SparsityPatternDefault pattern(const bool prune = false,
                                         const typename Common::FloatCmp::DefaultEpsilon<ScalarType>::Type eps =
                                             Common::FloatCmp::DefaultEpsilon<ScalarType>::value()) const override final
    SparsityPatternDefault ret(rows());
    auto matrix_patt = matrix_.pattern(prune, eps);
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : matrix_patt.inner(row_index(ii)))
        if (jj >= first_col_ && jj < past_last_col_)
          ret.insert(ii, jj - first_col_);
    return ret;
  } // ... pattern(...)

  const SparsityPatternDefault& get_pattern() const
    return **pattern_;

  operator Matrix() const
    const auto& patt = get_pattern();
    Matrix ret(rows(), cols(), patt);
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : patt.inner(ii))
        ret.set_entry(ii, jj, get_entry(ii, jj));
    return ret;

  void initialize_pattern() const
    if (!*pattern_)
      *pattern_ = std::make_shared<SparsityPatternDefault>(pattern());
  const Matrix& matrix_;
  const size_t first_row_;
  const size_t past_last_row_;
  const size_t first_col_;
  const size_t past_last_col_;
  mutable XT::Common::PerThreadValue<std::shared_ptr<SparsityPatternDefault>> pattern_;
}; // class ConstMatrixView

template <class MatrixImp>
class MatrixView : public MatrixInterface<internal::MatrixViewTraits<MatrixImp>, typename MatrixImp::ScalarType>
  using BaseType = MatrixInterface<internal::MatrixViewTraits<MatrixImp>, typename MatrixImp::ScalarType>;
  using ConstMatrixViewType = ConstMatrixView<MatrixImp>;
  using ThisType = MatrixView;

  using ScalarType = typename BaseType::ScalarType;
  using RealType = typename BaseType::RealType;
  using Matrix = MatrixImp;

  // This constructor is only here for the interface to compile
  explicit MatrixView(const size_t /*rr*/ = 0,
                      const size_t /*cc*/ = 0,
                      const ScalarType /*value*/ = ScalarType(0),
                      const size_t /*num_mutexes*/ = 1)
    : const_matrix_view_()
    , matrix_(internal::empty_matrix_ref<MatrixImp>())
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "This constructor does not make sense for MatrixView");

  // This constructor is only here for the interface to compile
  MatrixView(const size_t /*rr*/,
             const size_t /*cc*/,
             const SparsityPatternDefault& /*pattern*/,
             const size_t /*num_mutexes*/ = 1)
    : const_matrix_view_()
    , matrix_(internal::empty_matrix_ref<MatrixImp>())
    DUNE_THROW(XT::Common::Exceptions::you_are_using_this_wrong, "This constructor does not make sense for MatrixView");

  // This is the actual constructor
  MatrixView(Matrix& matrix,
             const size_t first_row,
             const size_t past_last_row,
             const size_t first_col,
             const size_t past_last_col)
    : const_matrix_view_(matrix, first_row, past_last_row, first_col, past_last_col)
    , matrix_(matrix)

  ThisType& operator=(const Matrix& other)
    const auto& patt = get_pattern();
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : patt.inner(ii))
        set_entry(ii, jj, other.get_entry(ii, jj));
    return *this;

  size_t row_index(const size_t ii) const
    return const_matrix_view_.row_index(ii);

  size_t col_index(const size_t jj) const
    return const_matrix_view_.col_index(jj);

  inline size_t rows() const
    return const_matrix_view_.rows();

  inline size_t cols() const
    return const_matrix_view_.cols();

  inline void scal(const ScalarType& alpha)
    const auto& patt = get_pattern();
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : patt.inner(ii))
        set_entry(ii, jj, get_entry(ii, jj) * alpha);

  inline void axpy(const ScalarType& alpha, const ThisType& xx)
    const auto other_patt = xx.pattern();
#ifndef NDEBUG
    const auto& patt = get_pattern();
    if (xx.rows() != rows() || xx.cols() != cols())
      DUNE_THROW(Common::Exceptions::shapes_do_not_match, "Shapes do not match!");
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : other_patt.inner(ii))
        if (!patt.contains(ii, jj))
          DUNE_THROW(Dune::MathError, "Pattern of xx has to be a subset of this pattern!");
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : other_patt.inner(ii))
        add_to_entry(ii, jj, xx.get_entry(ii, jj) * alpha);

  template <class OtherTraits>
  inline void axpy(const ScalarType& alpha, const MatrixInterface<OtherTraits, ScalarType>& xx)
    const auto& patt = get_pattern();
#ifndef NDEBUG
    const auto& other_patt = xx.pattern();
    if (xx.rows() != rows() || xx.cols() != cols())
      DUNE_THROW(Common::Exceptions::shapes_do_not_match, "Shapes do not match!");
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : other_patt.inner(ii)) {
        // The EigenRowMajorSparseMatrix automatically adds one entry to an empty row
        if (!patt.contains(ii, jj)
            && !(patt.inner(ii).size() == 0 && other_patt.inner(ii).size() == 1 && other_patt.inner(ii)[0] == 0))
          DUNE_THROW(Dune::MathError, "Pattern of xx has to be a subset of this pattern!");
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : patt.inner(ii))
        add_to_entry(ii, jj, xx.get_entry(ii, jj) * alpha);

  template <class XX, class YY>
  inline void mv(const XX& xx, YY& yy) const
    return, yy);

  template <class XX, class YY>
  inline void mtv(const XX& xx, YY& yy) const
    return const_matrix_view_.mtv(xx, yy);

  inline bool valid() const
    return const_matrix_view_.valid();

  RealType sup_norm() const override final
    return const_matrix_view_.sup_norm();
  } // ... sup_norm(...)

  virtual SparsityPatternDefault pattern(const bool prune = false,
                                         const typename Common::FloatCmp::DefaultEpsilon<ScalarType>::Type eps =
                                             Common::FloatCmp::DefaultEpsilon<ScalarType>::value()) const override final
    return const_matrix_view_.pattern(prune, eps);
  } // ... pattern(...)

  inline void add_to_entry(const size_t ii, const size_t jj, const ScalarType& value)
    assert(ii < rows() && jj < cols());
    matrix_.add_to_entry(row_index(ii), col_index(jj), value);

  inline void set_entry(const size_t ii, const size_t jj, const ScalarType& value)
    assert(ii < rows() && jj < cols());
    matrix_.set_entry(row_index(ii), col_index(jj), value);

  inline ScalarType get_entry(const size_t ii, const size_t jj) const
    return const_matrix_view_.get_entry(ii, jj);

  inline void clear_row(const size_t ii)
    const auto& patt = get_pattern();
    for (auto&& jj : patt.inner(ii))
      set_entry(ii, jj, 0.);

  inline void clear_col(const size_t jj)
    const auto& patt = get_pattern();
    for (size_t ii = 0; ii < rows(); ++ii)
      if (std::find(patt.inner(ii).begin(), patt.inner(ii).end(), jj) != patt.inner(ii).end())
        set_entry(ii, jj, 0.);

  inline void unit_row(const size_t ii)
    set_entry(ii, ii, 1.);

  inline void unit_col(const size_t jj)
    set_entry(jj, jj, 1.);

  operator Matrix() const
    return const_matrix_view_.operator Matrix();

  const SparsityPatternDefault& get_pattern() const
    return const_matrix_view_.get_pattern();

  bool pattern_assignable(const Matrix& other) const
    const auto& patt = get_pattern();
    const auto& other_patt = other.pattern();
    for (size_t ii = 0; ii < rows(); ++ii)
      for (auto&& jj : other_patt.inner(ii))
        if (!patt.contains(ii, jj) && !XT::Common::is_zero(other.get_entry(ii, jj)))
          return false;
    return true;

  ConstMatrixViewType const_matrix_view_;
  Matrix& matrix_;
}; // class MatrixView

} // namespace LA
namespace Common {

template <class MatrixImp>
struct MatrixAbstraction<LA::ConstMatrixView<MatrixImp>>
  : public LA::internal::MatrixAbstractionBase<LA::ConstMatrixView<MatrixImp>>
  using BaseType = LA::internal::MatrixAbstractionBase<MatrixImp>;
  static const constexpr Common::StorageLayout storage_layout = MatrixAbstraction<MatrixImp>::storage_layout;

  template <size_t rows = BaseType::static_rows,
            size_t cols = BaseType::static_cols,
            class FieldType = typename MatrixImp::ScalarType>
  using MatrixTypeTemplate = LA::ConstMatrixView<MatrixImp>;

template <class MatrixImp>
struct MatrixAbstraction<LA::MatrixView<MatrixImp>>
  : public LA::internal::MatrixAbstractionBase<LA::MatrixView<MatrixImp>>
  using BaseType = LA::internal::MatrixAbstractionBase<MatrixImp>;
  static const constexpr Common::StorageLayout storage_layout = MatrixAbstraction<MatrixImp>::storage_layout;

  template <size_t rows = BaseType::static_rows,
            size_t cols = BaseType::static_cols,
            class FieldType = typename MatrixImp::ScalarType>
  using MatrixTypeTemplate = LA::MatrixView<MatrixImp>;

} // namespace Common
} // namespace XT
} // namespace Dune