Skip to content
Snippets Groups Projects
Unverified Commit 6b2c8c9d authored by Tobias Leibner's avatar Tobias Leibner
Browse files

[float_cmp] fix for matrices, add tests (fixes #19)

parent e9cff463
No related branches found
No related tags found
No related merge requests found
......@@ -86,9 +86,9 @@ float_cmp_eq(const XType& xx, const YType& yy, const TolType& rtol, const TolTyp
} // ... float_cmp(...)
template <class XType, class YType, class TolType>
typename std::enable_if<is_matrix<XType>::value && is_matrix<YType>::value && std::is_arithmetic<TolType>::value
&& std::is_same<typename MatrixAbstraction<XType>::R, TolType>::value
&& std::is_same<typename MatrixAbstraction<YType>::R, TolType>::value,
typename std::enable_if<is_matrix<XType>::value && is_matrix<YType>::value
&& std::is_same<typename MatrixAbstraction<XType>::S, TolType>::value
&& std::is_same<typename MatrixAbstraction<YType>::S, TolType>::value,
bool>::type
float_cmp_eq(const XType& xx, const YType& yy, const TolType& rtol, const TolType& atol)
{
......@@ -138,9 +138,9 @@ dune_float_cmp_eq(const XType& xx, const YType& yy, const EpsType& eps)
} // ... dune_float_cmp(...)
template <Dune::FloatCmp::CmpStyle style, class XType, class YType, class EpsType>
typename std::enable_if<is_matrix<XType>::value && is_matrix<YType>::value && std::is_arithmetic<EpsType>::value
&& std::is_same<typename MatrixAbstraction<XType>::R, EpsType>::value
&& std::is_same<typename MatrixAbstraction<YType>::R, EpsType>::value,
typename std::enable_if<is_matrix<XType>::value && is_matrix<YType>::value
&& std::is_same<typename MatrixAbstraction<XType>::S, EpsType>::value
&& std::is_same<typename MatrixAbstraction<YType>::S, EpsType>::value,
bool>::type
dune_float_cmp_eq(const XType& xx, const YType& yy, const EpsType& eps)
{
......
......@@ -248,6 +248,21 @@ create(const size_t rows, const size_t cols, const typename MatrixAbstraction<Ma
return MatrixAbstraction<MatrixType>::create(rows, cols, val);
}
template <class T, class SR>
typename std::enable_if<is_complex<T>::value, T>::type
create(const size_t /*rows*/, const size_t /*cols*/, const SR& val)
{
return VectorAbstraction<T>::create(0, val);
}
template <class MatrixType>
typename std::enable_if<std::is_arithmetic<MatrixType>::value, MatrixType>::type
create(const size_t /*rows*/, const size_t /*cols*/, const typename MatrixAbstraction<MatrixType>::S& val)
{
return val;
}
} // namespace Common
} // namespace XT
} // namespace Dune
......
......@@ -26,26 +26,59 @@ using namespace Dune;
using XT::Common::create;
using XT::Common::FloatCmp::Style;
// add operator+= for std::array and std::vector
template <typename T, size_t N>
std::array<T, N>& operator+=(std::array<T, N>& arr, const std::array<T, N>& other)
{
std::transform(arr.begin(), arr.end(), other.begin(), arr.begin(), std::plus<T>());
return arr;
}
template <typename T>
std::vector<T>& operator+=(std::vector<T>& vec, const std::vector<T>& other)
{
assert(vec.size() == other.size());
std::transform(vec.begin(), vec.end(), other.begin(), vec.begin(), std::plus<T>());
return vec;
}
struct FloatCmpTest : public testing::Test
{
typedef TESTTYPE V;
static const size_t s_size =
XT::Common::VectorAbstraction<V>::has_static_size ? XT::Common::VectorAbstraction<V>::static_size : VECSIZE;
typedef typename XT::Common::VectorAbstraction<V>::ScalarType S;
typedef typename XT::Common::VectorAbstraction<V>::RealType R;
typedef typename std::conditional<XT::Common::is_matrix<V>::value,
typename XT::Common::MatrixAbstraction<V>::ScalarType,
typename XT::Common::VectorAbstraction<V>::ScalarType>::type S;
typedef typename std::conditional<XT::Common::is_matrix<V>::value,
typename XT::Common::MatrixAbstraction<V>::RealType,
typename XT::Common::VectorAbstraction<V>::RealType>::type R;
static constexpr bool fieldtype_is_float = std::is_floating_point<R>::value;
static const size_t s_size =
XT::Common::is_matrix<V>::value
? (XT::Common::MatrixAbstraction<V>::has_static_size ? XT::Common::MatrixAbstraction<V>::static_rows
: VECSIZE)
: (XT::Common::VectorAbstraction<V>::has_static_size ? XT::Common::VectorAbstraction<V>::static_size
: VECSIZE);
static const size_t s_cols = XT::Common::is_matrix<V>::value ? (XT::Common::MatrixAbstraction<V>::has_static_size
? XT::Common::MatrixAbstraction<V>::static_cols
: NUMCOLS)
: 1.;
FloatCmpTest()
: zero(create<V>(s_size, create<S>(0, 0)))
, one(create<V>(s_size, create<S>(0, 1)))
, neg_one(create<V>(s_size, create<S>(0, -1)))
, epsilon(create<V>(s_size, XT::Common::FloatCmp::DEFAULT_EPSILON::value()))
, eps_plus(create<V>(s_size, XT::Common::FloatCmp::DEFAULT_EPSILON::value() * 1.1))
, eps_minus(create<V>(s_size, XT::Common::FloatCmp::DEFAULT_EPSILON::value() * 0.9))
, two(create<V>(s_size, create<S>(0, 2)))
: zero(create<V>(s_size, s_cols, create<S>(0, 0)))
, one(create<V>(s_size, s_cols, create<S>(0, 1)))
, neg_one(create<V>(s_size, s_cols, create<S>(0, -1)))
, epsilon(create<V>(s_size, s_cols, XT::Common::FloatCmp::DEFAULT_EPSILON::value()))
, eps_plus(create<V>(s_size, s_cols, XT::Common::FloatCmp::DEFAULT_EPSILON::value() * 1.1))
, eps_minus(create<V>(s_size, s_cols, XT::Common::FloatCmp::DEFAULT_EPSILON::value() * 0.9))
, two(create<V>(s_size, s_cols, create<S>(0, 2)))
, test_config(DXTC_CONFIG.sub("test_common_float_cmp"))
{
one_plus_eps_minus_ = one;
one_plus_eps_minus_ += eps_minus;
one_plus_eps_plus_ = one;
one_plus_eps_plus_ += eps_plus;
}
const V zero;
......@@ -55,14 +88,15 @@ struct FloatCmpTest : public testing::Test
const V eps_plus;
const V eps_minus;
const V two;
V one_plus_eps_minus_;
V one_plus_eps_plus_;
const typename XT::Common::Configuration test_config;
void check_eq()
{
TEST_DXTC_EXPECT_FLOAT_EQ(zero, zero);
EXPECT_FALSE(FLOATCMP_EQ(zero, one));
TEST_DXTC_EXPECT_FLOAT_EQ(one, one + eps_minus);
TEST_DXTC_EXPECT_FLOAT_EQ(one, one_plus_eps_minus_);
EXPECT_FALSE(FLOATCMP_EQ(one, two));
if (test_config["cmpstyle_is_relative"] == "true" && fieldtype_is_float)
......@@ -71,9 +105,9 @@ struct FloatCmpTest : public testing::Test
TEST_DXTC_EXPECT_FLOAT_EQ(zero, eps_minus);
if (test_config["cmpstyle_is_numpy"] == "true" || !fieldtype_is_float)
TEST_DXTC_EXPECT_FLOAT_EQ(one, one + eps_plus);
TEST_DXTC_EXPECT_FLOAT_EQ(one, one_plus_eps_plus_);
else
EXPECT_FALSE(FLOATCMP_EQ(one, one + eps_plus));
EXPECT_FALSE(FLOATCMP_EQ(one, one_plus_eps_plus_));
if (fieldtype_is_float)
EXPECT_FALSE(FLOATCMP_EQ(zero, eps_plus));
......@@ -85,7 +119,7 @@ struct FloatCmpTest : public testing::Test
{
EXPECT_FALSE(FLOATCMP_NE(zero, zero));
TEST_DXTC_EXPECT_FLOAT_NE(zero, one);
EXPECT_FALSE(FLOATCMP_NE(one, one + eps_minus));
EXPECT_FALSE(FLOATCMP_NE(one, one_plus_eps_minus_));
TEST_DXTC_EXPECT_FLOAT_NE(one, two);
if (test_config["cmpstyle_is_relative"] == "true" && fieldtype_is_float)
TEST_DXTC_EXPECT_FLOAT_NE(zero, eps_minus);
......@@ -93,9 +127,9 @@ struct FloatCmpTest : public testing::Test
EXPECT_FALSE(FLOATCMP_NE(zero, eps_minus));
if (test_config["cmpstyle_is_numpy"] == "true" || !fieldtype_is_float)
EXPECT_FALSE(FLOATCMP_NE(one, one + eps_plus));
EXPECT_FALSE(FLOATCMP_NE(one, one_plus_eps_plus_));
else
TEST_DXTC_EXPECT_FLOAT_NE(one, one + eps_plus);
TEST_DXTC_EXPECT_FLOAT_NE(one, one_plus_eps_plus_);
if (fieldtype_is_float)
TEST_DXTC_EXPECT_FLOAT_NE(zero, eps_plus);
......@@ -107,7 +141,7 @@ struct FloatCmpTest : public testing::Test
{
EXPECT_FALSE(FLOATCMP_GT(zero, zero));
TEST_DXTC_EXPECT_FLOAT_GT(one, zero);
EXPECT_FALSE(FLOATCMP_GT(one + eps_minus, one));
EXPECT_FALSE(FLOATCMP_GT(one_plus_eps_minus_, one));
TEST_DXTC_EXPECT_FLOAT_GT(two, one);
if (test_config["cmpstyle_is_relative"] == "true" && fieldtype_is_float)
......@@ -116,9 +150,9 @@ struct FloatCmpTest : public testing::Test
EXPECT_FALSE(FLOATCMP_GT(eps_minus, zero));
if (test_config["cmpstyle_is_numpy"] == "true" || !fieldtype_is_float)
EXPECT_FALSE(FLOATCMP_GT(one + eps_plus, one));
EXPECT_FALSE(FLOATCMP_GT(one_plus_eps_plus_, one));
else
TEST_DXTC_EXPECT_FLOAT_GT(one + eps_plus, one);
TEST_DXTC_EXPECT_FLOAT_GT(one_plus_eps_plus_, one);
if (fieldtype_is_float)
TEST_DXTC_EXPECT_FLOAT_GT(eps_plus, zero);
......@@ -130,7 +164,7 @@ struct FloatCmpTest : public testing::Test
{
EXPECT_FALSE(FLOATCMP_LT(zero, zero));
TEST_DXTC_EXPECT_FLOAT_LT(zero, one);
EXPECT_FALSE(FLOATCMP_LT(one, one + eps_minus));
EXPECT_FALSE(FLOATCMP_LT(one, one_plus_eps_minus_));
TEST_DXTC_EXPECT_FLOAT_LT(one, two);
if (test_config["cmpstyle_is_relative"] == "true" && fieldtype_is_float)
......@@ -139,9 +173,9 @@ struct FloatCmpTest : public testing::Test
EXPECT_FALSE(FLOATCMP_LT(zero, eps_minus));
if (test_config["cmpstyle_is_numpy"] == "true" || !fieldtype_is_float)
EXPECT_FALSE(FLOATCMP_LT(one, one + eps_plus));
EXPECT_FALSE(FLOATCMP_LT(one, one_plus_eps_plus_));
else
TEST_DXTC_EXPECT_FLOAT_LT(one, one + eps_plus);
TEST_DXTC_EXPECT_FLOAT_LT(one, one_plus_eps_plus_);
if (fieldtype_is_float)
TEST_DXTC_EXPECT_FLOAT_LT(zero, eps_plus);
......@@ -155,8 +189,8 @@ struct FloatCmpTest : public testing::Test
TEST_DXTC_EXPECT_FLOAT_GE(eps_minus, zero);
TEST_DXTC_EXPECT_FLOAT_GE(eps_plus, zero);
TEST_DXTC_EXPECT_FLOAT_GE(one, zero);
TEST_DXTC_EXPECT_FLOAT_GE(one + eps_minus, one);
TEST_DXTC_EXPECT_FLOAT_GE(one + eps_plus, one);
TEST_DXTC_EXPECT_FLOAT_GE(one_plus_eps_minus_, one);
TEST_DXTC_EXPECT_FLOAT_GE(one_plus_eps_plus_, one);
TEST_DXTC_EXPECT_FLOAT_GE(two, one);
}
......@@ -166,8 +200,8 @@ struct FloatCmpTest : public testing::Test
TEST_DXTC_EXPECT_FLOAT_LE(zero, eps_minus);
TEST_DXTC_EXPECT_FLOAT_LE(zero, eps_plus);
TEST_DXTC_EXPECT_FLOAT_LE(zero, one);
TEST_DXTC_EXPECT_FLOAT_LE(one, one + eps_minus);
TEST_DXTC_EXPECT_FLOAT_LE(one, one + eps_plus);
TEST_DXTC_EXPECT_FLOAT_LE(one, one_plus_eps_minus_);
TEST_DXTC_EXPECT_FLOAT_LE(one, one_plus_eps_plus_);
TEST_DXTC_EXPECT_FLOAT_LE(one, two);
}
......
......@@ -13,10 +13,13 @@ __local.fieldtype = std::size_t, long, double, std::complex<double>, unsigned |
__local.fieldtype_short = std_size_t, long, double, complex, unsigned | expand field
__local.vectortype = std::vector<{__local.fieldtype}>, std::array<{__local.fieldtype}\,{__local.vec_size}>, Dune::FieldVector<{__local.fieldtype}\,{__local.vec_size}>, Dune::XT::Common::FieldVector<{__local.fieldtype}\,{__local.vec_size}>, Dune::DynamicVector<{__local.fieldtype}> | expand vector
__local.vectortype_short = std_vector, std_array, dune_fieldvector, xt_fieldvector, dynamic_vector | expand vector
__local.testtype = {__local.fieldtype}, {__local.vectortype} | expand test
__local.testtype_short = {__local.fieldtype_short}, {__local.vectortype_short}_{__local.fieldtype_short} | expand test
__local.matrixtype = Dune::FieldMatrix<{__local.fieldtype}\,{__local.vec_size}\,{__local.num_cols}>, Dune::XT::Common::FieldMatrix<{__local.fieldtype}\,{__local.vec_size}\,{__local.num_cols}>, Dune::DynamicMatrix<{__local.fieldtype}> | expand matrix
__local.matrixtype_short = dune_fieldmatrix, xt_fieldmatrix, dynamic_matrix | expand matrix
__local.testtype = {__local.fieldtype}, {__local.vectortype}, {__local.matrixtype} | expand test
__local.testtype_short = {__local.fieldtype_short}, {__local.vectortype_short}_{__local.fieldtype_short}, {__local.matrixtype_short}_{__local.fieldtype_short} | expand test
__local.vec_size = 4
__local.num_cols = 4
[test_common_float_cmp]
cmpstyle_is_relative = false, {__local.cmpstyle_is_relative} | expand cmpstyle_templ
......@@ -25,6 +28,7 @@ cmpstyle_is_numpy = true, {__local.cmpstyle_is_numpy} | expand cmpstyle_templ
[__static]
TESTTYPE = {__local.testtype}
VECSIZE = {__local.vec_size}
NUMCOLS = {__local.num_cols}
DEFAULT_EPSILON = {__local.default_eps}
FLOATCMP_EQ = Dune::XT::Common::FloatCmp::eq{__local.cmpstyle_template}
FLOATCMP_NE = Dune::XT::Common::FloatCmp::ne{__local.cmpstyle_template}
......
......@@ -238,10 +238,20 @@ create(const size_t /*sz*/,
return val;
}
// for compatibility with matrix types
template <class VectorType>
typename std::enable_if<is_vector<VectorType>::value, VectorType>::type
create(const size_t rows, const size_t /*cols*/, const typename VectorAbstraction<VectorType>::S& val)
{
return VectorAbstraction<VectorType>::create(rows, val);
}
} // namespace Common
} // namespace XT
} // namespace Dune
template <class S, class V>
typename std::enable_if<std::is_arithmetic<S>::value && Dune::XT::Common::is_vector<V>::value, V>::type
operator*(const S& scalar, const V& vec)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment