diff --git a/arch/x86/mm/mem_encrypt.c b/arch/x86/mm/mem_encrypt.c
index 5a20696c54405935ddef0d1dff3a20353ebc0617..35f38caa1fa3fb713c3d43d81f03841c15fd6ea3 100644
--- a/arch/x86/mm/mem_encrypt.c
+++ b/arch/x86/mm/mem_encrypt.c
@@ -468,31 +468,39 @@ struct sme_populate_pgd_data {
 	void	*pgtable_area;
 	pgd_t	*pgd;
 
-	pmdval_t pmd_val;
+	pmdval_t pmd_flags;
+	unsigned long paddr;
+
 	unsigned long vaddr;
+	unsigned long vaddr_end;
 };
 
-static void __init sme_clear_pgd(pgd_t *pgd_base, unsigned long start,
-				 unsigned long end)
+static void __init sme_clear_pgd(struct sme_populate_pgd_data *ppd)
 {
 	unsigned long pgd_start, pgd_end, pgd_size;
 	pgd_t *pgd_p;
 
-	pgd_start = start & PGDIR_MASK;
-	pgd_end = end & PGDIR_MASK;
+	pgd_start = ppd->vaddr & PGDIR_MASK;
+	pgd_end = ppd->vaddr_end & PGDIR_MASK;
 
-	pgd_size = (((pgd_end - pgd_start) / PGDIR_SIZE) + 1);
-	pgd_size *= sizeof(pgd_t);
+	pgd_size = (((pgd_end - pgd_start) / PGDIR_SIZE) + 1) * sizeof(pgd_t);
 
-	pgd_p = pgd_base + pgd_index(start);
+	pgd_p = ppd->pgd + pgd_index(ppd->vaddr);
 
 	memset(pgd_p, 0, pgd_size);
 }
 
-#define PGD_FLAGS	_KERNPG_TABLE_NOENC
-#define P4D_FLAGS	_KERNPG_TABLE_NOENC
-#define PUD_FLAGS	_KERNPG_TABLE_NOENC
-#define PMD_FLAGS	(__PAGE_KERNEL_LARGE_EXEC & ~_PAGE_GLOBAL)
+#define PGD_FLAGS		_KERNPG_TABLE_NOENC
+#define P4D_FLAGS		_KERNPG_TABLE_NOENC
+#define PUD_FLAGS		_KERNPG_TABLE_NOENC
+
+#define PMD_FLAGS_LARGE		(__PAGE_KERNEL_LARGE_EXEC & ~_PAGE_GLOBAL)
+
+#define PMD_FLAGS_DEC		PMD_FLAGS_LARGE
+#define PMD_FLAGS_DEC_WP	((PMD_FLAGS_DEC & ~_PAGE_CACHE_MASK) | \
+				 (_PAGE_PAT | _PAGE_PWT))
+
+#define PMD_FLAGS_ENC		(PMD_FLAGS_LARGE | _PAGE_ENC)
 
 static void __init sme_populate_pgd_large(struct sme_populate_pgd_data *ppd)
 {
@@ -561,7 +569,35 @@ static void __init sme_populate_pgd_large(struct sme_populate_pgd_data *ppd)
 
 	pmd_p += pmd_index(ppd->vaddr);
 	if (!native_pmd_val(*pmd_p) || !(native_pmd_val(*pmd_p) & _PAGE_PSE))
-		native_set_pmd(pmd_p, native_make_pmd(ppd->pmd_val));
+		native_set_pmd(pmd_p, native_make_pmd(ppd->paddr | ppd->pmd_flags));
+}
+
+static void __init __sme_map_range(struct sme_populate_pgd_data *ppd,
+				   pmdval_t pmd_flags)
+{
+	ppd->pmd_flags = pmd_flags;
+
+	while (ppd->vaddr < ppd->vaddr_end) {
+		sme_populate_pgd_large(ppd);
+
+		ppd->vaddr += PMD_PAGE_SIZE;
+		ppd->paddr += PMD_PAGE_SIZE;
+	}
+}
+
+static void __init sme_map_range_encrypted(struct sme_populate_pgd_data *ppd)
+{
+	__sme_map_range(ppd, PMD_FLAGS_ENC);
+}
+
+static void __init sme_map_range_decrypted(struct sme_populate_pgd_data *ppd)
+{
+	__sme_map_range(ppd, PMD_FLAGS_DEC);
+}
+
+static void __init sme_map_range_decrypted_wp(struct sme_populate_pgd_data *ppd)
+{
+	__sme_map_range(ppd, PMD_FLAGS_DEC_WP);
 }
 
 static unsigned long __init sme_pgtable_calc(unsigned long len)
@@ -621,7 +657,6 @@ void __init sme_encrypt_kernel(void)
 	unsigned long kernel_start, kernel_end, kernel_len;
 	struct sme_populate_pgd_data ppd;
 	unsigned long pgtable_area_len;
-	unsigned long paddr, pmd_flags;
 	unsigned long decrypted_base;
 
 	if (!sme_active())
@@ -693,14 +728,10 @@ void __init sme_encrypt_kernel(void)
 	 * addressing the workarea.
 	 */
 	ppd.pgd = (pgd_t *)native_read_cr3_pa();
-	paddr = workarea_start;
-	while (paddr < workarea_end) {
-		ppd.pmd_val = paddr + PMD_FLAGS;
-		ppd.vaddr = paddr;
-		sme_populate_pgd_large(&ppd);
-
-		paddr += PMD_PAGE_SIZE;
-	}
+	ppd.paddr = workarea_start;
+	ppd.vaddr = workarea_start;
+	ppd.vaddr_end = workarea_end;
+	sme_map_range_decrypted(&ppd);
 
 	/* Flush the TLB - no globals so cr3 is enough */
 	native_write_cr3(__native_read_cr3());
@@ -715,17 +746,6 @@ void __init sme_encrypt_kernel(void)
 	memset(ppd.pgd, 0, sizeof(pgd_t) * PTRS_PER_PGD);
 	ppd.pgtable_area += sizeof(pgd_t) * PTRS_PER_PGD;
 
-	/* Add encrypted kernel (identity) mappings */
-	pmd_flags = PMD_FLAGS | _PAGE_ENC;
-	paddr = kernel_start;
-	while (paddr < kernel_end) {
-		ppd.pmd_val = paddr + pmd_flags;
-		ppd.vaddr = paddr;
-		sme_populate_pgd_large(&ppd);
-
-		paddr += PMD_PAGE_SIZE;
-	}
-
 	/*
 	 * A different PGD index/entry must be used to get different
 	 * pagetable entries for the decrypted mapping. Choose the next
@@ -735,29 +755,28 @@ void __init sme_encrypt_kernel(void)
 	decrypted_base = (pgd_index(workarea_end) + 1) & (PTRS_PER_PGD - 1);
 	decrypted_base <<= PGDIR_SHIFT;
 
+	/* Add encrypted kernel (identity) mappings */
+	ppd.paddr = kernel_start;
+	ppd.vaddr = kernel_start;
+	ppd.vaddr_end = kernel_end;
+	sme_map_range_encrypted(&ppd);
+
 	/* Add decrypted, write-protected kernel (non-identity) mappings */
-	pmd_flags = (PMD_FLAGS & ~_PAGE_CACHE_MASK) | (_PAGE_PAT | _PAGE_PWT);
-	paddr = kernel_start;
-	while (paddr < kernel_end) {
-		ppd.pmd_val = paddr + pmd_flags;
-		ppd.vaddr = paddr + decrypted_base;
-		sme_populate_pgd_large(&ppd);
-
-		paddr += PMD_PAGE_SIZE;
-	}
+	ppd.paddr = kernel_start;
+	ppd.vaddr = kernel_start + decrypted_base;
+	ppd.vaddr_end = kernel_end + decrypted_base;
+	sme_map_range_decrypted_wp(&ppd);
 
 	/* Add decrypted workarea mappings to both kernel mappings */
-	paddr = workarea_start;
-	while (paddr < workarea_end) {
-		ppd.pmd_val = paddr + PMD_FLAGS;
-		ppd.vaddr = paddr;
-		sme_populate_pgd_large(&ppd);
-
-		ppd.vaddr = paddr + decrypted_base;
-		sme_populate_pgd_large(&ppd);
+	ppd.paddr = workarea_start;
+	ppd.vaddr = workarea_start;
+	ppd.vaddr_end = workarea_end;
+	sme_map_range_decrypted(&ppd);
 
-		paddr += PMD_PAGE_SIZE;
-	}
+	ppd.paddr = workarea_start;
+	ppd.vaddr = workarea_start + decrypted_base;
+	ppd.vaddr_end = workarea_end + decrypted_base;
+	sme_map_range_decrypted(&ppd);
 
 	/* Perform the encryption */
 	sme_encrypt_execute(kernel_start, kernel_start + decrypted_base,
@@ -768,11 +787,13 @@ void __init sme_encrypt_kernel(void)
 	 * the decrypted areas - all that is needed for this is to remove
 	 * the PGD entry/entries.
 	 */
-	sme_clear_pgd(ppd.pgd, kernel_start + decrypted_base,
-		      kernel_end + decrypted_base);
+	ppd.vaddr = kernel_start + decrypted_base;
+	ppd.vaddr_end = kernel_end + decrypted_base;
+	sme_clear_pgd(&ppd);
 
-	sme_clear_pgd(ppd.pgd, workarea_start + decrypted_base,
-		      workarea_end + decrypted_base);
+	ppd.vaddr = workarea_start + decrypted_base;
+	ppd.vaddr_end = workarea_end + decrypted_base;
+	sme_clear_pgd(&ppd);
 
 	/* Flush the TLB - no globals so cr3 is enough */
 	native_write_cr3(__native_read_cr3());