diff --git a/arch/x86/kernel/entry_64.S b/arch/x86/kernel/entry_64.S
index 03c52e2176809a562692ce1b89457bc8cc595680..386375d43d147debdbd90a0d8e3c2b9325754fd6 100644
--- a/arch/x86/kernel/entry_64.S
+++ b/arch/x86/kernel/entry_64.S
@@ -1355,19 +1355,7 @@ ENTRY(error_exit)
 	CFI_ENDPROC
 END(error_exit)
 
-/*
- * Test if a given stack is an NMI stack or not.
- */
-	.macro test_in_nmi reg stack nmi_ret normal_ret
-	cmpq %\reg, \stack
-	ja \normal_ret
-	subq $EXCEPTION_STKSZ, %\reg
-	cmpq %\reg, \stack
-	jb \normal_ret
-	jmp \nmi_ret
-	.endm
-
-	/* runs on exception stack */
+/* Runs on exception stack */
 ENTRY(nmi)
 	INTR_FRAME
 	PARAVIRT_ADJUST_EXCEPTION_FRAME
@@ -1428,8 +1416,18 @@ ENTRY(nmi)
 	 * We check the variable because the first NMI could be in a
 	 * breakpoint routine using a breakpoint stack.
 	 */
-	lea 6*8(%rsp), %rdx
-	test_in_nmi rdx, 4*8(%rsp), nested_nmi, first_nmi
+	lea	6*8(%rsp), %rdx
+	/* Compare the NMI stack (rdx) with the stack we came from (4*8(%rsp)) */
+	cmpq	%rdx, 4*8(%rsp)
+	/* If the stack pointer is above the NMI stack, this is a normal NMI */
+	ja	first_nmi
+	subq	$EXCEPTION_STKSZ, %rdx
+	cmpq	%rdx, 4*8(%rsp)
+	/* If it is below the NMI stack, it is a normal NMI */
+	jb	first_nmi
+	/* Ah, it is within the NMI stack, treat it as nested */
+	jmp	nested_nmi
+
 	CFI_REMEMBER_STATE
 
 nested_nmi: