diff --git a/drivers/scsi/ufs/ufs.h b/drivers/scsi/ufs/ufs.h
index 8770255b5dc0a5f565278bc8f859aab626641b5c..250e1a905a148f3d5229e264ab2b38c926e516af 100644
--- a/drivers/scsi/ufs/ufs.h
+++ b/drivers/scsi/ufs/ufs.h
@@ -63,6 +63,7 @@
 #define UFS_UPIU_MAX_UNIT_NUM_ID	0x7F
 #define UFS_MAX_LUNS		(SCSI_W_LUN_BASE + UFS_UPIU_MAX_UNIT_NUM_ID)
 #define UFS_UPIU_WLUN_ID	(1 << 7)
+#define UFS_RPMB_UNIT		0xC4
 
 /* WriteBooster buffer is available only for the logical unit from 0 to 7 */
 #define UFS_UPIU_MAX_WB_LUN_ID	8
diff --git a/drivers/scsi/ufs/ufshcd.c b/drivers/scsi/ufs/ufshcd.c
index 116b95cd24844608b1f4473efdd65522dc57dba7..42ba2bab35fd3c1e08234154f5411caeb11698de 100644
--- a/drivers/scsi/ufs/ufshcd.c
+++ b/drivers/scsi/ufs/ufshcd.c
@@ -3083,11 +3083,16 @@ void ufshcd_map_desc_id_to_length(struct ufs_hba *hba, enum desc_idn desc_id,
 EXPORT_SYMBOL(ufshcd_map_desc_id_to_length);
 
 static void ufshcd_update_desc_length(struct ufs_hba *hba,
-				      enum desc_idn desc_id,
+				      enum desc_idn desc_id, int desc_index,
 				      unsigned char desc_len)
 {
 	if (hba->desc_size[desc_id] == QUERY_DESC_MAX_SIZE &&
-	    desc_id != QUERY_DESC_IDN_STRING)
+	    desc_id != QUERY_DESC_IDN_STRING && desc_index != UFS_RPMB_UNIT)
+		/* For UFS 3.1, the normal unit descriptor is 10 bytes larger
+		 * than the RPMB unit, however, both descriptors share the same
+		 * desc_idn, to cover both unit descriptors with one length, we
+		 * choose the normal unit descriptor length by desc_index.
+		 */
 		hba->desc_size[desc_id] = desc_len;
 }
 
@@ -3156,7 +3161,7 @@ int ufshcd_read_desc_param(struct ufs_hba *hba,
 
 	/* Update descriptor length */
 	buff_len = desc_buf[QUERY_DESC_LENGTH_OFFSET];
-	ufshcd_update_desc_length(hba, desc_id, buff_len);
+	ufshcd_update_desc_length(hba, desc_id, desc_index, buff_len);
 
 	/* Check wherher we will not copy more data, than available */
 	if (is_kmalloc && (param_offset + param_size) > buff_len)