diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 6d8326409974c951de2842b7394fd0ed9a5930e6..69b652547489ed041fb0c8c3694c779fe1c75f48 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -648,6 +648,8 @@ void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
 
 int kvm_mmu_reset_context(struct kvm_vcpu *vcpu);
 void kvm_mmu_slot_remove_write_access(struct kvm *kvm, int slot);
+int kvm_mmu_rmap_write_protect(struct kvm *kvm, u64 gfn,
+			       struct kvm_memory_slot *slot);
 void kvm_mmu_zap_all(struct kvm *kvm);
 unsigned int kvm_mmu_calculate_mmu_pages(struct kvm *kvm);
 void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages);
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index fa71085f75a3e8989d464daf65bd963e49f47fee..aecdea265f7e602ae42f24777b32126865bca98f 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -1023,15 +1023,13 @@ static void drop_spte(struct kvm *kvm, u64 *sptep)
 		rmap_remove(kvm, sptep);
 }
 
-static int rmap_write_protect(struct kvm *kvm, u64 gfn)
+int kvm_mmu_rmap_write_protect(struct kvm *kvm, u64 gfn,
+			       struct kvm_memory_slot *slot)
 {
-	struct kvm_memory_slot *slot;
 	unsigned long *rmapp;
 	u64 *spte;
 	int i, write_protected = 0;
 
-	slot = gfn_to_memslot(kvm, gfn);
-
 	rmapp = __gfn_to_rmap(kvm, gfn, PT_PAGE_TABLE_LEVEL, slot);
 	spte = rmap_next(kvm, rmapp, NULL);
 	while (spte) {
@@ -1066,6 +1064,14 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
 	return write_protected;
 }
 
+static int rmap_write_protect(struct kvm *kvm, u64 gfn)
+{
+	struct kvm_memory_slot *slot;
+
+	slot = gfn_to_memslot(kvm, gfn);
+	return kvm_mmu_rmap_write_protect(kvm, gfn, slot);
+}
+
 static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp,
 			   unsigned long data)
 {
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index 220c83b0fbdad2725afe40cf2fc0a045837d5162..af546b768ffd3bebfeeef05d51365c2dce1810d5 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -3460,6 +3460,50 @@ static int kvm_vm_ioctl_reinject(struct kvm *kvm,
 	return 0;
 }
 
+/**
+ * write_protect_slot - write protect a slot for dirty logging
+ * @kvm: the kvm instance
+ * @memslot: the slot we protect
+ * @dirty_bitmap: the bitmap indicating which pages are dirty
+ * @nr_dirty_pages: the number of dirty pages
+ *
+ * We have two ways to find all sptes to protect:
+ * 1. Use kvm_mmu_slot_remove_write_access() which walks all shadow pages and
+ *    checks ones that have a spte mapping a page in the slot.
+ * 2. Use kvm_mmu_rmap_write_protect() for each gfn found in the bitmap.
+ *
+ * Generally speaking, if there are not so many dirty pages compared to the
+ * number of shadow pages, we should use the latter.
+ *
+ * Note that letting others write into a page marked dirty in the old bitmap
+ * by using the remaining tlb entry is not a problem.  That page will become
+ * write protected again when we flush the tlb and then be reported dirty to
+ * the user space by copying the old bitmap.
+ */
+static void write_protect_slot(struct kvm *kvm,
+			       struct kvm_memory_slot *memslot,
+			       unsigned long *dirty_bitmap,
+			       unsigned long nr_dirty_pages)
+{
+	/* Not many dirty pages compared to # of shadow pages. */
+	if (nr_dirty_pages < kvm->arch.n_used_mmu_pages) {
+		unsigned long gfn_offset;
+
+		for_each_set_bit(gfn_offset, dirty_bitmap, memslot->npages) {
+			unsigned long gfn = memslot->base_gfn + gfn_offset;
+
+			spin_lock(&kvm->mmu_lock);
+			kvm_mmu_rmap_write_protect(kvm, gfn, memslot);
+			spin_unlock(&kvm->mmu_lock);
+		}
+		kvm_flush_remote_tlbs(kvm);
+	} else {
+		spin_lock(&kvm->mmu_lock);
+		kvm_mmu_slot_remove_write_access(kvm, memslot->id);
+		spin_unlock(&kvm->mmu_lock);
+	}
+}
+
 /*
  * Get (and clear) the dirty memory log for a memory slot.
  */
@@ -3468,7 +3512,7 @@ int kvm_vm_ioctl_get_dirty_log(struct kvm *kvm,
 {
 	int r;
 	struct kvm_memory_slot *memslot;
-	unsigned long n;
+	unsigned long n, nr_dirty_pages;
 
 	mutex_lock(&kvm->slots_lock);
 
@@ -3482,9 +3526,10 @@ int kvm_vm_ioctl_get_dirty_log(struct kvm *kvm,
 		goto out;
 
 	n = kvm_dirty_bitmap_bytes(memslot);
+	nr_dirty_pages = memslot->nr_dirty_pages;
 
 	/* If nothing is dirty, don't bother messing with page tables. */
-	if (memslot->nr_dirty_pages) {
+	if (nr_dirty_pages) {
 		struct kvm_memslots *slots, *old_slots;
 		unsigned long *dirty_bitmap;
 
@@ -3498,8 +3543,9 @@ int kvm_vm_ioctl_get_dirty_log(struct kvm *kvm,
 		if (!slots)
 			goto out;
 		memcpy(slots, kvm->memslots, sizeof(struct kvm_memslots));
-		slots->memslots[log->slot].dirty_bitmap = dirty_bitmap;
-		slots->memslots[log->slot].nr_dirty_pages = 0;
+		memslot = &slots->memslots[log->slot];
+		memslot->dirty_bitmap = dirty_bitmap;
+		memslot->nr_dirty_pages = 0;
 		slots->generation++;
 
 		old_slots = kvm->memslots;
@@ -3508,9 +3554,7 @@ int kvm_vm_ioctl_get_dirty_log(struct kvm *kvm,
 		dirty_bitmap = old_slots->memslots[log->slot].dirty_bitmap;
 		kfree(old_slots);
 
-		spin_lock(&kvm->mmu_lock);
-		kvm_mmu_slot_remove_write_access(kvm, log->slot);
-		spin_unlock(&kvm->mmu_lock);
+		write_protect_slot(kvm, memslot, dirty_bitmap, nr_dirty_pages);
 
 		r = -EFAULT;
 		if (copy_to_user(log->dirty_bitmap, dirty_bitmap, n))