diff --git a/mm/vmscan.c b/mm/vmscan.c
index d59b22a7ba9a2dad5694dbeb3742c420ebf218fa..42f24473756b7792bec5d3640402dbfe072d1d5b 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -185,7 +185,7 @@ int vm_cache_reclaim_weight __read_mostly;
 int vm_cache_reclaim_weight_min;
 int vm_cache_reclaim_weight_max;
 int vm_cache_reclaim_enable;
-static DEFINE_PER_CPU(struct delayed_work, vmscan_work);
+static struct work_struct vmscan_works[MAX_NUMNODES];
 #endif
 
 static LIST_HEAD(shrinker_list);
@@ -3972,8 +3972,8 @@ static unsigned long __shrink_page_cache(gfp_t mask)
 		.gfp_mask = current_gfp_context(mask),
 		.reclaim_idx = gfp_zone(mask),
 		.may_writepage = !laptop_mode,
-		.nr_to_reclaim = SWAP_CLUSTER_MAX *
-				 (unsigned long)vm_cache_reclaim_weight,
+		.nr_to_reclaim = SWAP_CLUSTER_MAX * nr_cpus_node(numa_node_id()) *
+				 vm_cache_reclaim_weight,
 		.may_unmap = 1,
 		.may_swap = mem_reliable_is_enabled() ? 0 : 1,
 		.no_shrink_slab = mem_reliable_is_enabled() ? 0 : 1,
@@ -3997,26 +3997,21 @@ static DECLARE_DEFERRABLE_WORK(shepherd, shrink_shepherd);
 
 static void shrink_shepherd(struct work_struct *w)
 {
-	int cpu;
 
-	get_online_cpus();
+	int node;
 
-	for_each_online_cpu(cpu) {
-		struct delayed_work *work = &per_cpu(vmscan_work, cpu);
-
-		if (!delayed_work_pending(work) && vm_cache_reclaim_enable)
-			queue_delayed_work_on(cpu, system_wq, work, 0);
+	for_each_online_node(node) {
+		if (!work_pending(&vmscan_works[node]) && vm_cache_reclaim_enable)
+			queue_work_node(node, system_unbound_wq, &vmscan_works[node]);
 	}
 
-	put_online_cpus();
-
 	/* we want all kernel thread to stop */
 	if (vm_cache_reclaim_enable) {
 		if (vm_cache_reclaim_s == 0)
-			schedule_delayed_work(&shepherd,
+			queue_delayed_work(system_unbound_wq, &shepherd,
 				round_jiffies_relative(120 * HZ));
 		else
-			schedule_delayed_work(&shepherd,
+			queue_delayed_work(system_unbound_wq, &shepherd,
 				round_jiffies_relative((unsigned long)
 				vm_cache_reclaim_s * HZ));
 	}
@@ -4024,15 +4019,12 @@ static void shrink_shepherd(struct work_struct *w)
 
 static void shrink_shepherd_timer(void)
 {
-	int cpu;
-
-	for_each_possible_cpu(cpu) {
-		struct delayed_work *work = &per_cpu(vmscan_work, cpu);
+	int i;
 
-		INIT_DEFERRABLE_WORK(work, shrink_page_cache_work);
-	}
+	for (i = 0; i < MAX_NUMNODES; i++)
+		INIT_WORK(&vmscan_works[i], shrink_page_cache_work);
 
-	schedule_delayed_work(&shepherd,
+	queue_delayed_work(system_unbound_wq, &shepherd,
 		round_jiffies_relative((unsigned long)vm_cache_reclaim_s * HZ));
 }
 
@@ -4048,9 +4040,6 @@ unsigned long shrink_page_cache(gfp_t mask)
 
 static void shrink_page_cache_work(struct work_struct *w)
 {
-	struct delayed_work *work = to_delayed_work(w);
-	unsigned long nr_pages;
-
 	/*
 	 * if vm_cache_reclaim_enable or vm_cache_reclaim_s is zero,
 	 * we do not shrink page cache again.
@@ -4063,10 +4052,7 @@ static void shrink_page_cache_work(struct work_struct *w)
 		return;
 
 	/* It should wait more time if we hardly reclaim the page cache */
-	nr_pages = shrink_page_cache(GFP_KERNEL);
-	if ((nr_pages < SWAP_CLUSTER_MAX) && vm_cache_reclaim_enable)
-		queue_delayed_work_on(smp_processor_id(), system_wq, work,
-		round_jiffies_relative((vm_cache_reclaim_s + 120) * HZ));
+	shrink_page_cache(GFP_KERNEL);
 }
 
 static void shrink_page_cache_init(void)
@@ -4088,13 +4074,6 @@ static void shrink_page_cache_init(void)
 	shrink_shepherd_timer();
 }
 
-static int kswapd_cpu_down_prep(unsigned int cpu)
-{
-	cancel_delayed_work_sync(&per_cpu(vmscan_work, cpu));
-
-	return 0;
-}
-
 int cache_reclaim_enable_handler(struct ctl_table *table, int write,
 			void __user *buffer, size_t *length, loff_t *ppos)
 {
@@ -4105,8 +4084,8 @@ int cache_reclaim_enable_handler(struct ctl_table *table, int write,
 		return ret;
 
 	if (write)
-		schedule_delayed_work(&shepherd, round_jiffies_relative(
-			(unsigned long)vm_cache_reclaim_s * HZ));
+		queue_delayed_work(system_unbound_wq, &shepherd,
+		round_jiffies_relative((unsigned long)vm_cache_reclaim_s * HZ));
 
 	return 0;
 }
@@ -4121,7 +4100,7 @@ int cache_reclaim_sysctl_handler(struct ctl_table *table, int write,
 		return ret;
 
 	if (write)
-		mod_delayed_work(system_wq, &shepherd,
+		mod_delayed_work(system_unbound_wq, &shepherd,
 				round_jiffies_relative(
 				(unsigned long)vm_cache_reclaim_s * HZ));
 
@@ -4194,15 +4173,9 @@ static int __init kswapd_init(void)
 	swap_setup();
 	for_each_node_state(nid, N_MEMORY)
  		kswapd_run(nid);
-#ifdef CONFIG_SHRINK_PAGECACHE
-	ret = cpuhp_setup_state_nocalls(CPUHP_AP_ONLINE_DYN,
-					"mm/vmscan:online", kswapd_cpu_online,
-					kswapd_cpu_down_prep);
-#else
 	ret = cpuhp_setup_state_nocalls(CPUHP_AP_ONLINE_DYN,
 					"mm/vmscan:online", kswapd_cpu_online,
 					NULL);
-#endif
 	WARN_ON(ret < 0);
 #ifdef CONFIG_SHRINK_PAGECACHE
 	shrink_page_cache_init();