diff --git a/arch/x86/include/asm/processor.h b/arch/x86/include/asm/processor.h
index e119c78581d14562bfc5159b3099b8afd9de4135..307aa9454df30b76a5d87f7cfd7bbe9945e5a61f 100644
--- a/arch/x86/include/asm/processor.h
+++ b/arch/x86/include/asm/processor.h
@@ -763,6 +763,7 @@ extern void load_direct_gdt(int);
 extern void load_fixmap_gdt(int);
 extern void load_percpu_segment(int);
 extern void cpu_init(void);
+extern void cr4_init(void);
 
 static inline unsigned long get_debugctlmsr(void)
 {
diff --git a/arch/x86/include/asm/special_insns.h b/arch/x86/include/asm/special_insns.h
index 87c371e87743fb193b86a803cadd5806fa2739af..af5f0d9a99511edeba28d5ece62d68753f3715ff 100644
--- a/arch/x86/include/asm/special_insns.h
+++ b/arch/x86/include/asm/special_insns.h
@@ -18,9 +18,7 @@
  */
 extern unsigned long __force_order;
 
-/* Starts false and gets enabled once CPU feature detection is done. */
-DECLARE_STATIC_KEY_FALSE(cr_pinning);
-extern unsigned long cr4_pinned_bits;
+void native_write_cr0(unsigned long val);
 
 static inline unsigned long native_read_cr0(void)
 {
@@ -29,24 +27,6 @@ static inline unsigned long native_read_cr0(void)
 	return val;
 }
 
-static inline void native_write_cr0(unsigned long val)
-{
-	unsigned long bits_missing = 0;
-
-set_register:
-	asm volatile("mov %0,%%cr0": "+r" (val), "+m" (__force_order));
-
-	if (static_branch_likely(&cr_pinning)) {
-		if (unlikely((val & X86_CR0_WP) != X86_CR0_WP)) {
-			bits_missing = X86_CR0_WP;
-			val |= bits_missing;
-			goto set_register;
-		}
-		/* Warn after we've set the missing bits. */
-		WARN_ONCE(bits_missing, "CR0 WP bit went missing!?\n");
-	}
-}
-
 static inline unsigned long native_read_cr2(void)
 {
 	unsigned long val;
@@ -91,24 +71,7 @@ static inline unsigned long native_read_cr4(void)
 	return val;
 }
 
-static inline void native_write_cr4(unsigned long val)
-{
-	unsigned long bits_missing = 0;
-
-set_register:
-	asm volatile("mov %0,%%cr4": "+r" (val), "+m" (cr4_pinned_bits));
-
-	if (static_branch_likely(&cr_pinning)) {
-		if (unlikely((val & cr4_pinned_bits) != cr4_pinned_bits)) {
-			bits_missing = ~val & cr4_pinned_bits;
-			val |= bits_missing;
-			goto set_register;
-		}
-		/* Warn after we've set the missing bits. */
-		WARN_ONCE(bits_missing, "CR4 bits went missing: %lx!?\n",
-			  bits_missing);
-	}
-}
+void native_write_cr4(unsigned long val);
 
 #ifdef CONFIG_X86_64
 static inline unsigned long native_read_cr8(void)
diff --git a/arch/x86/kernel/cpu/common.c b/arch/x86/kernel/cpu/common.c
index a7d2fff957552c7ff34bacbfcd5d00a6d941fa38..a0306327e7ed38d029a0b9ba9eef61eb3f7df0b6 100644
--- a/arch/x86/kernel/cpu/common.c
+++ b/arch/x86/kernel/cpu/common.c
@@ -365,10 +365,62 @@ static __always_inline void setup_umip(struct cpuinfo_x86 *c)
 	cr4_clear_bits(X86_CR4_UMIP);
 }
 
-DEFINE_STATIC_KEY_FALSE_RO(cr_pinning);
-EXPORT_SYMBOL(cr_pinning);
-unsigned long cr4_pinned_bits __ro_after_init;
-EXPORT_SYMBOL(cr4_pinned_bits);
+static DEFINE_STATIC_KEY_FALSE_RO(cr_pinning);
+static unsigned long cr4_pinned_bits __ro_after_init;
+
+void native_write_cr0(unsigned long val)
+{
+	unsigned long bits_missing = 0;
+
+set_register:
+	asm volatile("mov %0,%%cr0": "+r" (val), "+m" (__force_order));
+
+	if (static_branch_likely(&cr_pinning)) {
+		if (unlikely((val & X86_CR0_WP) != X86_CR0_WP)) {
+			bits_missing = X86_CR0_WP;
+			val |= bits_missing;
+			goto set_register;
+		}
+		/* Warn after we've set the missing bits. */
+		WARN_ONCE(bits_missing, "CR0 WP bit went missing!?\n");
+	}
+}
+EXPORT_SYMBOL(native_write_cr0);
+
+void native_write_cr4(unsigned long val)
+{
+	unsigned long bits_missing = 0;
+
+set_register:
+	asm volatile("mov %0,%%cr4": "+r" (val), "+m" (cr4_pinned_bits));
+
+	if (static_branch_likely(&cr_pinning)) {
+		if (unlikely((val & cr4_pinned_bits) != cr4_pinned_bits)) {
+			bits_missing = ~val & cr4_pinned_bits;
+			val |= bits_missing;
+			goto set_register;
+		}
+		/* Warn after we've set the missing bits. */
+		WARN_ONCE(bits_missing, "CR4 bits went missing: %lx!?\n",
+			  bits_missing);
+	}
+}
+EXPORT_SYMBOL(native_write_cr4);
+
+void cr4_init(void)
+{
+	unsigned long cr4 = __read_cr4();
+
+	if (boot_cpu_has(X86_FEATURE_PCID))
+		cr4 |= X86_CR4_PCIDE;
+	if (static_branch_likely(&cr_pinning))
+		cr4 |= cr4_pinned_bits;
+
+	__write_cr4(cr4);
+
+	/* Initialize cr4 shadow for this CPU. */
+	this_cpu_write(cpu_tlbstate.cr4, cr4);
+}
 
 /*
  * Once CPU feature detection is finished (and boot params have been
@@ -1851,12 +1903,6 @@ void cpu_init(void)
 
 	wait_for_master_cpu(cpu);
 
-	/*
-	 * Initialize the CR4 shadow before doing anything that could
-	 * try to read it.
-	 */
-	cr4_init_shadow();
-
 	if (cpu)
 		load_ucode_ap();
 
@@ -1956,12 +2002,6 @@ void cpu_init(void)
 
 	wait_for_master_cpu(cpu);
 
-	/*
-	 * Initialize the CR4 shadow before doing anything that could
-	 * try to read it.
-	 */
-	cr4_init_shadow();
-
 	show_ucode_info_early();
 
 	pr_info("Initializing CPU#%d\n", cpu);
diff --git a/arch/x86/kernel/smpboot.c b/arch/x86/kernel/smpboot.c
index 23021970136ead2f2c01abaf88b007166846aa44..ee697fa8847d6f2d32baef8bc1218cb7a9a39f82 100644
--- a/arch/x86/kernel/smpboot.c
+++ b/arch/x86/kernel/smpboot.c
@@ -212,28 +212,16 @@ static int enable_start_cpu0;
  */
 static void notrace start_secondary(void *unused)
 {
-	unsigned long cr4 = __read_cr4();
-
 	/*
 	 * Don't put *anything* except direct CPU state initialization
 	 * before cpu_init(), SMP booting is too fragile that we want to
 	 * limit the things done here to the most necessary things.
 	 */
-	if (boot_cpu_has(X86_FEATURE_PCID))
-		cr4 |= X86_CR4_PCIDE;
-	if (static_branch_likely(&cr_pinning))
-		cr4 |= cr4_pinned_bits;
-
-	__write_cr4(cr4);
+	cr4_init();
 
 #ifdef CONFIG_X86_32
 	/* switch away from the initial page table */
 	load_cr3(swapper_pg_dir);
-	/*
-	 * Initialize the CR4 shadow before doing anything that could
-	 * try to read it.
-	 */
-	cr4_init_shadow();
 	__flush_tlb_all();
 #endif
 	load_current_idt();
diff --git a/arch/x86/xen/smp_pv.c b/arch/x86/xen/smp_pv.c
index e3b18ad49889afc5ae35d2e2796aecd108a93819..32a9c22121249672cc9256c7d368a0787b129a9b 100644
--- a/arch/x86/xen/smp_pv.c
+++ b/arch/x86/xen/smp_pv.c
@@ -57,6 +57,7 @@ static void cpu_bringup(void)
 {
 	int cpu;
 
+	cr4_init();
 	cpu_init();
 	touch_softlockup_watchdog();
 	preempt_disable();