lib: radix-tree: check accounting of existing slot replacement users

The bug in khugepaged fixed earlier in this series shows that radix tree
slot replacement is fragile; and it will become more so when not only
NULL<->!NULL transitions need to be caught but transitions from and to
exceptional entries as well.  We need checks.

Re-implement radix_tree_replace_slot() on top of the sanity-checked
__radix_tree_replace().  This requires existing callers to also pass the
radix tree root, but it'll warn us when somebody replaces slots with
contents that need proper accounting (transitions between NULL entries,
real entries, exceptional entries) and where a replacement through the
slot pointer would corrupt the radix tree node counts.

Link: http://lkml.kernel.org/r/20161117193021.GB23430@cmpxchg.org
Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>
Suggested-by: Jan Kara <jack@suse.cz>
Reviewed-by: Jan Kara <jack@suse.cz>
Cc: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Cc: Hugh Dickins <hughd@google.com>
Cc: Matthew Wilcox <mawilcox@linuxonhyperv.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
index 3ba6227..ec1f0de 100644
--- a/arch/s390/mm/gmap.c
+++ b/arch/s390/mm/gmap.c
@@ -1015,7 +1015,7 @@
 	if (slot) {
 		rmap->next = radix_tree_deref_slot_protected(slot,
 							&sg->guest_table_lock);
-		radix_tree_replace_slot(slot, rmap);
+		radix_tree_replace_slot(&sg->host_to_rmap, slot, rmap);
 	} else {
 		rmap->next = NULL;
 		radix_tree_insert(&sg->host_to_rmap, vmaddr >> PAGE_SHIFT,
diff --git a/drivers/sh/intc/virq.c b/drivers/sh/intc/virq.c
index e789962..35bbe28 100644
--- a/drivers/sh/intc/virq.c
+++ b/drivers/sh/intc/virq.c
@@ -254,7 +254,7 @@
 
 		radix_tree_tag_clear(&d->tree, entry->enum_id,
 				     INTC_TAG_VIRQ_NEEDS_ALLOC);
-		radix_tree_replace_slot((void **)entries[i],
+		radix_tree_replace_slot(&d->tree, (void **)entries[i],
 					&intc_irq_xlate[irq]);
 	}
 
diff --git a/fs/dax.c b/fs/dax.c
index db78bae..85930c2 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -342,7 +342,7 @@
 		radix_tree_deref_slot_protected(slot, &mapping->tree_lock);
 
 	entry |= RADIX_DAX_ENTRY_LOCK;
-	radix_tree_replace_slot(slot, (void *)entry);
+	radix_tree_replace_slot(&mapping->page_tree, slot, (void *)entry);
 	return (void *)entry;
 }
 
@@ -356,7 +356,7 @@
 		radix_tree_deref_slot_protected(slot, &mapping->tree_lock);
 
 	entry &= ~(unsigned long)RADIX_DAX_ENTRY_LOCK;
-	radix_tree_replace_slot(slot, (void *)entry);
+	radix_tree_replace_slot(&mapping->page_tree, slot, (void *)entry);
 	return (void *)entry;
 }
 
diff --git a/include/linux/radix-tree.h b/include/linux/radix-tree.h
index 7ced8a7..2d1b9b8 100644
--- a/include/linux/radix-tree.h
+++ b/include/linux/radix-tree.h
@@ -249,20 +249,6 @@
 	return unlikely((unsigned long)arg & RADIX_TREE_ENTRY_MASK);
 }
 
-/**
- * radix_tree_replace_slot	- replace item in a slot
- * @pslot:	pointer to slot, returned by radix_tree_lookup_slot
- * @item:	new item to store in the slot.
- *
- * For use with radix_tree_lookup_slot().  Caller must hold tree write locked
- * across slot lookup and replacement.
- */
-static inline void radix_tree_replace_slot(void **pslot, void *item)
-{
-	BUG_ON(radix_tree_is_internal_node(item));
-	rcu_assign_pointer(*pslot, item);
-}
-
 int __radix_tree_create(struct radix_tree_root *root, unsigned long index,
 			unsigned order, struct radix_tree_node **nodep,
 			void ***slotp);
@@ -280,6 +266,8 @@
 void __radix_tree_replace(struct radix_tree_root *root,
 			  struct radix_tree_node *node,
 			  void **slot, void *item);
+void radix_tree_replace_slot(struct radix_tree_root *root,
+			     void **slot, void *item);
 bool __radix_tree_delete_node(struct radix_tree_root *root,
 			      struct radix_tree_node *node);
 void *radix_tree_delete_item(struct radix_tree_root *, unsigned long, void *);
diff --git a/lib/radix-tree.c b/lib/radix-tree.c
index 7885796..f91d5b0 100644
--- a/lib/radix-tree.c
+++ b/lib/radix-tree.c
@@ -753,6 +753,28 @@
 }
 EXPORT_SYMBOL(radix_tree_lookup);
 
+static void replace_slot(struct radix_tree_root *root,
+			 struct radix_tree_node *node,
+			 void **slot, void *item,
+			 bool warn_typeswitch)
+{
+	void *old = rcu_dereference_raw(*slot);
+	int exceptional;
+
+	WARN_ON_ONCE(radix_tree_is_internal_node(item));
+	WARN_ON_ONCE(!!item - !!old);
+
+	exceptional = !!radix_tree_exceptional_entry(item) -
+		      !!radix_tree_exceptional_entry(old);
+
+	WARN_ON_ONCE(warn_typeswitch && exceptional);
+
+	if (node)
+		node->exceptional += exceptional;
+
+	rcu_assign_pointer(*slot, item);
+}
+
 /**
  * __radix_tree_replace		- replace item in a slot
  * @root:	radix tree root
@@ -767,21 +789,34 @@
 			  struct radix_tree_node *node,
 			  void **slot, void *item)
 {
-	void *old = rcu_dereference_raw(*slot);
-	int exceptional;
+	/*
+	 * This function supports replacing exceptional entries, but
+	 * that needs accounting against the node unless the slot is
+	 * root->rnode.
+	 */
+	replace_slot(root, node, slot, item,
+		     !node && slot != (void **)&root->rnode);
+}
 
-	WARN_ON_ONCE(radix_tree_is_internal_node(item));
-	WARN_ON_ONCE(!!item - !!old);
-
-	exceptional = !!radix_tree_exceptional_entry(item) -
-		      !!radix_tree_exceptional_entry(old);
-
-	WARN_ON_ONCE(exceptional && !node && slot != (void **)&root->rnode);
-
-	if (node)
-		node->exceptional += exceptional;
-
-	rcu_assign_pointer(*slot, item);
+/**
+ * radix_tree_replace_slot	- replace item in a slot
+ * @root:	radix tree root
+ * @slot:	pointer to slot
+ * @item:	new item to store in the slot.
+ *
+ * For use with radix_tree_lookup_slot(), radix_tree_gang_lookup_slot(),
+ * radix_tree_gang_lookup_tag_slot().  Caller must hold tree write locked
+ * across slot lookup and replacement.
+ *
+ * NOTE: This cannot be used to switch between non-entries (empty slots),
+ * regular entries, and exceptional entries, as that requires accounting
+ * inside the radix tree node. When switching from one type of entry to
+ * another, use __radix_tree_lookup() and __radix_tree_replace().
+ */
+void radix_tree_replace_slot(struct radix_tree_root *root,
+			     void **slot, void *item)
+{
+	replace_slot(root, NULL, slot, item, true);
 }
 
 /**
diff --git a/mm/filemap.c b/mm/filemap.c
index caa779f..1ba726a 100644
--- a/mm/filemap.c
+++ b/mm/filemap.c
@@ -147,7 +147,7 @@
 						      false);
 		}
 	}
-	radix_tree_replace_slot(slot, page);
+	radix_tree_replace_slot(&mapping->page_tree, slot, page);
 	mapping->nrpages++;
 	if (node) {
 		workingset_node_pages_inc(node);
@@ -196,7 +196,7 @@
 			shadow = NULL;
 		}
 
-		radix_tree_replace_slot(slot, shadow);
+		radix_tree_replace_slot(&mapping->page_tree, slot, shadow);
 
 		if (!node)
 			break;
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index 5d7c006..7a50c72 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -1426,7 +1426,7 @@
 		list_add_tail(&page->lru, &pagelist);
 
 		/* Finally, replace with the new page. */
-		radix_tree_replace_slot(slot,
+		radix_tree_replace_slot(&mapping->page_tree, slot,
 				new_page + (index % HPAGE_PMD_NR));
 
 		slot = radix_tree_iter_next(&iter);
@@ -1538,7 +1538,8 @@
 			/* Unfreeze the page. */
 			list_del(&page->lru);
 			page_ref_unfreeze(page, 2);
-			radix_tree_replace_slot(slot, page);
+			radix_tree_replace_slot(&mapping->page_tree,
+						slot, page);
 			spin_unlock_irq(&mapping->tree_lock);
 			putback_lru_page(page);
 			unlock_page(page);
diff --git a/mm/migrate.c b/mm/migrate.c
index 66ce6b4..0ed24b1 100644
--- a/mm/migrate.c
+++ b/mm/migrate.c
@@ -482,7 +482,7 @@
 		SetPageDirty(newpage);
 	}
 
-	radix_tree_replace_slot(pslot, newpage);
+	radix_tree_replace_slot(&mapping->page_tree, pslot, newpage);
 
 	/*
 	 * Drop cache reference from old page by unfreezing
@@ -556,7 +556,7 @@
 
 	get_page(newpage);
 
-	radix_tree_replace_slot(pslot, newpage);
+	radix_tree_replace_slot(&mapping->page_tree, pslot, newpage);
 
 	page_ref_unfreeze(page, expected_count - 1);
 
diff --git a/mm/truncate.c b/mm/truncate.c
index 8d8c62d..3c631c3 100644
--- a/mm/truncate.c
+++ b/mm/truncate.c
@@ -49,7 +49,7 @@
 		goto unlock;
 	if (*slot != entry)
 		goto unlock;
-	radix_tree_replace_slot(slot, NULL);
+	radix_tree_replace_slot(&mapping->page_tree, slot, NULL);
 	mapping->nrexceptional--;
 	if (!node)
 		goto unlock;
diff --git a/tools/testing/radix-tree/multiorder.c b/tools/testing/radix-tree/multiorder.c
index 05d7bc4..d1be946 100644
--- a/tools/testing/radix-tree/multiorder.c
+++ b/tools/testing/radix-tree/multiorder.c
@@ -146,7 +146,7 @@
 
 	slot = radix_tree_lookup_slot(&tree, index);
 	free(*slot);
-	radix_tree_replace_slot(slot, item2);
+	radix_tree_replace_slot(&tree, slot, item2);
 	for (i = min; i < max; i++) {
 		struct item *item = item_lookup(&tree, i);
 		assert(item != 0);