diff --git a/lib/Sema/SemaExpr.cpp b/lib/Sema/SemaExpr.cpp
index 36250839cf7d6f48a96834be6bc62fb92636755e..ca261cd78490f706021f67a61c7209c3a21d1b4d 100644
--- a/lib/Sema/SemaExpr.cpp
+++ b/lib/Sema/SemaExpr.cpp
@@ -6617,7 +6617,8 @@ static bool areVectorOperandsLaxBitCastable(ASTContext &Ctx,
   if (!Ctx.getLangOpts().LaxVectorConversions)
     return false;
 
-  if (!LHSType->isVectorType() || !RHSType->isVectorType())
+  if (!(LHSType->isVectorType() || LHSType->isScalarType()) ||
+      !(RHSType->isVectorType() || RHSType->isScalarType()))
     return false;
 
   unsigned LHSSize = Ctx.getTypeSize(LHSType);
@@ -6631,13 +6632,20 @@ static bool areVectorOperandsLaxBitCastable(ASTContext &Ctx,
   // Make sure such width is the same between the types, otherwise we may end
   // up with an invalid bitcast.
   unsigned LHSIRSize, RHSIRSize;
-  const VectorType *LVec = LHSType->getAs<VectorType>();
-  LHSIRSize = LVec->getNumElements() *
-      Ctx.getTypeSize(LVec->getElementType());
-  const VectorType *RVec = RHSType->getAs<VectorType>();
-  RHSIRSize = RVec->getNumElements() *
-      Ctx.getTypeSize(RVec->getElementType());
-
+  if (LHSType->isVectorType()) {
+    const VectorType *Vec = LHSType->getAs<VectorType>();
+    LHSIRSize = Vec->getNumElements() *
+        Ctx.getTypeSize(Vec->getElementType());
+  } else {
+    LHSIRSize = LHSSize;
+  }
+  if (RHSType->isVectorType()) {
+    const VectorType *Vec = RHSType->getAs<VectorType>();
+    RHSIRSize = Vec->getNumElements() *
+        Ctx.getTypeSize(Vec->getElementType());
+  } else {
+    RHSIRSize = RHSSize;
+  }
   if (LHSIRSize != RHSIRSize)
     return false;
 
diff --git a/test/Sema/vector-cast.c b/test/Sema/vector-cast.c
index 7fa6e86aa10e7dce0fa2aa7e71a958df8617e11a..6a5f0eca4254d5aac7f56088ebcb691691c8d846 100644
--- a/test/Sema/vector-cast.c
+++ b/test/Sema/vector-cast.c
@@ -36,3 +36,11 @@ void f3(t3 Y) {
   f2(Y);  // expected-warning {{incompatible vector types passing 't3' to parameter of type 't2'}}
 }
 
+typedef float float2 __attribute__ ((vector_size (8)));
+
+void f4() {
+  float2 f2;
+  double d;
+  f2 += d;
+  d += f2;
+}