[openib-general] [PATCH 2 of 2] ipoib: ipoib_multicast.c cleanup

Michael S. Tsirkin mst at mellanox.co.il
Mon Nov 21 08:36:53 PST 2005


Fix several race conditions in ipoib_multicast.c:
1. Make sure mcast->query is set to NULL if, and only if,
   no query is outstanding.
2. Make sure mcast->done is initialized to uncompleted value
   before we submit a new query, so that its safe to wait on.
4. Protect all accesses to priv->broadcast, priv->multicast_list,
   mcast->query and mcast->done by priv->lock.
   I had to change mcast_mutex to ipoib_mcast_lock to make the last bit work.

Signed-off-by: Michael S. Tsirkin <mst at mellanox.co.il>

Index: linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_main.c
===================================================================
--- linux-2.6.14-dbg.orig/drivers/infiniband/ulp/ipoib/ipoib_main.c	2005-11-20 11:57:00.000000000 +0200
+++ linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_main.c	2005-11-20 14:57:18.000000000 +0200
@@ -1146,6 +1146,8 @@ static int __init ipoib_init_module(void
 	if (ret)
 		goto err_wq;
 
+	spin_lock_init(&ipoib_mcast_lock);
+
 	return 0;
 
 err_wq:
Index: linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_multicast.c
===================================================================
--- linux-2.6.14-dbg.orig/drivers/infiniband/ulp/ipoib/ipoib_multicast.c	2005-11-20 12:34:04.000000000 +0200
+++ linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_multicast.c	2005-11-20 14:57:18.000000000 +0200
@@ -53,7 +53,7 @@ MODULE_PARM_DESC(mcast_debug_level,
 		 "Enable multicast debug tracing if > 0");
 #endif
 
-static DECLARE_MUTEX(mcast_mutex);
+spinlock_t ipoib_mcast_lock;
 
 /* Used for all multicast joins (broadcast, IPv4 mcast and IPv6 mcast) */
 struct ipoib_mcast {
@@ -126,17 +126,14 @@ static void ipoib_mcast_free(struct ipoi
 	kfree(mcast);
 }
 
-static struct ipoib_mcast *ipoib_mcast_alloc(struct net_device *dev,
-					     int can_sleep)
+static struct ipoib_mcast *ipoib_mcast_alloc(struct net_device *dev)
 {
 	struct ipoib_mcast *mcast;
 
-	mcast = kzalloc(sizeof *mcast, can_sleep ? GFP_KERNEL : GFP_ATOMIC);
+	mcast = kzalloc(sizeof *mcast, GFP_ATOMIC);
 	if (!mcast)
 		return NULL;
 
-	init_completion(&mcast->done);
-
 	mcast->dev = dev;
 	mcast->created = jiffies;
 	mcast->backoff = 1;
@@ -209,17 +206,23 @@ static int ipoib_mcast_join_finish(struc
 {
 	struct net_device *dev = mcast->dev;
 	struct ipoib_dev_priv *priv = netdev_priv(dev);
+	unsigned long flags;
 	int ret;
 
 	mcast->mcmember = *mcmember;
 
+	spin_lock_irqsave(&priv->lock, flags);
+
 	/* Set the cached Q_Key before we attach if it's the broadcast group */
-	if (!memcmp(mcast->mcmember.mgid.raw, priv->dev->broadcast + 4,
+	if (priv->broadcast &&
+	    !memcmp(mcast->mcmember.mgid.raw, priv->dev->broadcast + 4,
 		    sizeof (union ib_gid))) {
 		priv->qkey = be32_to_cpu(priv->broadcast->mcmember.qkey);
 		priv->tx_wr.wr.ud.remote_qkey = priv->qkey;
 	}
 
+	spin_unlock_irqrestore(&priv->lock, flags);
+
 	if (!test_bit(IPOIB_MCAST_FLAG_SENDONLY, &mcast->flags)) {
 		if (test_and_set_bit(IPOIB_MCAST_FLAG_ATTACHED, &mcast->flags)) {
 			ipoib_warn(priv, "multicast group " IPOIB_GID_FMT
@@ -303,6 +306,12 @@ ipoib_mcast_sendonly_join_complete(int s
 {
 	struct ipoib_mcast *mcast = mcast_ptr;
 	struct net_device *dev = mcast->dev;
+	struct ipoib_dev_priv *priv = netdev_priv(dev);
+	unsigned long flags;
+
+	ipoib_dbg_mcast(priv, "sendonly join completion for " IPOIB_GID_FMT
+			" (status %d)\n",
+			IPOIB_GID_ARG(mcast->mcmember.mgid), status);
 
 	if (!status)
 		ipoib_mcast_join_finish(mcast, mcmember);
@@ -320,7 +329,11 @@ ipoib_mcast_sendonly_join_complete(int s
 		clear_bit(IPOIB_MCAST_FLAG_BUSY, &mcast->flags);
 	}
 
+	spin_lock_irqsave(&priv->lock, flags);
+	mcast->query = NULL;
+
 	complete(&mcast->done);
+	spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 static int ipoib_mcast_sendonly_join(struct ipoib_mcast *mcast)
@@ -350,6 +363,7 @@ static int ipoib_mcast_sendonly_join(str
 	rec.port_gid = priv->local_gid;
 	rec.pkey     = cpu_to_be16(priv->pkey);
 
+	init_completion(&mcast->done);
 	ret = ib_sa_mcmember_rec_set(priv->ca, priv->port, &rec,
 				     IB_SA_MCMEMBER_REC_MGID		|
 				     IB_SA_MCMEMBER_REC_PORT_GID	|
@@ -379,23 +393,31 @@ static void ipoib_mcast_join_complete(in
 	struct ipoib_mcast *mcast = mcast_ptr;
 	struct net_device *dev = mcast->dev;
 	struct ipoib_dev_priv *priv = netdev_priv(dev);
+	unsigned long flags;
 
 	ipoib_dbg_mcast(priv, "join completion for " IPOIB_GID_FMT
 			" (status %d)\n",
 			IPOIB_GID_ARG(mcast->mcmember.mgid), status);
 
+
 	if (!status && !ipoib_mcast_join_finish(mcast, mcmember)) {
 		mcast->backoff = 1;
-		down(&mcast_mutex);
+		spin_lock(&ipoib_mcast_lock);
 		if (test_bit(IPOIB_MCAST_RUN, &priv->flags))
 			queue_work(ipoib_workqueue, &priv->mcast_task);
-		up(&mcast_mutex);
+		spin_unlock(&ipoib_mcast_lock);
+		spin_lock_irqsave(&priv->lock, flags);
+		mcast->query = NULL;
 		complete(&mcast->done);
+		spin_unlock_irqrestore(&priv->lock, flags);
 		return;
 	}
 
 	if (status == -EINTR) {
+		spin_lock_irqsave(&priv->lock, flags);
+		mcast->query = NULL;
 		complete(&mcast->done);
+		spin_unlock_irqrestore(&priv->lock, flags);
 		return;
 	}
 
@@ -417,20 +439,21 @@ static void ipoib_mcast_join_complete(in
 	if (mcast->backoff > IPOIB_MAX_BACKOFF_SECONDS)
 		mcast->backoff = IPOIB_MAX_BACKOFF_SECONDS;
 
-	mcast->query = NULL;
+	spin_lock_irqsave(&priv->lock, flags);
 
-	down(&mcast_mutex);
+	spin_lock(&ipoib_mcast_lock);
 	if (test_bit(IPOIB_MCAST_RUN, &priv->flags)) {
 		if (status == -ETIMEDOUT)
 			queue_work(ipoib_workqueue, &priv->mcast_task);
 		else
 			queue_delayed_work(ipoib_workqueue, &priv->mcast_task,
 					   mcast->backoff * HZ);
-	} else
-		complete(&mcast->done);
-	up(&mcast_mutex);
+	}
+	spin_unlock(&ipoib_mcast_lock);
 
-	return;
+	mcast->query = NULL;
+	complete(&mcast->done);
+	spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 static void ipoib_mcast_join(struct net_device *dev, struct ipoib_mcast *mcast,
@@ -469,6 +492,7 @@ static void ipoib_mcast_join(struct net_
 		rec.traffic_class = priv->broadcast->mcmember.traffic_class;
 	}
 
+	init_completion(&mcast->done);
 	ret = ib_sa_mcmember_rec_set(priv->ca, priv->port, &rec, comp_mask,
 				     mcast->backoff * 1000, GFP_ATOMIC,
 				     ipoib_mcast_join_complete,
@@ -481,12 +505,12 @@ static void ipoib_mcast_join(struct net_
 		if (mcast->backoff > IPOIB_MAX_BACKOFF_SECONDS)
 			mcast->backoff = IPOIB_MAX_BACKOFF_SECONDS;
 
-		down(&mcast_mutex);
+		spin_lock(&ipoib_mcast_lock);
 		if (test_bit(IPOIB_MCAST_RUN, &priv->flags))
 			queue_delayed_work(ipoib_workqueue,
 					   &priv->mcast_task,
 					   mcast->backoff * HZ);
-		up(&mcast_mutex);
+		spin_unlock(&ipoib_mcast_lock);
 	} else
 		mcast->query_id = ret;
 }
@@ -515,44 +539,44 @@ void ipoib_mcast_join_task(void *dev_ptr
 			ipoib_warn(priv, "ib_query_port failed\n");
 	}
 
+	spin_lock_irq(&priv->lock);
+
 	if (!priv->broadcast) {
-		priv->broadcast = ipoib_mcast_alloc(dev, 1);
+		priv->broadcast = ipoib_mcast_alloc(dev);
 		if (!priv->broadcast) {
 			ipoib_warn(priv, "failed to allocate broadcast group\n");
-			down(&mcast_mutex);
+			spin_lock(&ipoib_mcast_lock);
 			if (test_bit(IPOIB_MCAST_RUN, &priv->flags))
 				queue_delayed_work(ipoib_workqueue,
 						   &priv->mcast_task, HZ);
-			up(&mcast_mutex);
-			return;
+			spin_unlock(&ipoib_mcast_lock);
+			goto unlock;
 		}
 
 		memcpy(priv->broadcast->mcmember.mgid.raw, priv->dev->broadcast + 4,
 		       sizeof (union ib_gid));
 
-		spin_lock_irq(&priv->lock);
 		__ipoib_mcast_add(dev, priv->broadcast);
-		spin_unlock_irq(&priv->lock);
 	}
 
-	if (!test_bit(IPOIB_MCAST_FLAG_ATTACHED, &priv->broadcast->flags)) {
+	if (!test_bit(IPOIB_MCAST_FLAG_ATTACHED, &priv->broadcast->flags) &&
+	    !priv->broadcast->query) {
 		ipoib_mcast_join(dev, priv->broadcast, 0);
-		return;
+		goto unlock;
 	}
 
 	while (1) {
 		struct ipoib_mcast *mcast = NULL;
 
-		spin_lock_irq(&priv->lock);
 		list_for_each_entry(mcast, &priv->multicast_list, list) {
 			if (!test_bit(IPOIB_MCAST_FLAG_SENDONLY, &mcast->flags)
 			    && !test_bit(IPOIB_MCAST_FLAG_BUSY, &mcast->flags)
-			    && !test_bit(IPOIB_MCAST_FLAG_ATTACHED, &mcast->flags)) {
+			    && !test_bit(IPOIB_MCAST_FLAG_ATTACHED, &mcast->flags)
+			    && !mcast->query) {
 				/* Found the next unjoined group */
 				break;
 			}
 		}
-		spin_unlock_irq(&priv->lock);
 
 		if (&mcast->list == &priv->multicast_list) {
 			/* All done */
@@ -560,7 +584,7 @@ void ipoib_mcast_join_task(void *dev_ptr
 		}
 
 		ipoib_mcast_join(dev, mcast, 1);
-		return;
+		goto unlock;
 	}
 
 	priv->mcast_mtu = ib_mtu_enum_to_int(priv->broadcast->mcmember.mtu) -
@@ -571,48 +595,59 @@ void ipoib_mcast_join_task(void *dev_ptr
 
 	clear_bit(IPOIB_MCAST_RUN, &priv->flags);
 	netif_carrier_on(dev);
+
+unlock:
+	spin_unlock_irq(&priv->lock);
 }
 
 static void ipoib_mcast_start_thread(struct net_device *dev)
 {
 	struct ipoib_dev_priv *priv = netdev_priv(dev);
+	unsigned long flags;
 
 	ipoib_dbg_mcast(priv, "starting multicast thread\n");
 
-	down(&mcast_mutex);
+	spin_lock_irqsave(&ipoib_mcast_lock, flags);
 	if (!test_and_set_bit(IPOIB_MCAST_RUN, &priv->flags))
 		queue_work(ipoib_workqueue, &priv->mcast_task);
-	up(&mcast_mutex);
+	spin_unlock_irqrestore(&ipoib_mcast_lock, flags);
 }
 
 static void ipoib_mcast_stop_thread(struct net_device *dev)
 {
 	struct ipoib_dev_priv *priv = netdev_priv(dev);
 	struct ipoib_mcast *mcast;
+	unsigned long flags;
 
 	ipoib_dbg_mcast(priv, "stopping multicast thread\n");
 
-	down(&mcast_mutex);
+	spin_lock_irqsave(&priv->lock, flags);
+
+	spin_lock(&ipoib_mcast_lock);
 	clear_bit(IPOIB_MCAST_RUN, &priv->flags);
 	cancel_delayed_work(&priv->mcast_task);
-	up(&mcast_mutex);
+	spin_unlock(&ipoib_mcast_lock);
 
 	if (priv->broadcast && priv->broadcast->query) {
 		ib_sa_cancel_query(priv->broadcast->query_id, priv->broadcast->query);
-		priv->broadcast->query = NULL;
+		spin_unlock_irqrestore(&priv->lock, flags);
 		ipoib_dbg_mcast(priv, "waiting for bcast\n");
 		wait_for_completion(&priv->broadcast->done);
+		spin_lock_irqsave(&priv->lock, flags);
 	}
 
 	list_for_each_entry(mcast, &priv->multicast_list, list) {
 		if (mcast->query) {
 			ib_sa_cancel_query(mcast->query_id, mcast->query);
-			mcast->query = NULL;
+			spin_unlock_irqrestore(&priv->lock, flags);
 			ipoib_dbg_mcast(priv, "waiting for MGID " IPOIB_GID_FMT "\n",
 					IPOIB_GID_ARG(mcast->mcmember.mgid));
 			wait_for_completion(&mcast->done);
+			spin_lock_irqsave(&priv->lock, flags);
 		}
 	}
+
+	spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 static int ipoib_mcast_leave(struct net_device *dev, struct ipoib_mcast *mcast)
@@ -621,6 +656,7 @@ static int ipoib_mcast_leave(struct net_
 	struct ib_sa_mcmember_rec rec = {
 		.join_state = 1
 	};
+	struct ib_sa_query *query;
 	int ret = 0;
 
 	if (!test_and_clear_bit(IPOIB_MCAST_FLAG_ATTACHED, &mcast->flags))
@@ -629,6 +665,8 @@ static int ipoib_mcast_leave(struct net_
 	ipoib_dbg_mcast(priv, "leaving MGID " IPOIB_GID_FMT "\n",
 			IPOIB_GID_ARG(mcast->mcmember.mgid));
 
+	BUG_ON(mcast->query);
+
 	rec.mgid     = mcast->mcmember.mgid;
 	rec.port_gid = priv->local_gid;
 	rec.pkey     = cpu_to_be16(priv->pkey);
@@ -649,7 +687,7 @@ static int ipoib_mcast_leave(struct net_
 					IB_SA_MCMEMBER_REC_PKEY		|
 					IB_SA_MCMEMBER_REC_JOIN_STATE,
 					0, GFP_ATOMIC, NULL,
-					mcast, &mcast->query);
+					mcast, &query);
 	if (ret < 0)
 		ipoib_warn(priv, "ib_sa_mcmember_rec_delete failed "
 			   "for leave (result = %d)\n", ret);
@@ -675,7 +713,7 @@ void ipoib_mcast_send(struct net_device 
 		ipoib_dbg_mcast(priv, "setting up send only multicast group for "
 				IPOIB_GID_FMT "\n", IPOIB_GID_ARG(*mgid));
 
-		mcast = ipoib_mcast_alloc(dev, 0);
+		mcast = ipoib_mcast_alloc(dev);
 		if (!mcast) {
 			ipoib_warn(priv, "unable to allocate memory for "
 				   "multicast structure\n");
@@ -741,7 +779,7 @@ static void ipoib_mcast_dev_flush(struct
 
 	spin_lock_irqsave(&priv->lock, flags);
 	list_for_each_entry_safe(mcast, tmcast, &priv->multicast_list, list) {
-		nmcast = ipoib_mcast_alloc(dev, 0);
+		nmcast = ipoib_mcast_alloc(dev);
 		if (nmcast) {
 			nmcast->flags =
 				mcast->flags & (1 << IPOIB_MCAST_FLAG_SENDONLY);
@@ -764,17 +802,16 @@ static void ipoib_mcast_dev_flush(struct
 	}
 
 	if (priv->broadcast) {
-		nmcast = ipoib_mcast_alloc(dev, 0);
+		nmcast = ipoib_mcast_alloc(dev);
 		if (nmcast) {
 			nmcast->mcmember.mgid = priv->broadcast->mcmember.mgid;
 
 			rb_replace_node(&priv->broadcast->rb_node,
 					&nmcast->rb_node,
 					&priv->multicast_tree);
-
-			list_add_tail(&priv->broadcast->list, &remove_list);
 		}
 
+		list_add_tail(&priv->broadcast->list, &remove_list);
 		priv->broadcast = nmcast;
 	}
 
@@ -789,19 +826,23 @@ static void ipoib_mcast_dev_flush(struct
 static void ipoib_mcast_dev_down(struct net_device *dev)
 {
 	struct ipoib_dev_priv *priv = netdev_priv(dev);
+	struct ipoib_mcast *mcast;
 	unsigned long flags;
 
+	spin_lock_irqsave(&priv->lock, flags);
+
 	/* Delete broadcast since it will be recreated */
 	if (priv->broadcast) {
 		ipoib_dbg_mcast(priv, "deleting broadcast group\n");
 
-		spin_lock_irqsave(&priv->lock, flags);
 		rb_erase(&priv->broadcast->rb_node, &priv->multicast_tree);
-		spin_unlock_irqrestore(&priv->lock, flags);
-		ipoib_mcast_leave(dev, priv->broadcast);
-		ipoib_mcast_free(priv->broadcast);
+		mcast = priv->broadcast;
 		priv->broadcast = NULL;
-	}
+		spin_unlock_irqrestore(&priv->lock, flags);
+		ipoib_mcast_leave(dev, mcast);
+		ipoib_mcast_free(mcast);
+	} else
+		spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 void ipoib_mcast_restart_task(void *dev_ptr)
@@ -847,7 +888,7 @@ void ipoib_mcast_restart_task(void *dev_
 			ipoib_dbg_mcast(priv, "adding multicast entry for mgid "
 					IPOIB_GID_FMT "\n", IPOIB_GID_ARG(mgid));
 
-			nmcast = ipoib_mcast_alloc(dev, 0);
+			nmcast = ipoib_mcast_alloc(dev);
 			if (!nmcast) {
 				ipoib_warn(priv, "unable to allocate memory for multicast structure\n");
 				continue;
Index: linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib.h
===================================================================
--- linux-2.6.14-dbg.orig/drivers/infiniband/ulp/ipoib/ipoib.h	2005-11-20 12:18:43.000000000 +0200
+++ linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib.h	2005-11-20 14:56:53.000000000 +0200
@@ -226,6 +226,7 @@ static inline struct ipoib_neigh **to_ip
 }
 
 extern struct workqueue_struct *ipoib_workqueue;
+extern spinlock_t ipoib_mcast_lock;
 
 /* functions */
 
-- 
MST



More information about the general mailing list