diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index c13cd28d9d1be5abdff8fdf93692d51755c8930c..57d418061c55d876fcafe16f4f7fca550035583e 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -713,6 +713,9 @@ struct kvm_vcpu_arch {
 
 	/* be preempted when it's in kernel-mode(cpl=0) */
 	bool preempted_in_kernel;
+
+	/* Flush the L1 Data cache for L1TF mitigation on VMENTER */
+	bool l1tf_flush_l1d;
 };
 
 struct kvm_lpage_info {
@@ -881,6 +884,7 @@ struct kvm_vcpu_stat {
 	u64 signal_exits;
 	u64 irq_window_exits;
 	u64 nmi_window_exits;
+	u64 l1d_flush;
 	u64 halt_exits;
 	u64 halt_successful_poll;
 	u64 halt_attempted_poll;
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index d594690d8b9597a87f4cba26e8c1be5cb2de22de..9beb772b9eb6450eda337e0c56f5db15ba9f8fd3 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -3840,6 +3840,7 @@ int kvm_handle_page_fault(struct kvm_vcpu *vcpu, u64 error_code,
 {
 	int r = 1;
 
+	vcpu->arch.l1tf_flush_l1d = true;
 	switch (vcpu->arch.apf.host_apf_reason) {
 	default:
 		trace_kvm_page_fault(fault_address, error_code);
diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index a1dbc17c7a0318083951741bf0a5ca92f3711b6e..882df9605e4260d7d1e56fb95dedf896d437da02 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -9576,9 +9576,20 @@ static int vmx_handle_exit(struct kvm_vcpu *vcpu)
 #define L1D_CACHE_ORDER 4
 static void *vmx_l1d_flush_pages;
 
-static void __maybe_unused vmx_l1d_flush(void)
+static void vmx_l1d_flush(struct kvm_vcpu *vcpu)
 {
 	int size = PAGE_SIZE << L1D_CACHE_ORDER;
+	bool always;
+
+	/*
+	 * If the mitigation mode is 'flush always', keep the flush bit
+	 * set, otherwise clear it. It gets set again either from
+	 * vcpu_run() or from one of the unsafe VMEXIT handlers.
+	 */
+	always = vmentry_l1d_flush == VMENTER_L1D_FLUSH_ALWAYS;
+	vcpu->arch.l1tf_flush_l1d = always;
+
+	vcpu->stat.l1d_flush++;
 
 	if (static_cpu_has(X86_FEATURE_FLUSH_L1D)) {
 		wrmsrl(MSR_IA32_FLUSH_CMD, L1D_FLUSH);
@@ -9847,6 +9858,7 @@ static void vmx_handle_external_intr(struct kvm_vcpu *vcpu)
 			[ss]"i"(__KERNEL_DS),
 			[cs]"i"(__KERNEL_CS)
 			);
+		vcpu->arch.l1tf_flush_l1d = true;
 	}
 }
 STACK_FRAME_NON_STANDARD(vmx_handle_external_intr);
@@ -10104,6 +10116,11 @@ static void __noclone vmx_vcpu_run(struct kvm_vcpu *vcpu)
 	evmcs_rsp = static_branch_unlikely(&enable_evmcs) ?
 		(unsigned long)&current_evmcs->host_rsp : 0;
 
+	if (static_branch_unlikely(&vmx_l1d_should_flush)) {
+		if (vcpu->arch.l1tf_flush_l1d)
+			vmx_l1d_flush(vcpu);
+	}
+
 	asm(
 		/* Store host registers */
 		"push %%" _ASM_DX "; push %%" _ASM_BP ";"
@@ -11917,6 +11934,9 @@ static int nested_vmx_run(struct kvm_vcpu *vcpu, bool launch)
 		return ret;
 	}
 
+	/* Hide L1D cache contents from the nested guest.  */
+	vmx->vcpu.arch.l1tf_flush_l1d = true;
+
 	/*
 	 * If we're entering a halted L2 vcpu and the L2 vcpu won't be woken
 	 * by event injection, halt vcpu.
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index 0046aa70205aa2dfbc0577065250be717ca25b4e..902d535dff8f8c5704a9c26b88819e3afdf58862 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -195,6 +195,7 @@ struct kvm_stats_debugfs_item debugfs_entries[] = {
 	{ "irq_injections", VCPU_STAT(irq_injections) },
 	{ "nmi_injections", VCPU_STAT(nmi_injections) },
 	{ "req_event", VCPU_STAT(req_event) },
+	{ "l1d_flush", VCPU_STAT(l1d_flush) },
 	{ "mmu_shadow_zapped", VM_STAT(mmu_shadow_zapped) },
 	{ "mmu_pte_write", VM_STAT(mmu_pte_write) },
 	{ "mmu_pte_updated", VM_STAT(mmu_pte_updated) },
@@ -4874,6 +4875,9 @@ static int emulator_write_std(struct x86_emulate_ctxt *ctxt, gva_t addr, void *v
 int kvm_write_guest_virt_system(struct kvm_vcpu *vcpu, gva_t addr, void *val,
 				unsigned int bytes, struct x86_exception *exception)
 {
+	/* kvm_write_guest_virt_system can pull in tons of pages. */
+	vcpu->arch.l1tf_flush_l1d = true;
+
 	return kvm_write_guest_virt_helper(addr, val, bytes, vcpu,
 					   PFERR_WRITE_MASK, exception);
 }
@@ -6050,6 +6054,8 @@ int x86_emulate_instruction(struct kvm_vcpu *vcpu,
 	bool writeback = true;
 	bool write_fault_to_spt = vcpu->arch.write_fault_to_shadow_pgtable;
 
+	vcpu->arch.l1tf_flush_l1d = true;
+
 	/*
 	 * Clear write_fault_to_shadow_pgtable here to ensure it is
 	 * never reused.
@@ -7579,6 +7585,7 @@ static int vcpu_run(struct kvm_vcpu *vcpu)
 	struct kvm *kvm = vcpu->kvm;
 
 	vcpu->srcu_idx = srcu_read_lock(&kvm->srcu);
+	vcpu->arch.l1tf_flush_l1d = true;
 
 	for (;;) {
 		if (kvm_vcpu_running(vcpu)) {
@@ -8698,6 +8705,7 @@ void kvm_arch_vcpu_uninit(struct kvm_vcpu *vcpu)
 
 void kvm_arch_sched_in(struct kvm_vcpu *vcpu, int cpu)
 {
+	vcpu->arch.l1tf_flush_l1d = true;
 	kvm_x86_ops->sched_in(vcpu, cpu);
 }