diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index cd8d0b040daa813915395b3e727a6ebc1cfec846..bdb9b302825073d5e946c8fc9bf84096dee8c7a7 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -2073,6 +2073,7 @@ static void prepare_vmcs02_early(struct vcpu_vmx *vmx, struct vmcs12 *vmcs12)
 	exec_control &= ~CPU_BASED_TPR_SHADOW;
 	exec_control |= vmcs12->cpu_based_vm_exec_control;
 
+	vmx->nested.l1_tpr_threshold = -1;
 	if (exec_control & CPU_BASED_TPR_SHADOW)
 		vmcs_write32(TPR_THRESHOLD, vmcs12->tpr_threshold);
 #ifdef CONFIG_X86_64
@@ -4115,6 +4116,8 @@ void nested_vmx_vmexit(struct kvm_vcpu *vcpu, u32 exit_reason,
 	vmcs_write32(VM_EXIT_MSR_LOAD_COUNT, vmx->msr_autoload.host.nr);
 	vmcs_write32(VM_ENTRY_MSR_LOAD_COUNT, vmx->msr_autoload.guest.nr);
 	vmcs_write64(TSC_OFFSET, vcpu->arch.tsc_offset);
+	if (vmx->nested.l1_tpr_threshold != -1)
+		vmcs_write32(TPR_THRESHOLD, vmx->nested.l1_tpr_threshold);
 
 	if (kvm_has_tsc_control)
 		decache_tsc_multiplier(vmx);
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index 2a64bf3c62b9fbffc4a14ecaadd49c754ee92a4b..765086756177518241d980712326abd4e243d5dc 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -5988,7 +5988,10 @@ static void update_cr8_intercept(struct kvm_vcpu *vcpu, int tpr, int irr)
 		return;
 
 	tpr_threshold = (irr == -1 || tpr < irr) ? 0 : irr;
-	vmcs_write32(TPR_THRESHOLD, tpr_threshold);
+	if (is_guest_mode(vcpu))
+		to_vmx(vcpu)->nested.l1_tpr_threshold = tpr_threshold;
+	else
+		vmcs_write32(TPR_THRESHOLD, tpr_threshold);
 }
 
 void vmx_set_virtual_apic_mode(struct kvm_vcpu *vcpu)
diff --git a/arch/x86/kvm/vmx/vmx.h b/arch/x86/kvm/vmx/vmx.h
index bee16687dc0bf054957273d520306cfcbbd178dc..43331dfafffe6f91007c1e701bb343a91afee41c 100644
--- a/arch/x86/kvm/vmx/vmx.h
+++ b/arch/x86/kvm/vmx/vmx.h
@@ -167,6 +167,9 @@ struct nested_vmx {
 	u64 vmcs01_debugctl;
 	u64 vmcs01_guest_bndcfgs;
 
+	/* to migrate it to L1 if L2 writes to L1's CR8 directly */
+	int l1_tpr_threshold;
+
 	u16 vpid02;
 	u16 last_vpid;