[openib-general] [PATCH] [CM] verify cm_id state before allocating messages

Sean Hefty sean.hefty at intel.com
Thu Jun 30 14:22:47 PDT 2005


The following patch moves verification of the cm_id state higher up
in the CM function calls.  This fixes an issue trying to allocate a
message for a cm_id that may be in an invalid state, which could
result in a crash.

This problem was first reported by Hal a while ago, but was recently
hit by Arlin running a usermode application.  Arlin, can you verify
that this fixes your issues?

Signed-off-by: Sean Hefty <sean.hefty at intel.com>


Index: cm.c
===================================================================
--- cm.c	(revision 2759)
+++ cm.c	(working copy)
@@ -1348,12 +1348,17 @@ int ib_send_cm_rep(struct ib_cm_id *cm_i
 	int ret;
 
 	if (param->private_data &&
-	    param->private_data_len > IB_CM_REP_PRIVATE_DATA_SIZE) {
+	    param->private_data_len > IB_CM_REP_PRIVATE_DATA_SIZE)
+		return -EINVAL;
+
+	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
+	spin_lock_irqsave(&cm_id_priv->lock, flags);
+	if (cm_id->state != IB_CM_REQ_RCVD &&
+	    cm_id->state != IB_CM_MRA_REQ_SENT) {
 		ret = -EINVAL;
 		goto out;
 	}
 
-	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
 	ret = cm_alloc_msg(cm_id_priv, &msg);
 	if (ret)
 		goto out;
@@ -1363,18 +1368,12 @@ int ib_send_cm_rep(struct ib_cm_id *cm_i
 	msg->send_wr.wr.ud.timeout_ms = cm_id_priv->timeout_ms;
 	msg->context[1] = (void *) (unsigned long) IB_CM_REP_SENT;
 
-	spin_lock_irqsave(&cm_id_priv->lock, flags);
-	if (cm_id->state == IB_CM_REQ_RCVD ||
-	    cm_id->state == IB_CM_MRA_REQ_SENT)
-		ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
-				       &msg->send_wr, &bad_send_wr);
-	else
-		ret = -EINVAL;
-
+	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
+			       &msg->send_wr, &bad_send_wr);
 	if (ret) {
 		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		cm_free_msg(msg);
-		goto out;
+		return ret;
 	}
 
 	cm_id->state = IB_CM_REP_SENT;
@@ -1383,8 +1382,8 @@ int ib_send_cm_rep(struct ib_cm_id *cm_i
 	cm_id_priv->responder_resources = param->responder_resources;
 	cm_id_priv->rq_psn = cm_rep_get_starting_psn(rep_msg);
 	cm_id_priv->local_qpn = cm_rep_get_local_qpn(rep_msg);
-	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
+
+out:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_rep);
@@ -1416,39 +1415,41 @@ int ib_send_cm_rtu(struct ib_cm_id *cm_i
 	if (private_data && private_data_len > IB_CM_RTU_PRIVATE_DATA_SIZE)
 		return -EINVAL;
 
+	data = cm_copy_private_data(private_data, private_data_len);
+	if (IS_ERR(data))
+		return PTR_ERR(data);
+
 	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
+	spin_lock_irqsave(&cm_id_priv->lock, flags);
+	if (cm_id->state != IB_CM_REP_RCVD &&
+	    cm_id->state != IB_CM_MRA_REP_SENT) {
+		ret = -EINVAL;
+		goto error;
+	}
+
 	ret = cm_alloc_msg(cm_id_priv, &msg);
 	if (ret)
-		return ret;
-
-	data = cm_copy_private_data(private_data, private_data_len);
-	if (IS_ERR(data)) {
-		ret = PTR_ERR(data);
-		goto error1;
-	}
+		goto error;
 
 	cm_format_rtu((struct cm_rtu_msg *) msg->mad, cm_id_priv,
 		      private_data, private_data_len);
 
-	spin_lock_irqsave(&cm_id_priv->lock, flags);
-	if (cm_id->state == IB_CM_REP_RCVD ||
-	    cm_id->state == IB_CM_MRA_REP_SENT)
-		ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
-				       &msg->send_wr, &bad_send_wr);
-	else
-		ret = -EINVAL;
-
-	if (ret)
-		goto error2;
+	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
+			       &msg->send_wr, &bad_send_wr);
+	if (ret) {
+		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+		cm_free_msg(msg);
+		kfree(data);
+		return ret;
+	}
 
 	cm_id->state = IB_CM_ESTABLISHED;
 	cm_set_private_data(cm_id_priv, data, private_data_len);
 	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return 0;
 
-error2:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+error:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	kfree(data);
-error1:	cm_free_msg(msg);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_rtu);
@@ -1691,38 +1692,41 @@ int ib_send_cm_dreq(struct ib_cm_id *cm_
 	struct ib_mad_send_buf *msg;
 	struct ib_send_wr *bad_send_wr;
 	unsigned long flags;
-	int msg_ret, ret;
+	int ret;
 
 	if (private_data && private_data_len > IB_CM_DREQ_PRIVATE_DATA_SIZE)
 		return -EINVAL;
 
 	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
-	msg_ret = cm_alloc_msg(cm_id_priv, &msg);
-	if (!msg_ret) {
-		cm_format_dreq((struct cm_dreq_msg *) msg->mad, cm_id_priv,
-			       private_data, private_data_len);
-		msg->send_wr.wr.ud.timeout_ms = cm_id_priv->timeout_ms;
-		msg->context[1] = (void *) (unsigned long) IB_CM_DREQ_SENT;
-	}
-
 	spin_lock_irqsave(&cm_id_priv->lock, flags);
 	if (cm_id->state != IB_CM_ESTABLISHED) {
-		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		ret = -EINVAL;
 		goto out;
 	}
-	ret = msg_ret ? msg_ret :
-		ib_post_send_mad(cm_id_priv->av.port->mad_agent,
-				 &msg->send_wr, &bad_send_wr);
-	if (!ret) {
-		cm_id->state = IB_CM_DREQ_SENT;
-		cm_id_priv->msg = msg;
-	} else
+
+	ret = cm_alloc_msg(cm_id_priv, &msg);
+	if (ret) {
 		cm_enter_timewait(cm_id_priv);
-	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
-	if (!msg_ret && ret)
+		goto out;
+	}
+
+	cm_format_dreq((struct cm_dreq_msg *) msg->mad, cm_id_priv,
+		       private_data, private_data_len);
+	msg->send_wr.wr.ud.timeout_ms = cm_id_priv->timeout_ms;
+	msg->context[1] = (void *) (unsigned long) IB_CM_DREQ_SENT;
+
+	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
+			       &msg->send_wr, &bad_send_wr);
+	if (ret) {
+		cm_enter_timewait(cm_id_priv);
+		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		cm_free_msg(msg);
+		return ret;
+	}
+
+	cm_id->state = IB_CM_DREQ_SENT;
+	cm_id_priv->msg = msg;
+out:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_dreq);
@@ -1754,45 +1758,37 @@ int ib_send_cm_drep(struct ib_cm_id *cm_
 	if (private_data && private_data_len > IB_CM_DREP_PRIVATE_DATA_SIZE)
 		return -EINVAL;
 
-	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
-	ret = cm_alloc_msg(cm_id_priv, &msg);
-	if (ret)
-		goto error1;
-
 	data = cm_copy_private_data(private_data, private_data_len);
-	if (IS_ERR(data)) {
-		ret = PTR_ERR(data);
-		goto error2;
-	}
-
-	cm_format_drep((struct cm_drep_msg *) msg->mad, cm_id_priv,
-		       private_data, private_data_len);
+	if (IS_ERR(data))
+		return PTR_ERR(data);
 
+	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
 	spin_lock_irqsave(&cm_id_priv->lock, flags);
 	if (cm_id->state != IB_CM_DREQ_RCVD) {
 		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-		cm_free_msg(msg);
 		kfree(data);
 		return -EINVAL;
 	}
 
+	cm_set_private_data(cm_id_priv, data, private_data_len);
+	cm_enter_timewait(cm_id_priv);
+
+	ret = cm_alloc_msg(cm_id_priv, &msg);
+	if (ret)
+		goto out;
+
+	cm_format_drep((struct cm_drep_msg *) msg->mad, cm_id_priv,
+		       private_data, private_data_len);
+
 	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent, &msg->send_wr,
 			       &bad_send_wr);
-	if (ret)
+	if (ret) {
+		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		cm_free_msg(msg);
+		return ret;
+	}
 
-	cm_set_private_data(cm_id_priv, data, private_data_len);
-	cm_enter_timewait(cm_id_priv);
-	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-	return ret;
-
-error2:
-	cm_free_msg(msg);
-error1:
-	spin_lock_irqsave(&cm_id_priv->lock, flags);
-	if (cm_id->state == IB_CM_DREQ_RCVD)
-		cm_enter_timewait(cm_id_priv);
-	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+out:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_drep);
@@ -1912,18 +1908,13 @@ int ib_send_cm_rej(struct ib_cm_id *cm_i
 	struct ib_mad_send_buf *msg;
 	struct ib_send_wr *bad_send_wr;
 	unsigned long flags;
-	int msg_ret, ret;
+	int ret;
 
 	if ((private_data && private_data_len > IB_CM_REJ_PRIVATE_DATA_SIZE) ||
 	    (ari && ari_length > IB_CM_REJ_ARI_LENGTH))
 		return -EINVAL;
 
 	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
-	msg_ret = cm_alloc_msg(cm_id_priv, &msg);
-	if (!msg_ret)
-		cm_format_rej((struct cm_rej_msg *) msg->mad, cm_id_priv,
-			      reason, ari, ari_length, private_data,
-			      private_data_len);
 
 	spin_lock_irqsave(&cm_id_priv->lock, flags);
 	switch (cm_id->state) {
@@ -1933,25 +1924,38 @@ int ib_send_cm_rej(struct ib_cm_id *cm_i
 	case IB_CM_MRA_REQ_SENT:
 	case IB_CM_REP_RCVD:
 	case IB_CM_MRA_REP_SENT:
+		ret = cm_alloc_msg(cm_id_priv, &msg);
+		if (!ret)
+			cm_format_rej((struct cm_rej_msg *) msg->mad,
+				      cm_id_priv, reason, ari, ari_length,
+				      private_data, private_data_len);
+
 		cm_reset_to_idle(cm_id_priv);
 		break;
 	case IB_CM_REP_SENT:
 	case IB_CM_MRA_REP_RCVD:
+		ret = cm_alloc_msg(cm_id_priv, &msg);
+		if (!ret)
+			cm_format_rej((struct cm_rej_msg *) msg->mad,
+				      cm_id_priv, reason, ari, ari_length,
+				      private_data, private_data_len);
+
 		cm_enter_timewait(cm_id_priv);
 		break;
 	default:
-		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		ret = -EINVAL;
 		goto out;
 	}
-	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 
-	ret = msg_ret ? msg_ret :
-		ib_post_send_mad(cm_id_priv->av.port->mad_agent,
-				 &msg->send_wr, &bad_send_wr);
-out:
-	if (!msg_ret && ret)
+	if (ret)
+		goto out;
+
+	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
+			       &msg->send_wr, &bad_send_wr);
+	if (ret)
 		cm_free_msg(msg);
+
+out:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_rej);
@@ -2078,20 +2082,19 @@ int ib_send_cm_mra(struct ib_cm_id *cm_i
 	if (private_data && private_data_len > IB_CM_MRA_PRIVATE_DATA_SIZE)
 		return -EINVAL;
 
-	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
-	ret = cm_alloc_msg(cm_id_priv, &msg);
-	if (ret)
-		return ret;
-
 	data = cm_copy_private_data(private_data, private_data_len);
-	if (IS_ERR(data)) {
-		ret = PTR_ERR(data);
-		goto error1;
-	}
+	if (IS_ERR(data))
+		return PTR_ERR(data);
+
+	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
 
 	spin_lock_irqsave(&cm_id_priv->lock, flags);
 	switch(cm_id_priv->id.state) {
 	case IB_CM_REQ_RCVD:
+		ret = cm_alloc_msg(cm_id_priv, &msg);
+		if (ret)
+			goto error1;
+
 		cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv,
 			      CM_MSG_RESPONSE_REQ, service_timeout,
 			      private_data, private_data_len);
@@ -2102,6 +2105,10 @@ int ib_send_cm_mra(struct ib_cm_id *cm_i
 		cm_id->state = IB_CM_MRA_REQ_SENT;
 		break;
 	case IB_CM_REP_RCVD:
+		ret = cm_alloc_msg(cm_id_priv, &msg);
+		if (ret)
+			goto error1;
+
 		cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv,
 			      CM_MSG_RESPONSE_REP, service_timeout,
 			      private_data, private_data_len);
@@ -2112,6 +2119,10 @@ int ib_send_cm_mra(struct ib_cm_id *cm_i
 		cm_id->state = IB_CM_MRA_REP_SENT;
 		break;
 	case IB_CM_ESTABLISHED:
+		ret = cm_alloc_msg(cm_id_priv, &msg);
+		if (ret)
+			goto error1;
+
 		cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv,
 			      CM_MSG_RESPONSE_OTHER, service_timeout,
 			      private_data, private_data_len);
@@ -2123,16 +2134,20 @@ int ib_send_cm_mra(struct ib_cm_id *cm_i
 		break;
 	default:
 		ret = -EINVAL;
-		goto error2;
+		goto error1;
 	}
 	cm_id_priv->service_timeout = service_timeout;
 	cm_set_private_data(cm_id_priv, data, private_data_len);
 	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return 0;
 
+error1:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+	kfree(data);
+	return ret;
+
 error2:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	kfree(data);
-error1:	cm_free_msg(msg);
+	cm_free_msg(msg);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_mra);
@@ -2260,6 +2275,13 @@ int ib_send_cm_lap(struct ib_cm_id *cm_i
 		return -EINVAL;
 
 	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
+	spin_lock_irqsave(&cm_id_priv->lock, flags);
+	if (cm_id->state != IB_CM_ESTABLISHED ||
+	    cm_id->lap_state != IB_CM_LAP_IDLE) {
+		ret = -EINVAL;
+		goto out;
+	}
+
 	ret = cm_alloc_msg(cm_id_priv, &msg);
 	if (ret)
 		goto out;
@@ -2269,24 +2291,18 @@ int ib_send_cm_lap(struct ib_cm_id *cm_i
 	msg->send_wr.wr.ud.timeout_ms = cm_id_priv->timeout_ms;
 	msg->context[1] = (void *) (unsigned long) IB_CM_ESTABLISHED;
 
-	spin_lock_irqsave(&cm_id_priv->lock, flags);
-	if (cm_id->state == IB_CM_ESTABLISHED &&
-	    cm_id->lap_state == IB_CM_LAP_IDLE)
-		ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
-				       &msg->send_wr, &bad_send_wr);
-	else
-		ret = -EINVAL;
-
+	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
+			       &msg->send_wr, &bad_send_wr);
 	if (ret) {
 		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		cm_free_msg(msg);
-		goto out;
+		return ret;
 	}
 
 	cm_id->lap_state = IB_CM_LAP_SENT;
 	cm_id_priv->msg = msg;
-	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
+
+out:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_lap);
@@ -2420,30 +2436,30 @@ int ib_send_cm_apr(struct ib_cm_id *cm_i
 		return -EINVAL;
 
 	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
+	spin_lock_irqsave(&cm_id_priv->lock, flags);
+	if (cm_id->state != IB_CM_ESTABLISHED ||
+	    (cm_id->lap_state != IB_CM_LAP_RCVD &&
+	     cm_id->lap_state != IB_CM_MRA_LAP_SENT)) {
+		ret = -EINVAL;
+		goto out;
+	}
+
 	ret = cm_alloc_msg(cm_id_priv, &msg);
 	if (ret)
 		goto out;
 
 	cm_format_apr((struct cm_apr_msg *) msg->mad, cm_id_priv, status,
 		      info, info_length, private_data, private_data_len);
-
-	spin_lock_irqsave(&cm_id_priv->lock, flags);
-	if (cm_id->state == IB_CM_ESTABLISHED &&
-	    (cm_id->lap_state == IB_CM_LAP_RCVD ||
-	     cm_id->lap_state == IB_CM_MRA_LAP_SENT))
-		ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
-				       &msg->send_wr, &bad_send_wr);
-	else
-		ret = -EINVAL;
-
+	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
+			       &msg->send_wr, &bad_send_wr);
 	if (ret) {
 		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		cm_free_msg(msg);
-		goto out;
+		return ret;
 	}
+
 	cm_id->lap_state = IB_CM_LAP_IDLE;
-	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
+out:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_apr);
@@ -2703,24 +2719,24 @@ int ib_send_cm_sidr_rep(struct ib_cm_id 
 		return -EINVAL;
 
 	cm_id_priv = container_of(cm_id, struct cm_id_private, id);
+	spin_lock_irqsave(&cm_id_priv->lock, flags);
+	if (cm_id->state != IB_CM_SIDR_REQ_RCVD) {
+		ret = -EINVAL;
+		goto error;
+	}
+
 	ret = cm_alloc_msg(cm_id_priv, &msg);
 	if (ret)
-		goto out;
+		goto error;
 
 	cm_format_sidr_rep((struct cm_sidr_rep_msg *) msg->mad, cm_id_priv,
 			   param);
-
-	spin_lock_irqsave(&cm_id_priv->lock, flags);
-	if (cm_id->state == IB_CM_SIDR_REQ_RCVD)
-		ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
-				       &msg->send_wr, &bad_send_wr);
-	else
-		ret = -EINVAL;
-
+	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
+			       &msg->send_wr, &bad_send_wr);
 	if (ret) {
 		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		cm_free_msg(msg);
-		goto out;
+		return ret;
 	}
 	cm_id->state = IB_CM_IDLE;
 	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
@@ -2728,7 +2744,9 @@ int ib_send_cm_sidr_rep(struct ib_cm_id 
 	spin_lock_irqsave(&cm.lock, flags);
 	rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table);
 	spin_unlock_irqrestore(&cm.lock, flags);
-out:
+	return 0;
+
+error:	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_sidr_rep);






More information about the general mailing list