Skip to content
Snippets Groups Projects
Commit 94e418ff authored by Tobias Leibner's avatar Tobias Leibner
Browse files

[pattern] add contains for patterns

parent d873b0a1
No related branches found
No related tags found
No related merge requests found
......@@ -253,12 +253,23 @@ public:
return **pattern_;
}
operator Matrix() const
{
const auto& patt = get_pattern();
Matrix ret(rows(), cols(), patt);
for (size_t ii = 0; ii < rows(); ++ii)
for (auto&& jj : patt.inner(ii))
ret.set_entry(ii, jj, get_entry(ii, jj));
return ret;
}
private:
void initialize_pattern() const
{
if (!*pattern_)
*pattern_ = std::make_shared<SparsityPatternDefault>(pattern());
}
const Matrix& matrix_;
const size_t first_row_;
const size_t past_last_row_;
......@@ -314,7 +325,7 @@ public:
ThisType& operator=(const Matrix& other)
{
const auto& patt = const_matrix_view_.get_pattern();
assert(patt == other.pattern());
assert(pattern_assignable(other));
for (size_t ii = 0; ii < rows(); ++ii)
for (auto&& jj : patt.inner(ii))
set_entry(ii, jj, other.get_entry(ii, jj));
......@@ -439,7 +450,23 @@ public:
set_entry(jj, jj, 1.);
}
operator Matrix() const
{
return const_matrix_view_.operator Matrix();
}
private:
bool pattern_assignable(const Matrix& other) const
{
const auto& patt = const_matrix_view_.get_pattern();
const auto& other_patt = other.pattern();
for (size_t ii = 0; ii < rows(); ++ii)
for (auto&& jj : other_patt.inner(ii))
if (!patt.contains(ii, jj) && !XT::Common::is_zero(other.get_entry(ii, jj)))
return false;
return true;
}
ConstMatrixViewType const_matrix_view_;
Matrix& matrix_;
}; // class MatrixView
......
......@@ -117,6 +117,15 @@ bool SparsityPatternDefault::contains(const size_t outer_index, const size_t inn
return std::find(row.begin(), row.end(), inner_index) != row.end();
}
bool SparsityPatternDefault::contains(const SparsityPatternDefault& other) const
{
for (size_t rr = 0; rr < size(); ++rr)
for (auto&& cc : other.inner(rr))
if (!this->contains(rr, cc))
return false;
return true;
}
SparsityPatternDefault SparsityPatternDefault::transposed(const size_t cols) const
{
SparsityPatternDefault transposed_pattern(cols);
......
......@@ -63,6 +63,8 @@ public:
bool contains(const size_t outer_index, const size_t inner_index) const;
bool contains(const SparsityPatternDefault& other) const;
SparsityPatternDefault transposed(const size_t cols) const;
private:
......
......@@ -114,6 +114,24 @@ struct MatrixViewTest_{{T_NAME}} : public ::testing::Test
EXPECT_DOUBLE_OR_COMPLEX_EQ(RealType(0.5), testmatrix_sparse.get_entry(1, 0));
sparse_view_upperleft.set_entry(1, 0, 1.);
// test operator=
MatrixImp lowerright_saved = view_lowerright;
MatrixImp sparse_lowerright_saved = sparse_view_lowerright;
MatrixImp zeros_dense(2, 1, 0.);
LA::SparsityPatternDefault lowerright_pattern(2);
lowerright_pattern.insert(0, 0);
MatrixImp zeros_sparse(2, 1, lowerright_pattern);
view_lowerright = zeros_dense;
sparse_view_lowerright = zeros_sparse;
for (size_t ii = 0; ii < 2; ++ii) {
EXPECT_DOUBLE_OR_COMPLEX_EQ(RealType(0), view_lowerright.get_entry(ii, 0));
EXPECT_DOUBLE_OR_COMPLEX_EQ(RealType(0), const_view_lowerright.get_entry(ii, 0));
EXPECT_DOUBLE_OR_COMPLEX_EQ(RealType(0), sparse_view_lowerright.get_entry(ii, 0));
EXPECT_DOUBLE_OR_COMPLEX_EQ(RealType(0), sparse_const_view_lowerright.get_entry(ii, 0));
}
view_lowerright = lowerright_saved;
sparse_view_lowerright = sparse_lowerright_saved;
// test scal, operator*
const auto testmatrix_copy = testmatrix;
const auto testmatrix_sparse_copy = testmatrix_sparse;
......
......@@ -47,12 +47,17 @@ GTEST_TEST(SparsityPatternDefaultTest, test_interface)
}
}
auto pattern2 = pattern;
EXPECT_TRUE(pattern2.contains(pattern));
EXPECT_TRUE(pattern.contains(pattern2));
EXPECT_TRUE(pattern.contains(const_pattern));
pattern.sort();
for (size_t ii = 0; ii < ROWS; ++ii)
pattern2.sort(ii);
EXPECT_TRUE(pattern == pattern2);
EXPECT_TRUE(pattern != const_pattern);
auto test_pattern = pattern + pattern2;
EXPECT_TRUE(test_pattern.contains(pattern));
EXPECT_TRUE(test_pattern.contains(pattern2));
EXPECT_TRUE(pattern + pattern2 == pattern);
Pattern pattern3(ROWS), pattern4(ROWS), pattern5(ROWS);
for (size_t ii = 0; ii < ROWS; ++ii) {
......@@ -126,6 +131,11 @@ GTEST_TEST(SparsityPatternDefaultTest, test_creation_functions)
}
} // jj
} // ii
const auto diagonal_subdiagonal_patt = diagonal_patt + subdiagonal_patt;
EXPECT_TRUE(diagonal_subdiagonal_patt.contains(diagonal_patt));
EXPECT_TRUE(diagonal_subdiagonal_patt.contains(subdiagonal_patt));
EXPECT_FALSE(diagonal_patt.contains(diagonal_subdiagonal_patt));
EXPECT_FALSE(subdiagonal_patt.contains(diagonal_subdiagonal_patt));
XT::LA::SparsityPatternDefault pattern(ROWS), pattern2(COLS);
pattern.insert(0, 0);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment