diff --git a/arch/x86/entry/entry_64.S b/arch/x86/entry/entry_64.S
index ca651113bea2dcb061ac666aad807becf063c5ee..323b395c9cd893352c1b6ade762a68c40741ae33 100644
--- a/arch/x86/entry/entry_64.S
+++ b/arch/x86/entry/entry_64.S
@@ -575,7 +575,7 @@ END(spurious_entries_start)
  * +----------------------------------------------------+
  */
 ENTRY(interrupt_entry)
-	UNWIND_HINT_FUNC
+	UNWIND_HINT_IRET_REGS offset=16
 	ASM_CLAC
 	cld
 
@@ -607,9 +607,9 @@ ENTRY(interrupt_entry)
 	pushq	5*8(%rdi)		/* regs->eflags */
 	pushq	4*8(%rdi)		/* regs->cs */
 	pushq	3*8(%rdi)		/* regs->ip */
+	UNWIND_HINT_IRET_REGS
 	pushq	2*8(%rdi)		/* regs->orig_ax */
 	pushq	8(%rdi)			/* return address */
-	UNWIND_HINT_FUNC
 
 	movq	(%rdi), %rdi
 	jmp	2f
diff --git a/arch/x86/include/asm/unwind.h b/arch/x86/include/asm/unwind.h
index 499578f7e6d7bb5020fe0c503c70b1be0f5b95f1..70fc159ebe6959fead369c46d8a143812a1bc058 100644
--- a/arch/x86/include/asm/unwind.h
+++ b/arch/x86/include/asm/unwind.h
@@ -19,7 +19,7 @@ struct unwind_state {
 #if defined(CONFIG_UNWINDER_ORC)
 	bool signal, full_regs;
 	unsigned long sp, bp, ip;
-	struct pt_regs *regs;
+	struct pt_regs *regs, *prev_regs;
 #elif defined(CONFIG_UNWINDER_FRAME_POINTER)
 	bool got_irq;
 	unsigned long *bp, *orig_sp, ip;
diff --git a/arch/x86/kernel/unwind_orc.c b/arch/x86/kernel/unwind_orc.c
index 7a978ed9cb037f3083a04e5d7d65ec1b300e4fd6..bafe953f5d7ff4097889af90e3ca8f7764fcc7d2 100644
--- a/arch/x86/kernel/unwind_orc.c
+++ b/arch/x86/kernel/unwind_orc.c
@@ -375,9 +375,38 @@ static bool deref_stack_iret_regs(struct unwind_state *state, unsigned long addr
 	return true;
 }
 
+/*
+ * If state->regs is non-NULL, and points to a full pt_regs, just get the reg
+ * value from state->regs.
+ *
+ * Otherwise, if state->regs just points to IRET regs, and the previous frame
+ * had full regs, it's safe to get the value from the previous regs.  This can
+ * happen when early/late IRQ entry code gets interrupted by an NMI.
+ */
+static bool get_reg(struct unwind_state *state, unsigned int reg_off,
+		    unsigned long *val)
+{
+	unsigned int reg = reg_off/8;
+
+	if (!state->regs)
+		return false;
+
+	if (state->full_regs) {
+		*val = ((unsigned long *)state->regs)[reg];
+		return true;
+	}
+
+	if (state->prev_regs) {
+		*val = ((unsigned long *)state->prev_regs)[reg];
+		return true;
+	}
+
+	return false;
+}
+
 bool unwind_next_frame(struct unwind_state *state)
 {
-	unsigned long ip_p, sp, orig_ip = state->ip, prev_sp = state->sp;
+	unsigned long ip_p, sp, tmp, orig_ip = state->ip, prev_sp = state->sp;
 	enum stack_type prev_type = state->stack_info.type;
 	struct orc_entry *orc;
 	bool indirect = false;
@@ -431,39 +460,35 @@ bool unwind_next_frame(struct unwind_state *state)
 		break;
 
 	case ORC_REG_R10:
-		if (!state->regs || !state->full_regs) {
+		if (!get_reg(state, offsetof(struct pt_regs, r10), &sp)) {
 			orc_warn_current("missing R10 value at %pB\n",
 					 (void *)state->ip);
 			goto err;
 		}
-		sp = state->regs->r10;
 		break;
 
 	case ORC_REG_R13:
-		if (!state->regs || !state->full_regs) {
+		if (!get_reg(state, offsetof(struct pt_regs, r13), &sp)) {
 			orc_warn_current("missing R13 value at %pB\n",
 					 (void *)state->ip);
 			goto err;
 		}
-		sp = state->regs->r13;
 		break;
 
 	case ORC_REG_DI:
-		if (!state->regs || !state->full_regs) {
+		if (!get_reg(state, offsetof(struct pt_regs, di), &sp)) {
 			orc_warn_current("missing RDI value at %pB\n",
 					 (void *)state->ip);
 			goto err;
 		}
-		sp = state->regs->di;
 		break;
 
 	case ORC_REG_DX:
-		if (!state->regs || !state->full_regs) {
+		if (!get_reg(state, offsetof(struct pt_regs, dx), &sp)) {
 			orc_warn_current("missing DX value at %pB\n",
 					 (void *)state->ip);
 			goto err;
 		}
-		sp = state->regs->dx;
 		break;
 
 	default:
@@ -490,6 +515,7 @@ bool unwind_next_frame(struct unwind_state *state)
 
 		state->sp = sp;
 		state->regs = NULL;
+		state->prev_regs = NULL;
 		state->signal = false;
 		break;
 
@@ -501,6 +527,7 @@ bool unwind_next_frame(struct unwind_state *state)
 		}
 
 		state->regs = (struct pt_regs *)sp;
+		state->prev_regs = NULL;
 		state->full_regs = true;
 		state->signal = true;
 		break;
@@ -512,6 +539,8 @@ bool unwind_next_frame(struct unwind_state *state)
 			goto err;
 		}
 
+		if (state->full_regs)
+			state->prev_regs = state->regs;
 		state->regs = (void *)sp - IRET_FRAME_OFFSET;
 		state->full_regs = false;
 		state->signal = true;
@@ -526,8 +555,8 @@ bool unwind_next_frame(struct unwind_state *state)
 	/* Find BP: */
 	switch (orc->bp_reg) {
 	case ORC_REG_UNDEFINED:
-		if (state->regs && state->full_regs)
-			state->bp = state->regs->bp;
+		if (get_reg(state, offsetof(struct pt_regs, bp), &tmp))
+			state->bp = tmp;
 		break;
 
 	case ORC_REG_PREV_SP: