Skip to content
Snippets Groups Projects
Commit 769f2152 authored by Tobias Leibner's avatar Tobias Leibner
Browse files

[fmatrix/matrix] add heap creation method to MatrixAbstraction

parent 560d5ed8
No related branches found
No related tags found
No related merge requests found
...@@ -386,6 +386,16 @@ struct MatrixAbstraction<Dune::XT::Common::FieldMatrix<K, N, M>> ...@@ -386,6 +386,16 @@ struct MatrixAbstraction<Dune::XT::Common::FieldMatrix<K, N, M>>
return MatrixType(rows, cols, val); return MatrixType(rows, cols, val);
} }
static inline std::unique_ptr<MatrixType> create_dynamic(const size_t rows, const size_t cols)
{
return std::make_unique<MatrixType>(rows, cols);
}
static inline std::unique_ptr<MatrixType> create_dynamic(const size_t rows, const size_t cols, const ScalarType& val)
{
return std::make_unique<MatrixType>(rows, cols, val);
}
static inline size_t rows(const MatrixType& /*mat*/) static inline size_t rows(const MatrixType& /*mat*/)
{ {
return N; return N;
...@@ -575,6 +585,19 @@ Dune::XT::Common::FieldMatrix<K, L_ROWS, R_COLS> operator*(const Dune::FieldMatr ...@@ -575,6 +585,19 @@ Dune::XT::Common::FieldMatrix<K, L_ROWS, R_COLS> operator*(const Dune::FieldMatr
return left.rightmultiplyany(right); return left.rightmultiplyany(right);
} }
template <class K, int L_ROWS, int L_COLS, int R_COLS>
void rightmultiply(Dune::FieldMatrix<K, L_ROWS, R_COLS>& ret,
const Dune::FieldMatrix<K, L_ROWS, L_COLS>& left,
const Dune::FieldMatrix<K, L_COLS, R_COLS>& right)
{
for (size_t ii = 0; ii < L_ROWS; ++ii) {
for (size_t jj = 0; jj < R_COLS; ++jj) {
ret[ii][jj] = 0.;
for (size_t kk = 0; kk < L_COLS; ++kk)
ret[ii][jj] += left[ii][kk] * right[kk][jj];
}
}
}
template <class L, int L_ROWS, int L_COLS, class R, int R_COLS> template <class L, int L_ROWS, int L_COLS, class R, int R_COLS>
typename std::enable_if<!std::is_same<L, R>::value, typename std::enable_if<!std::is_same<L, R>::value,
...@@ -587,6 +610,35 @@ typename std::enable_if<!std::is_same<L, R>::value, ...@@ -587,6 +610,35 @@ typename std::enable_if<!std::is_same<L, R>::value,
return convert_to<Promoted>(left).rightmultiplyany(convert_to<Promoted>(right)); return convert_to<Promoted>(left).rightmultiplyany(convert_to<Promoted>(right));
} }
// versions that do not allocate matrices on the stack (for large matrices)
template <class K, int L_ROWS, int L_COLS, int R_COLS>
std::unique_ptr<Dune::XT::Common::FieldMatrix<K, L_ROWS, R_COLS>>
operator*(const std::unique_ptr<Dune::FieldMatrix<K, L_ROWS, L_COLS>>& left,
const Dune::FieldMatrix<K, L_COLS, R_COLS>& right)
{
auto ret = std::make_unique<Dune::XT::Common::FieldMatrix<K, L_ROWS, R_COLS>>();
rightmultiply(*ret, *left, right);
return ret;
}
template <class K, int L_ROWS, int L_COLS, int R_COLS>
std::unique_ptr<Dune::XT::Common::FieldMatrix<K, L_ROWS, R_COLS>>
operator*(const Dune::FieldMatrix<K, L_ROWS, L_COLS>& left,
const std::unique_ptr<Dune::FieldMatrix<K, L_COLS, R_COLS>>& right)
{
auto ret = std::make_unique<Dune::XT::Common::FieldMatrix<K, L_ROWS, R_COLS>>();
rightmultiply(*ret, left, *right);
return ret;
}
template <class K, int L_ROWS, int L_COLS, int R_COLS>
std::unique_ptr<Dune::XT::Common::FieldMatrix<K, L_ROWS, R_COLS>>
operator*(const std::unique_ptr<Dune::FieldMatrix<K, L_ROWS, L_COLS>>& left,
const std::unique_ptr<Dune::FieldMatrix<K, L_COLS, R_COLS>>& right)
{
return left * *right;
}
} // namespace Dune } // namespace Dune
......
...@@ -320,21 +320,25 @@ void dtrsm(const int layout, ...@@ -320,21 +320,25 @@ void dtrsm(const int layout,
const int ldb) const int ldb)
{ {
#if HAVE_MKL #if HAVE_MKL
return cblas_dtrsm(static_cast<CBLAS_LAYOUT>(layout), cblas_dtrsm(static_cast<CBLAS_LAYOUT>(layout),
static_cast<CBLAS_SIDE>(side), static_cast<CBLAS_SIDE>(side),
static_cast<CBLAS_UPLO>(uplo), static_cast<CBLAS_UPLO>(uplo),
static_cast<CBLAS_TRANSPOSE>(transa), static_cast<CBLAS_TRANSPOSE>(transa),
static_cast<CBLAS_DIAG>(diag), static_cast<CBLAS_DIAG>(diag),
m, m,
n, n,
alpha, alpha,
a, a,
lda, lda,
b, b,
ldb); ldb);
#ifndef NDEBUG
for (size_t ii = 0; ii < m; ++ii)
if (std::isnan(b[ii]) || std::isinf(b[ii]))
DUNE_THROW(Dune::MathError, "Triangular solve using cblas_dtrsm failed!");
#endif
#else #else
DUNE_THROW(Exceptions::dependency_missing, "You are missing CBLAS or the intel mkl, check available() first!"); DUNE_THROW(Exceptions::dependency_missing, "You are missing CBLAS or the intel mkl, check available() first!");
return 1;
#endif #endif
} }
......
...@@ -61,6 +61,12 @@ struct MatrixAbstraction ...@@ -61,6 +61,12 @@ struct MatrixAbstraction
static_assert(AlwaysFalse<MatType>::value, "Do not call me if is_matrix is false!"); static_assert(AlwaysFalse<MatType>::value, "Do not call me if is_matrix is false!");
} }
static inline /*std::unique_ptr<MatrixType>*/ void
create_dynamic(const size_t /*rows*/, const size_t /*cols*/, const ScalarType& /*val*/)
{
static_assert(AlwaysFalse<MatType>::value, "Do not call me if is_matrix is false!");
}
static inline /*size_t*/ void rows(const MatrixType& /*mat*/) static inline /*size_t*/ void rows(const MatrixType& /*mat*/)
{ {
static_assert(AlwaysFalse<MatType>::value, "Do not call me if is_matrix is false!"); static_assert(AlwaysFalse<MatType>::value, "Do not call me if is_matrix is false!");
...@@ -119,6 +125,16 @@ struct MatrixAbstraction<Dune::DynamicMatrix<K>> ...@@ -119,6 +125,16 @@ struct MatrixAbstraction<Dune::DynamicMatrix<K>>
return MatrixType(rows, cols, val); return MatrixType(rows, cols, val);
} }
static inline std::unique_ptr<MatrixType> create_dynamic(const size_t rows, const size_t cols)
{
return std::make_unique<MatrixType>(rows, cols);
}
static inline std::unique_ptr<MatrixType> create_dynamic(const size_t rows, const size_t cols, const ScalarType& val)
{
return std::make_unique<MatrixType>(rows, cols, val);
}
static inline size_t rows(const MatrixType& mat) static inline size_t rows(const MatrixType& mat)
{ {
return mat.rows(); return mat.rows();
...@@ -175,6 +191,24 @@ struct MatrixAbstraction<Dune::FieldMatrix<K, N, M>> ...@@ -175,6 +191,24 @@ struct MatrixAbstraction<Dune::FieldMatrix<K, N, M>>
return MatrixType(val); return MatrixType(val);
} }
static inline std::unique_ptr<MatrixType> create_dynamic(const size_t rows, const size_t cols)
{
if (rows != N)
DUNE_THROW(Exceptions::shapes_do_not_match, "rows = " << rows << "\nN = " << int(N));
if (cols != M)
DUNE_THROW(Exceptions::shapes_do_not_match, "cols = " << cols << "\nM = " << int(M));
return std::make_unique<MatrixType>();
}
static inline std::unique_ptr<MatrixType> create_dynamic(const size_t rows, const size_t cols, const ScalarType& val)
{
if (rows != N)
DUNE_THROW(Exceptions::shapes_do_not_match, "rows = " << rows << "\nN = " << int(N));
if (cols != M)
DUNE_THROW(Exceptions::shapes_do_not_match, "cols = " << cols << "\nM = " << int(M));
return std::make_unique<MatrixType>(val);
}
static inline size_t rows(const MatrixType& /*mat*/) static inline size_t rows(const MatrixType& /*mat*/)
{ {
return numeric_cast<size_t>(N); return numeric_cast<size_t>(N);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment