SUNRPC: Fix a potential race in rpc_wake_up_task()

Use RCU to ensure that we can safely call rpc_finish_wakeup after we've
called __rpc_do_wake_up_task. If not, there is a theoretical race, in which
the rpc_task finishes executing, and gets freed first.

Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c
index 66d01365..6b808c0 100644
--- a/net/sunrpc/sched.c
+++ b/net/sunrpc/sched.c
@@ -427,16 +427,19 @@
  */
 void rpc_wake_up_task(struct rpc_task *task)
 {
+	rcu_read_lock_bh();
 	if (rpc_start_wakeup(task)) {
 		if (RPC_IS_QUEUED(task)) {
 			struct rpc_wait_queue *queue = task->u.tk_wait.rpc_waitq;
 
-			spin_lock_bh(&queue->lock);
+			/* Note: we're already in a bh-safe context */
+			spin_lock(&queue->lock);
 			__rpc_do_wake_up_task(task);
-			spin_unlock_bh(&queue->lock);
+			spin_unlock(&queue->lock);
 		}
 		rpc_finish_wakeup(task);
 	}
+	rcu_read_unlock_bh();
 }
 
 /*
@@ -499,14 +502,16 @@
 	struct rpc_task	*task = NULL;
 
 	dprintk("RPC:      wake_up_next(%p \"%s\")\n", queue, rpc_qname(queue));
-	spin_lock_bh(&queue->lock);
+	rcu_read_lock_bh();
+	spin_lock(&queue->lock);
 	if (RPC_IS_PRIORITY(queue))
 		task = __rpc_wake_up_next_priority(queue);
 	else {
 		task_for_first(task, &queue->tasks[0])
 			__rpc_wake_up_task(task);
 	}
-	spin_unlock_bh(&queue->lock);
+	spin_unlock(&queue->lock);
+	rcu_read_unlock_bh();
 
 	return task;
 }
@@ -522,7 +527,8 @@
 	struct rpc_task *task, *next;
 	struct list_head *head;
 
-	spin_lock_bh(&queue->lock);
+	rcu_read_lock_bh();
+	spin_lock(&queue->lock);
 	head = &queue->tasks[queue->maxpriority];
 	for (;;) {
 		list_for_each_entry_safe(task, next, head, u.tk_wait.list)
@@ -531,7 +537,8 @@
 			break;
 		head--;
 	}
-	spin_unlock_bh(&queue->lock);
+	spin_unlock(&queue->lock);
+	rcu_read_unlock_bh();
 }
 
 /**
@@ -546,7 +553,8 @@
 	struct rpc_task *task, *next;
 	struct list_head *head;
 
-	spin_lock_bh(&queue->lock);
+	rcu_read_lock_bh();
+	spin_lock(&queue->lock);
 	head = &queue->tasks[queue->maxpriority];
 	for (;;) {
 		list_for_each_entry_safe(task, next, head, u.tk_wait.list) {
@@ -557,7 +565,8 @@
 			break;
 		head--;
 	}
-	spin_unlock_bh(&queue->lock);
+	spin_unlock(&queue->lock);
+	rcu_read_unlock_bh();
 }
 
 static void __rpc_atrun(struct rpc_task *task)
@@ -817,8 +826,9 @@
 	return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOFS);
 }
 
-static void rpc_free_task(struct rpc_task *task)
+static void rpc_free_task(struct rcu_head *rcu)
 {
+	struct rpc_task *task = container_of(rcu, struct rpc_task, u.tk_rcu);
 	dprintk("RPC: %4d freeing task\n", task->tk_pid);
 	mempool_free(task, rpc_task_mempool);
 }
@@ -872,7 +882,7 @@
 		task->tk_client = NULL;
 	}
 	if (task->tk_flags & RPC_TASK_DYNAMIC)
-		rpc_free_task(task);
+		call_rcu_bh(&task->u.tk_rcu, rpc_free_task);
 	if (tk_ops->rpc_release)
 		tk_ops->rpc_release(calldata);
 }