diff --git a/arch/x86/mm/mpx.c b/arch/x86/mm/mpx.c
index 8cc79347903968835327a3bf8a8ec2dcbb23940f..294ea2092ef538dd22b0ecd93506ab7d3d225dfb 100644
--- a/arch/x86/mm/mpx.c
+++ b/arch/x86/mm/mpx.c
@@ -419,6 +419,35 @@ int mpx_disable_management(void)
 	return 0;
 }
 
+static int mpx_cmpxchg_bd_entry(struct mm_struct *mm,
+		unsigned long *curval,
+		unsigned long __user *addr,
+		unsigned long old_val, unsigned long new_val)
+{
+	int ret;
+	/*
+	 * user_atomic_cmpxchg_inatomic() actually uses sizeof()
+	 * the pointer that we pass to it to figure out how much
+	 * data to cmpxchg.  We have to be careful here not to
+	 * pass a pointer to a 64-bit data type when we only want
+	 * a 32-bit copy.
+	 */
+	if (is_64bit_mm(mm)) {
+		ret = user_atomic_cmpxchg_inatomic(curval,
+				addr, old_val, new_val);
+	} else {
+		u32 uninitialized_var(curval_32);
+		u32 old_val_32 = old_val;
+		u32 new_val_32 = new_val;
+		u32 __user *addr_32 = (u32 __user *)addr;
+
+		ret = user_atomic_cmpxchg_inatomic(&curval_32,
+				addr_32, old_val_32, new_val_32);
+		*curval = curval_32;
+	}
+	return ret;
+}
+
 /*
  * With 32-bit mode, MPX_BT_SIZE_BYTES is 4MB, and the size of each
  * bounds table is 16KB. With 64-bit mode, MPX_BT_SIZE_BYTES is 2GB,
@@ -426,6 +455,7 @@ int mpx_disable_management(void)
  */
 static int allocate_bt(long __user *bd_entry)
 {
+	struct mm_struct *mm = current->mm;
 	unsigned long expected_old_val = 0;
 	unsigned long actual_old_val = 0;
 	unsigned long bt_addr;
@@ -455,8 +485,8 @@ static int allocate_bt(long __user *bd_entry)
 	 * mmap_sem at this point, unlike some of the other part
 	 * of the MPX code that have to pagefault_disable().
 	 */
-	ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry,
-					   expected_old_val, bd_new_entry);
+	ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val,	bd_entry,
+				   expected_old_val, bd_new_entry);
 	if (ret)
 		goto out_unmap;
 
@@ -710,15 +740,16 @@ static int unmap_single_bt(struct mm_struct *mm,
 		long __user *bd_entry, unsigned long bt_addr)
 {
 	unsigned long expected_old_val = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
-	unsigned long actual_old_val = 0;
+	unsigned long uninitialized_var(actual_old_val);
 	int ret;
 
 	while (1) {
 		int need_write = 1;
+		unsigned long cleared_bd_entry = 0;
 
 		pagefault_disable();
-		ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry,
-						   expected_old_val, 0);
+		ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val,
+				bd_entry, expected_old_val, cleared_bd_entry);
 		pagefault_enable();
 		if (!ret)
 			break;