diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c
index 7ce18f3c097ac4779eb4cf6ed0ad14ac1beb3eb5..28cbacae4e5161d0682356fbd3260fca27ff1bcc 100644
--- a/kernel/sched/fair.c
+++ b/kernel/sched/fair.c
@@ -1196,9 +1196,11 @@ static void task_numa_assign(struct task_numa_env *env,
 static bool load_too_imbalanced(long src_load, long dst_load,
 				struct task_numa_env *env)
 {
-	long imb, old_imb;
-	long orig_src_load, orig_dst_load;
 	long src_capacity, dst_capacity;
+	long orig_src_load;
+	long load_a, load_b;
+	long moved_load;
+	long imb;
 
 	/*
 	 * The load is corrected for the CPU capacity available on each node.
@@ -1211,30 +1213,39 @@ static bool load_too_imbalanced(long src_load, long dst_load,
 	dst_capacity = env->dst_stats.compute_capacity;
 
 	/* We care about the slope of the imbalance, not the direction. */
-	if (dst_load < src_load)
-		swap(dst_load, src_load);
+	load_a = dst_load;
+	load_b = src_load;
+	if (load_a < load_b)
+		swap(load_a, load_b);
 
 	/* Is the difference below the threshold? */
-	imb = dst_load * src_capacity * 100 -
-	      src_load * dst_capacity * env->imbalance_pct;
+	imb = load_a * src_capacity * 100 -
+		load_b * dst_capacity * env->imbalance_pct;
 	if (imb <= 0)
 		return false;
 
 	/*
 	 * The imbalance is above the allowed threshold.
-	 * Compare it with the old imbalance.
+	 * Allow a move that brings us closer to a balanced situation,
+	 * without moving things past the point of balance.
 	 */
 	orig_src_load = env->src_stats.load;
-	orig_dst_load = env->dst_stats.load;
 
-	if (orig_dst_load < orig_src_load)
-		swap(orig_dst_load, orig_src_load);
-
-	old_imb = orig_dst_load * src_capacity * 100 -
-		  orig_src_load * dst_capacity * env->imbalance_pct;
+	/*
+	 * In a task swap, there will be one load moving from src to dst,
+	 * and another moving back. This is the net sum of both moves.
+	 * A simple task move will always have a positive value.
+	 * Allow the move if it brings the system closer to a balanced
+	 * situation, without crossing over the balance point.
+	 */
+	moved_load = orig_src_load - src_load;
 
-	/* Would this change make things worse? */
-	return (imb > old_imb);
+	if (moved_load > 0)
+		/* Moving src -> dst. Did we overshoot balance? */
+		return src_load * dst_capacity < dst_load * src_capacity;
+	else
+		/* Moving dst -> src. Did we overshoot balance? */
+		return dst_load * src_capacity < src_load * dst_capacity;
 }
 
 /*