diff --git a/drivers/md/dm.c b/drivers/md/dm.c
index 2c907bc10fe9ac0cab0b081ce920fdefe7bd6b98..7538b897282072fbc73b60079d2fc87a64a03351 100644
--- a/drivers/md/dm.c
+++ b/drivers/md/dm.c
@@ -2521,26 +2521,76 @@ void dm_free_md_mempools(struct dm_md_mempools *pools)
 	kfree(pools);
 }
 
-static int dm_pr_register(struct block_device *bdev, u64 old_key, u64 new_key,
-			  u32 flags)
+struct dm_pr {
+	u64	old_key;
+	u64	new_key;
+	u32	flags;
+	bool	fail_early;
+};
+
+static int dm_call_pr(struct block_device *bdev, iterate_devices_callout_fn fn,
+		      void *data)
 {
 	struct mapped_device *md = bdev->bd_disk->private_data;
-	const struct pr_ops *ops;
-	fmode_t mode;
-	int r;
+	struct dm_table *table;
+	struct dm_target *ti;
+	int ret = -ENOTTY, srcu_idx;
 
-	r = dm_grab_bdev_for_ioctl(md, &bdev, &mode);
-	if (r < 0)
-		return r;
+	table = dm_get_live_table(md, &srcu_idx);
+	if (!table || !dm_table_get_size(table))
+		goto out;
 
-	ops = bdev->bd_disk->fops->pr_ops;
-	if (ops && ops->pr_register)
-		r = ops->pr_register(bdev, old_key, new_key, flags);
-	else
-		r = -EOPNOTSUPP;
+	/* We only support devices that have a single target */
+	if (dm_table_get_num_targets(table) != 1)
+		goto out;
+	ti = dm_table_get_target(table, 0);
 
-	bdput(bdev);
-	return r;
+	ret = -EINVAL;
+	if (!ti->type->iterate_devices)
+		goto out;
+
+	ret = ti->type->iterate_devices(ti, fn, data);
+out:
+	dm_put_live_table(md, srcu_idx);
+	return ret;
+}
+
+/*
+ * For register / unregister we need to manually call out to every path.
+ */
+static int __dm_pr_register(struct dm_target *ti, struct dm_dev *dev,
+			    sector_t start, sector_t len, void *data)
+{
+	struct dm_pr *pr = data;
+	const struct pr_ops *ops = dev->bdev->bd_disk->fops->pr_ops;
+
+	if (!ops || !ops->pr_register)
+		return -EOPNOTSUPP;
+	return ops->pr_register(dev->bdev, pr->old_key, pr->new_key, pr->flags);
+}
+
+static int dm_pr_register(struct block_device *bdev, u64 old_key, u64 new_key,
+			  u32 flags)
+{
+	struct dm_pr pr = {
+		.old_key	= old_key,
+		.new_key	= new_key,
+		.flags		= flags,
+		.fail_early	= true,
+	};
+	int ret;
+
+	ret = dm_call_pr(bdev, __dm_pr_register, &pr);
+	if (ret && new_key) {
+		/* unregister all paths if we failed to register any path */
+		pr.old_key = new_key;
+		pr.new_key = 0;
+		pr.flags = 0;
+		pr.fail_early = false;
+		dm_call_pr(bdev, __dm_pr_register, &pr);
+	}
+
+	return ret;
 }
 
 static int dm_pr_reserve(struct block_device *bdev, u64 key, enum pr_type type,