sd: Honor block layer integrity handling flags

A set of flags introduced in the block layer enable better control over
how protection information is handled. These flags are useful for both
error injection and data recovery purposes. Checking can be enabled and
disabled for controller and disk, and the guard tag format is now a
per-I/O property.

Update sd_protect_op to communicate the relevant information to the
low-level device driver via a set of flags in scsi_cmnd.

Signed-off-by: Martin K. Petersen <martin.petersen@oracle.com>
Reviewed-by: Sagi Grimberg <sagig@mellanox.com>
Acked-by: Christoph Hellwig <hch@lst.de>
Signed-off-by: Jens Axboe <axboe@fb.com>
diff --git a/drivers/scsi/sd.c b/drivers/scsi/sd.c
index 2c2041c..9f7099f 100644
--- a/drivers/scsi/sd.c
+++ b/drivers/scsi/sd.c
@@ -610,29 +610,44 @@
 	mutex_unlock(&sd_ref_mutex);
 }
 
-static void sd_prot_op(struct scsi_cmnd *scmd, unsigned int dif)
-{
-	unsigned int prot_op = SCSI_PROT_NORMAL;
-	unsigned int dix = scsi_prot_sg_count(scmd);
 
-	if (scmd->sc_data_direction == DMA_FROM_DEVICE) {
-		if (dif && dix)
-			prot_op = SCSI_PROT_READ_PASS;
-		else if (dif && !dix)
-			prot_op = SCSI_PROT_READ_STRIP;
-		else if (!dif && dix)
-			prot_op = SCSI_PROT_READ_INSERT;
-	} else {
-		if (dif && dix)
-			prot_op = SCSI_PROT_WRITE_PASS;
-		else if (dif && !dix)
-			prot_op = SCSI_PROT_WRITE_INSERT;
-		else if (!dif && dix)
-			prot_op = SCSI_PROT_WRITE_STRIP;
+
+static unsigned char sd_setup_protect_cmnd(struct scsi_cmnd *scmd,
+					   unsigned int dix, unsigned int dif)
+{
+	struct bio *bio = scmd->request->bio;
+	unsigned int prot_op = sd_prot_op(rq_data_dir(scmd->request), dix, dif);
+	unsigned int protect = 0;
+
+	if (dix) {				/* DIX Type 0, 1, 2, 3 */
+		if (bio_integrity_flagged(bio, BIP_IP_CHECKSUM))
+			scmd->prot_flags |= SCSI_PROT_IP_CHECKSUM;
+
+		if (bio_integrity_flagged(bio, BIP_CTRL_NOCHECK) == false)
+			scmd->prot_flags |= SCSI_PROT_GUARD_CHECK;
+	}
+
+	if (dif != SD_DIF_TYPE3_PROTECTION) {	/* DIX/DIF Type 0, 1, 2 */
+		scmd->prot_flags |= SCSI_PROT_REF_INCREMENT;
+
+		if (bio_integrity_flagged(bio, BIP_CTRL_NOCHECK) == false)
+			scmd->prot_flags |= SCSI_PROT_REF_CHECK;
+	}
+
+	if (dif) {				/* DIX/DIF Type 1, 2, 3 */
+		scmd->prot_flags |= SCSI_PROT_TRANSFER_PI;
+
+		if (bio_integrity_flagged(bio, BIP_DISK_NOCHECK))
+			protect = 3 << 5;	/* Disable target PI checking */
+		else
+			protect = 1 << 5;	/* Enable target PI checking */
 	}
 
 	scsi_set_prot_op(scmd, prot_op);
 	scsi_set_prot_type(scmd, dif);
+	scmd->prot_flags &= sd_prot_flag_mask(prot_op);
+
+	return protect;
 }
 
 static void sd_config_discard(struct scsi_disk *sdkp, unsigned int mode)
@@ -893,7 +908,8 @@
 	sector_t block = blk_rq_pos(rq);
 	sector_t threshold;
 	unsigned int this_count = blk_rq_sectors(rq);
-	int ret, host_dif;
+	unsigned int dif, dix;
+	int ret;
 	unsigned char protect;
 
 	ret = scsi_init_io(SCpnt, GFP_ATOMIC);
@@ -995,7 +1011,7 @@
 		SCpnt->cmnd[0] = WRITE_6;
 
 		if (blk_integrity_rq(rq))
-			sd_dif_prepare(rq, block, sdp->sector_size);
+			sd_dif_prepare(SCpnt);
 
 	} else if (rq_data_dir(rq) == READ) {
 		SCpnt->cmnd[0] = READ_6;
@@ -1010,14 +1026,15 @@
 					"writing" : "reading", this_count,
 					blk_rq_sectors(rq)));
 
-	/* Set RDPROTECT/WRPROTECT if disk is formatted with DIF */
-	host_dif = scsi_host_dif_capable(sdp->host, sdkp->protection_type);
-	if (host_dif)
-		protect = 1 << 5;
+	dix = scsi_prot_sg_count(SCpnt);
+	dif = scsi_host_dif_capable(SCpnt->device->host, sdkp->protection_type);
+
+	if (dif || dix)
+		protect = sd_setup_protect_cmnd(SCpnt, dix, dif);
 	else
 		protect = 0;
 
-	if (host_dif == SD_DIF_TYPE2_PROTECTION) {
+	if (protect && sdkp->protection_type == SD_DIF_TYPE2_PROTECTION) {
 		SCpnt->cmnd = mempool_alloc(sd_cdb_pool, GFP_ATOMIC);
 
 		if (unlikely(SCpnt->cmnd == NULL)) {
@@ -1102,10 +1119,6 @@
 	}
 	SCpnt->sdb.length = this_count * sdp->sector_size;
 
-	/* If DIF or DIX is enabled, tell HBA how to handle request */
-	if (host_dif || scsi_prot_sg_count(SCpnt))
-		sd_prot_op(SCpnt, host_dif);
-
 	/*
 	 * We shouldn't disconnect in the middle of a sector, so with a dumb
 	 * host adapter, it's safe to assume that we can at least transfer
diff --git a/drivers/scsi/sd.h b/drivers/scsi/sd.h
index 4c3ab83..4673778 100644
--- a/drivers/scsi/sd.h
+++ b/drivers/scsi/sd.h
@@ -167,6 +167,68 @@
 };
 
 /*
+ * Look up the DIX operation based on whether the command is read or
+ * write and whether dix and dif are enabled.
+ */
+static inline unsigned int sd_prot_op(bool write, bool dix, bool dif)
+{
+	/* Lookup table: bit 2 (write), bit 1 (dix), bit 0 (dif) */
+	const unsigned int ops[] = {	/* wrt dix dif */
+		SCSI_PROT_NORMAL,	/*  0	0   0  */
+		SCSI_PROT_READ_STRIP,	/*  0	0   1  */
+		SCSI_PROT_READ_INSERT,	/*  0	1   0  */
+		SCSI_PROT_READ_PASS,	/*  0	1   1  */
+		SCSI_PROT_NORMAL,	/*  1	0   0  */
+		SCSI_PROT_WRITE_INSERT, /*  1	0   1  */
+		SCSI_PROT_WRITE_STRIP,	/*  1	1   0  */
+		SCSI_PROT_WRITE_PASS,	/*  1	1   1  */
+	};
+
+	return ops[write << 2 | dix << 1 | dif];
+}
+
+/*
+ * Returns a mask of the protection flags that are valid for a given DIX
+ * operation.
+ */
+static inline unsigned int sd_prot_flag_mask(unsigned int prot_op)
+{
+	const unsigned int flag_mask[] = {
+		[SCSI_PROT_NORMAL]		= 0,
+
+		[SCSI_PROT_READ_STRIP]		= SCSI_PROT_TRANSFER_PI |
+						  SCSI_PROT_GUARD_CHECK |
+						  SCSI_PROT_REF_CHECK |
+						  SCSI_PROT_REF_INCREMENT,
+
+		[SCSI_PROT_READ_INSERT]		= SCSI_PROT_REF_INCREMENT |
+						  SCSI_PROT_IP_CHECKSUM,
+
+		[SCSI_PROT_READ_PASS]		= SCSI_PROT_TRANSFER_PI |
+						  SCSI_PROT_GUARD_CHECK |
+						  SCSI_PROT_REF_CHECK |
+						  SCSI_PROT_REF_INCREMENT |
+						  SCSI_PROT_IP_CHECKSUM,
+
+		[SCSI_PROT_WRITE_INSERT]	= SCSI_PROT_TRANSFER_PI |
+						  SCSI_PROT_REF_INCREMENT,
+
+		[SCSI_PROT_WRITE_STRIP]		= SCSI_PROT_GUARD_CHECK |
+						  SCSI_PROT_REF_CHECK |
+						  SCSI_PROT_REF_INCREMENT |
+						  SCSI_PROT_IP_CHECKSUM,
+
+		[SCSI_PROT_WRITE_PASS]		= SCSI_PROT_TRANSFER_PI |
+						  SCSI_PROT_GUARD_CHECK |
+						  SCSI_PROT_REF_CHECK |
+						  SCSI_PROT_REF_INCREMENT |
+						  SCSI_PROT_IP_CHECKSUM,
+	};
+
+	return flag_mask[prot_op];
+}
+
+/*
  * Data Integrity Field tuple.
  */
 struct sd_dif_tuple {
@@ -178,7 +240,7 @@
 #ifdef CONFIG_BLK_DEV_INTEGRITY
 
 extern void sd_dif_config_host(struct scsi_disk *);
-extern void sd_dif_prepare(struct request *rq, sector_t, unsigned int);
+extern void sd_dif_prepare(struct scsi_cmnd *scmd);
 extern void sd_dif_complete(struct scsi_cmnd *, unsigned int);
 
 #else /* CONFIG_BLK_DEV_INTEGRITY */
@@ -186,7 +248,7 @@
 static inline void sd_dif_config_host(struct scsi_disk *disk)
 {
 }
-static inline int sd_dif_prepare(struct request *rq, sector_t s, unsigned int a)
+static inline int sd_dif_prepare(struct scsi_cmnd *scmd)
 {
 	return 0;
 }
diff --git a/drivers/scsi/sd_dif.c b/drivers/scsi/sd_dif.c
index b7eaead..14c7d42 100644
--- a/drivers/scsi/sd_dif.c
+++ b/drivers/scsi/sd_dif.c
@@ -106,8 +106,7 @@
  *
  * Type 3 does not have a reference tag so no remapping is required.
  */
-void sd_dif_prepare(struct request *rq, sector_t hw_sector,
-		    unsigned int sector_sz)
+void sd_dif_prepare(struct scsi_cmnd *scmd)
 {
 	const int tuple_sz = sizeof(struct t10_pi_tuple);
 	struct bio *bio;
@@ -115,14 +114,14 @@
 	struct t10_pi_tuple *pi;
 	u32 phys, virt;
 
-	sdkp = rq->bio->bi_bdev->bd_disk->private_data;
+	sdkp = scsi_disk(scmd->request->rq_disk);
 
 	if (sdkp->protection_type == SD_DIF_TYPE3_PROTECTION)
 		return;
 
-	phys = hw_sector & 0xffffffff;
+	phys = scsi_prot_ref_tag(scmd);
 
-	__rq_for_each_bio(bio, rq) {
+	__rq_for_each_bio(bio, scmd->request) {
 		struct bio_integrity_payload *bip = bio_integrity(bio);
 		struct bio_vec iv;
 		struct bvec_iter iter;
@@ -163,7 +162,7 @@
 	struct scsi_disk *sdkp;
 	struct bio *bio;
 	struct t10_pi_tuple *pi;
-	unsigned int j, sectors, sector_sz;
+	unsigned int j, intervals;
 	u32 phys, virt;
 
 	sdkp = scsi_disk(scmd->request->rq_disk);
@@ -171,12 +170,8 @@
 	if (sdkp->protection_type == SD_DIF_TYPE3_PROTECTION || good_bytes == 0)
 		return;
 
-	sector_sz = scmd->device->sector_size;
-	sectors = good_bytes / sector_sz;
-
-	phys = blk_rq_pos(scmd->request) & 0xffffffff;
-	if (sector_sz == 4096)
-		phys >>= 3;
+	intervals = good_bytes / scsi_prot_interval(scmd);
+	phys = scsi_prot_ref_tag(scmd);
 
 	__rq_for_each_bio(bio, scmd->request) {
 		struct bio_integrity_payload *bip = bio_integrity(bio);
@@ -190,7 +185,7 @@
 
 			for (j = 0; j < iv.bv_len; j += tuple_sz, pi++) {
 
-				if (sectors == 0) {
+				if (intervals == 0) {
 					kunmap_atomic(pi);
 					return;
 				}
@@ -200,7 +195,7 @@
 
 				virt++;
 				phys++;
-				sectors--;
+				intervals--;
 			}
 
 			kunmap_atomic(pi);