diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index c1871d7b134cfdc7aadca32aa016a57f2f529e9b..b42d615fa847988119fe4c704f77c1708ada7b58 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -3407,6 +3407,25 @@ static int mem_cgroup_move_charge_write(struct cgroup_subsys_state *css,
 int sysctl_memcg_qos_stat = DISABLE_MEMCG_QOS;
 DEFINE_STATIC_KEY_FALSE(memcg_qos_stat_key);
 
+static void memcg_hierarchy_qos_set(struct mem_cgroup *memcg, int val)
+{
+	struct mem_cgroup *iter;
+	struct cgroup_subsys_state *css;
+	struct mem_cgroup_extension *memcg_ext;
+
+	if (!memcg)
+		memcg = root_mem_cgroup;
+
+	rcu_read_lock();
+	css_for_each_descendant_pre(css, &memcg->css) {
+		iter = mem_cgroup_from_css(css);
+		memcg_ext = to_memcg_ext(iter);
+
+		memcg_ext->memcg_priority = val;
+	}
+	rcu_read_unlock();
+}
+
 static void memcg_qos_init(struct mem_cgroup *memcg)
 {
 	struct mem_cgroup *parent = parent_mem_cgroup(memcg);
@@ -3449,12 +3468,17 @@ static int memcg_qos_write(struct cgroup_subsys_state *css,
 	if (!static_branch_likely(&memcg_qos_stat_key))
 		return -EACCES;
 
+	if (mem_cgroup_is_root(memcg))
+		return -EINVAL;
+
+	if (val != 0 && val != -1)
+		return -EINVAL;
+
 	memcg_ext = to_memcg_ext(memcg);
 
-	if (val >= 0)
-		memcg_ext->memcg_priority = 0;
-	else
-		memcg_ext->memcg_priority = -1;
+	memcg_ext->memcg_priority = val;
+	if (memcg->use_hierarchy)
+		memcg_hierarchy_qos_set(memcg, val);
 
 	return 0;
 }
@@ -3548,23 +3572,6 @@ void memcg_print_bad_task(void *arg, int ret)
 	}
 }
 
-static void memcg_qos_reset(void)
-{
-	struct mem_cgroup *iter;
-	struct cgroup_subsys_state *css;
-	struct mem_cgroup_extension *memcg_ext;
-
-	rcu_read_lock();
-	css_for_each_descendant_pre(css, &root_mem_cgroup->css) {
-		iter = mem_cgroup_from_css(css);
-		memcg_ext = to_memcg_ext(iter);
-
-		if (memcg_ext->memcg_priority)
-			memcg_ext->memcg_priority = 0;
-	}
-	rcu_read_unlock();
-}
-
 int sysctl_memcg_qos_handler(struct ctl_table *table, int write,
 		void __user *buffer, size_t *length, loff_t *ppos)
 {
@@ -3579,7 +3586,7 @@ int sysctl_memcg_qos_handler(struct ctl_table *table, int write,
 			pr_info("enable memcg priority.\n");
 		} else {
 			static_branch_disable(&memcg_qos_stat_key);
-			memcg_qos_reset();
+			memcg_hierarchy_qos_set(NULL, 0);
 			pr_info("disable memcg priority.\n");
 		}
 	}