[openib-general] [PATCH] [CM] reject stale connections

Sean Hefty sean.hefty at intel.com
Fri May 13 13:36:52 PDT 2005


The following patch will reject stale connection requests.
It also fixes a couple of minor error handling cleanup issues
found during testing, and fixes a bug handling a stale
connection reject message where the cm_id would access an
invalid timewait pointer.

Signed-off-by: Sean Hefty <sean.hefty at intel.com>


Index: core/cm.c
===================================================================
--- core/cm.c	(revision 2346)
+++ core/cm.c	(working copy)
@@ -185,10 +185,35 @@ static int cm_alloc_msg(struct cm_id_pri
 	return 0;
 }
 
+static int cm_alloc_response_msg(struct cm_port *port,
+				 struct ib_mad_recv_wc *mad_recv_wc,
+				 struct ib_mad_send_buf **msg)
+{
+	struct ib_mad_send_buf *m;
+	struct ib_ah *ah;
+
+	ah = ib_create_ah_from_wc(port->mad_agent->qp->pd, mad_recv_wc->wc,
+				  mad_recv_wc->recv_buf.grh, port->port_num);
+	if (IS_ERR(ah))
+		return PTR_ERR(ah);
+
+	m = ib_create_send_mad(port->mad_agent, 1, mad_recv_wc->wc->pkey_index,
+			       ah, 0, sizeof(struct ib_mad_hdr),
+			       sizeof(struct ib_mad)-sizeof(struct ib_mad_hdr),
+			       GFP_ATOMIC);
+	if (IS_ERR(m)) {
+		ib_destroy_ah(ah);
+		return PTR_ERR(m);
+	}
+	*msg = m;
+	return 0;
+}
+
 static void cm_free_msg(struct ib_mad_send_buf *msg)
 {
 	ib_destroy_ah(msg->send_wr.wr.ud.ah);
-	cm_deref_id(msg->context[0]);
+	if (msg->context[0])
+		cm_deref_id(msg->context[0]);
 	ib_free_send_mad(msg);
 }
 
@@ -531,10 +556,7 @@ static void cm_cleanup_timewait(struct c
 	spin_unlock_irqrestore(&cm.lock, flags);
 }
 
-static struct cm_timewait_info * cm_create_timewait_info(u32 local_id,
-							 u32 remote_id,
-							 u64 remote_ca_guid,
-							 u32 remote_qpn)
+static struct cm_timewait_info * cm_create_timewait_info(u32 local_id)
 {
 	struct cm_timewait_info *timewait_info;
 
@@ -544,10 +566,6 @@ static struct cm_timewait_info * cm_crea
 	memset(timewait_info, 0, sizeof *timewait_info);
 
 	timewait_info->work.local_id = local_id;
-	timewait_info->work.remote_id = remote_id;
-	timewait_info->remote_ca_guid = remote_ca_guid;
-	timewait_info->remote_qpn = remote_qpn;
-
 	INIT_WORK(&timewait_info->work.work, cm_work_handler,
 		  &timewait_info->work);
 	timewait_info->work.cm_event.event = IB_CM_TIMEWAIT_EXIT;
@@ -674,30 +692,33 @@ int ib_cm_listen(struct ib_cm_id *cm_id,
 }
 EXPORT_SYMBOL(ib_cm_listen);
 
-static void cm_format_mad_hdr(struct ib_mad_hdr *hdr,
-			      struct cm_id_private *cm_id_priv,
-			      enum cm_msg_attr_id attr_id,
-			      enum cm_msg_sequence msg_seq)
+static u64 cm_form_tid(struct cm_id_private *cm_id_priv,
+		       enum cm_msg_sequence msg_seq)
 {
 	u64 hi_tid, low_tid;
 
+	hi_tid   = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
+	low_tid  = (u64) (cm_id_priv->id.local_id | (msg_seq << 30));
+	return cpu_to_be64(hi_tid | low_tid);
+}
+
+static void cm_format_mad_hdr(struct ib_mad_hdr *hdr,
+			      enum cm_msg_attr_id attr_id, u64 tid)
+{
 	hdr->base_version  = IB_MGMT_BASE_VERSION;
 	hdr->mgmt_class	   = IB_MGMT_CLASS_CM;
 	hdr->class_version = IB_CM_CLASS_VERSION;
 	hdr->method	   = IB_MGMT_METHOD_SEND;
 	hdr->attr_id	   = attr_id;
-
-	hi_tid   = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
-	low_tid  = (u64) (cm_id_priv->id.local_id | (msg_seq << 30));
-	hdr->tid = cpu_to_be64(hi_tid | low_tid);
+	hdr->tid	   = tid;
 }
 
 static void cm_format_req(struct cm_req_msg *req_msg,
 			  struct cm_id_private *cm_id_priv,
 			  struct ib_cm_req_param *param)
 {
-	cm_format_mad_hdr(&req_msg->hdr, cm_id_priv,
-			  CM_REQ_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+	cm_format_mad_hdr(&req_msg->hdr, CM_REQ_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
 	req_msg->local_comm_id = cm_id_priv->id.local_id;
 	req_msg->service_id = param->service_id;
@@ -755,6 +776,10 @@ static void cm_format_req(struct cm_req_
 
 static inline int cm_validate_req_param(struct ib_cm_req_param *param)
 {
+	/* peer-to-peer not supported */
+	if (param->peer_to_peer)
+		return -EINVAL;
+
 	if (!param->primary_path)
 		return -EINVAL;
 
@@ -796,14 +821,19 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
 	}
 	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 
+	cm_id_priv->timewait_info = cm_create_timewait_info(cm_id_priv->
+							    id.local_id);
+	if (IS_ERR(cm_id_priv->timewait_info))
+		goto out;
+
 	ret = cm_init_av_by_path(param->primary_path, &cm_id_priv->av);
 	if (ret)
-		goto out;
+		goto error1;
 	if (param->alternate_path) {
 		ret = cm_init_av_by_path(param->alternate_path,
 					 &cm_id_priv->alt_av);
 		if (ret)
-			goto out;
+			goto error1;
 	}
 	cm_id->service_id = param->service_id;
 	cm_id->service_mask = ~0ULL;
@@ -819,7 +849,7 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
 
 	ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg);
 	if (ret)
-		goto out;
+		goto error1;
 
 	req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad;
 	cm_format_req(req_msg, cm_id_priv, param);
@@ -831,35 +861,61 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
 	cm_id_priv->local_ack_timeout =
 				cm_req_get_primary_local_ack_timeout(req_msg);
 
-	/*
-	 * Received REQs won't match until we're in REQ_SENT state.  This
-	 * simplifies error recovery if the send fails.
-	 */
-	if (param->peer_to_peer) {
-		ret = -EINVAL;
-		goto out;
-	}
-
 	spin_lock_irqsave(&cm_id_priv->lock, flags);
 	ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
 				&cm_id_priv->msg->send_wr, &bad_send_wr);
-
 	if (ret) {
 		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-		/* if (param->peer_to_peer) {
-			cleanup peer_service_table
-		} */
-		cm_free_msg(cm_id_priv->msg);
-		goto out;
+		goto error2;
 	}
 	BUG_ON(cm_id->state != IB_CM_IDLE);
 	cm_id->state = IB_CM_REQ_SENT;
 	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
-	return ret;
+	return 0;
+
+error2:	cm_free_msg(cm_id_priv->msg);
+error1:	kfree(cm_id_priv->timewait_info);
+out:	return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_req);
 
+static int cm_issue_rej(struct cm_port *port,
+			struct ib_mad_recv_wc *mad_recv_wc,
+			enum ib_cm_rej_reason reason,
+			enum cm_msg_response msg_rejected,
+			void *ari, u8 ari_length)
+{
+	struct ib_mad_send_buf *msg;
+	struct ib_send_wr *bad_send_wr;
+	struct cm_rej_msg *rej_msg, *rcv_msg;
+	int ret;
+
+	ret = cm_alloc_response_msg(port, mad_recv_wc, &msg);
+	if (ret)
+		return ret;
+	
+	/* We just need common CM header information.  Cast to any message. */
+	rcv_msg = (struct cm_rej_msg *) mad_recv_wc->recv_buf.mad;
+	rej_msg = (struct cm_rej_msg *) msg->mad;
+
+	cm_format_mad_hdr(&rej_msg->hdr, CM_REJ_ATTR_ID, rcv_msg->hdr.tid);
+	rej_msg->remote_comm_id = rcv_msg->local_comm_id;
+	rej_msg->local_comm_id = rcv_msg->remote_comm_id;
+	cm_rej_set_msg_rejected(rej_msg, msg_rejected);
+	rej_msg->reason = reason;
+
+	if (ari && ari_length) {
+		cm_rej_set_reject_info_len(rej_msg, ari_length);
+		memcpy(rej_msg->ari, ari, ari_length);
+	}
+
+	ret = ib_post_send_mad(port->mad_agent, &msg->send_wr, &bad_send_wr);
+	if (ret)
+		cm_free_msg(msg);
+
+	return ret;
+}
+
 static inline int cm_is_active_peer(u64 local_ca_guid, u64 remote_ca_guid,
 				    u32 local_qpn, u32 remote_qpn)
 {
@@ -992,29 +1048,31 @@ static int cm_req_handler(struct cm_work
 	cm_id_priv->id.remote_id = req_msg->local_comm_id;
 	cm_init_av_for_response(work->port, work->mad_recv_wc->wc,
 				&cm_id_priv->av);
-	cm_id_priv->timewait_info = cm_create_timewait_info(
-						cm_id_priv->id.local_id,
-						cm_id_priv->id.remote_id,
-						req_msg->local_ca_guid,
-						cm_req_get_local_qpn(req_msg));
+	cm_id_priv->timewait_info = cm_create_timewait_info(cm_id_priv->
+							    id.local_id);
 	if (IS_ERR(cm_id_priv->timewait_info)) {
 		ret = PTR_ERR(cm_id_priv->timewait_info);
 		goto error1;
 	}
+	cm_id_priv->timewait_info->work.remote_id = req_msg->local_comm_id;
+	cm_id_priv->timewait_info->remote_ca_guid = req_msg->local_ca_guid;
+	cm_id_priv->timewait_info->remote_qpn = cm_req_get_local_qpn(req_msg);
 
 	spin_lock_irqsave(&cm.lock, flags);
 	/* Check for duplicate REQ. */
 	if (cm_insert_remote_id(cm_id_priv->timewait_info)) {
 		spin_unlock_irqrestore(&cm.lock, flags);
 		ret = -EINVAL;
-		goto error1;
+		goto error2;
 	}
 	/* Check for a stale connection. */
 	if (cm_insert_remote_qpn(cm_id_priv->timewait_info)) {
 		spin_unlock_irqrestore(&cm.lock, flags);
-		/* todo: reject as stale */
+		cm_issue_rej(work->port, work->mad_recv_wc,
+			     IB_CM_REJ_STALE_CONN, CM_MSG_RESPONSE_REQ,
+			     NULL, 0);
 		ret = -EINVAL;
-		goto error1;
+		goto error2;
 	}
 	/* Find matching listen request. */
 	listen_cm_id_priv = cm_find_listen(req_msg->service_id);
@@ -1035,11 +1093,11 @@ static int cm_req_handler(struct cm_work
 	cm_format_paths_from_req(req_msg, &work->path[0], &work->path[1]);
 	ret = cm_init_av_by_path(&work->path[0], &cm_id_priv->av);
 	if (ret)
-		goto error2;
+		goto error3;
 	if (req_msg->alt_local_lid) {
 		ret = cm_init_av_by_path(&work->path[1], &cm_id_priv->alt_av);
 		if (ret)
-			goto error2;
+			goto error3;
 	}
 	cm_id_priv->timeout_ms = cm_convert_to_ms(
 					cm_req_get_local_resp_timeout(req_msg));
@@ -1058,11 +1116,12 @@ static int cm_req_handler(struct cm_work
 	cm_process_work(cm_id_priv, work);
 	cm_deref_id(listen_cm_id_priv);
 	return 0;
-error2:
-	atomic_dec(&cm_id_priv->refcount);
+
+error3:	atomic_dec(&cm_id_priv->refcount);
 	cm_deref_id(listen_cm_id_priv);
-error1:
-	ib_destroy_cm_id(&cm_id_priv->id);
+error2:	cm_cleanup_timewait(cm_id_priv->timewait_info);
+	kfree(cm_id_priv->timewait_info);
+error1:	ib_destroy_cm_id(&cm_id_priv->id);
 	return ret;
 }
 
@@ -1070,8 +1129,9 @@ static void cm_format_rep(struct cm_rep_
 			  struct cm_id_private *cm_id_priv,
 			  struct ib_cm_rep_param *param)
 {
-	cm_format_mad_hdr(&rep_msg->hdr, cm_id_priv,
-			  CM_REP_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+	/* todo: TID should match received REQ */
+	cm_format_mad_hdr(&rep_msg->hdr, CM_REP_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
 	rep_msg->local_comm_id = cm_id_priv->id.local_id;
 	rep_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1148,8 +1208,8 @@ static void cm_format_rtu(struct cm_rtu_
 			  void *private_data,
 			  u8 private_data_len)
 {
-	cm_format_mad_hdr(&rtu_msg->hdr, cm_id_priv,
-			  CM_RTU_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+	cm_format_mad_hdr(&rtu_msg->hdr, CM_RTU_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
 	rtu_msg->local_comm_id = cm_id_priv->id.local_id;
 	rtu_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1262,7 +1322,6 @@ static void cm_dup_rep_handler(struct cm
 static int cm_rep_handler(struct cm_work *work)
 {
 	struct cm_id_private *cm_id_priv;
-	struct cm_timewait_info *timewait_info;
 	struct cm_rep_msg *rep_msg;
 	unsigned long flags;
 	int ret;
@@ -1274,27 +1333,25 @@ static int cm_rep_handler(struct cm_work
 		return -EINVAL;
 	}
 
-	timewait_info = cm_create_timewait_info(cm_id_priv->id.local_id,
-						rep_msg->local_comm_id,
-						rep_msg->local_ca_guid,
-						cm_rep_get_local_qpn(rep_msg));
-	if (IS_ERR(timewait_info)) {
-		ret = PTR_ERR(timewait_info);
-		goto error1;
-	}
+	cm_id_priv->timewait_info->work.remote_id = rep_msg->local_comm_id;
+	cm_id_priv->timewait_info->remote_ca_guid = rep_msg->local_ca_guid;
+	cm_id_priv->timewait_info->remote_qpn = cm_rep_get_local_qpn(rep_msg);
+
 	spin_lock_irqsave(&cm.lock, flags);
 	/* Check for duplicate REP. */
-	if (cm_insert_remote_id(timewait_info)) {
+	if (cm_insert_remote_id(cm_id_priv->timewait_info)) {
 		spin_unlock_irqrestore(&cm.lock, flags);
 		ret = -EINVAL;
-		goto error2;
+		goto error;
 	}
 	/* Check for a stale connection. */
-	if (cm_insert_remote_qpn(timewait_info)) {
+	if (cm_insert_remote_qpn(cm_id_priv->timewait_info)) {
 		spin_unlock_irqrestore(&cm.lock, flags);
-		/* todo: reject as stale */
+		cm_issue_rej(work->port, work->mad_recv_wc,
+			     IB_CM_REJ_STALE_CONN, CM_MSG_RESPONSE_REP,
+			     NULL, 0);
 		ret = -EINVAL;
-		goto error2;
+		goto error;
 	}
 	spin_unlock_irqrestore(&cm.lock, flags);
 
@@ -1308,7 +1365,7 @@ static int cm_rep_handler(struct cm_work
 	default:
 		spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 		ret = -EINVAL;
-		goto error2;
+		goto error;
 	}
 	cm_id_priv->id.state = IB_CM_REP_RCVD;
 	cm_id_priv->id.remote_id = rep_msg->local_comm_id;
@@ -1317,7 +1374,6 @@ static int cm_rep_handler(struct cm_work
 	cm_id_priv->responder_resources = rep_msg->initiator_depth;
 	cm_id_priv->sq_psn = cm_rep_get_starting_psn(rep_msg);
 	cm_id_priv->rnr_retry_count = cm_rep_get_rnr_retry_count(rep_msg);
-	cm_id_priv->timewait_info = timewait_info;
 
 	/* todo: handle peer_to_peer */
 
@@ -1333,10 +1389,8 @@ static int cm_rep_handler(struct cm_work
 	else
 		cm_deref_id(cm_id_priv);
 	return 0;
-error2:
-	cm_cleanup_timewait(timewait_info);
-	kfree(timewait_info);
-error1:
+
+error:	cm_cleanup_timewait(cm_id_priv->timewait_info);
 	cm_deref_id(cm_id_priv);
 	return ret;
 }
@@ -1420,8 +1474,8 @@ static void cm_format_dreq(struct cm_dre
 			  void *private_data,
 			  u8 private_data_len)
 {
-	cm_format_mad_hdr(&dreq_msg->hdr, cm_id_priv,
-			  CM_DREQ_ATTR_ID, CM_MSG_SEQUENCE_DREQ);
+	cm_format_mad_hdr(&dreq_msg->hdr, CM_DREQ_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_DREQ));
 
 	dreq_msg->local_comm_id = cm_id_priv->id.local_id;
 	dreq_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1480,8 +1534,9 @@ static void cm_format_drep(struct cm_dre
 			  void *private_data,
 			  u8 private_data_len)
 {
-	cm_format_mad_hdr(&drep_msg->hdr, cm_id_priv,
-			  CM_DREP_ATTR_ID, CM_MSG_SEQUENCE_DREQ);
+	/* todo: TID should match received DREQ */
+	cm_format_mad_hdr(&drep_msg->hdr, CM_DREP_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_DREQ));
 
 	drep_msg->local_comm_id = cm_id_priv->id.local_id;
 	drep_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1642,8 +1697,9 @@ static void cm_format_rej(struct cm_rej_
 			  void *private_data,
 			  u8 private_data_len)
 {
-	cm_format_mad_hdr(&rej_msg->hdr, cm_id_priv,
-			  CM_REJ_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+	/* todo: TID should match received REQ */
+	cm_format_mad_hdr(&rej_msg->hdr, CM_REJ_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
 	rej_msg->remote_comm_id = cm_id_priv->id.remote_id;
 
@@ -1861,8 +1917,9 @@ static void cm_format_mra(struct cm_mra_
 		return;
 	}
 	spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-	cm_format_mad_hdr(&mra_msg->hdr, cm_id_priv,
-			  CM_MRA_ATTR_ID, msg_sequence);
+	/* todo: TID should matched REQ or LAP */
+	cm_format_mad_hdr(&mra_msg->hdr, CM_MRA_ATTR_ID,
+			  cm_form_tid(cm_id_priv, msg_sequence));
 	cm_mra_set_msg_mraed(mra_msg, msg_mraed);
 
 	mra_msg->local_comm_id = cm_id_priv->id.local_id;
@@ -1946,8 +2003,8 @@ static void cm_format_lap(struct cm_lap_
 			  void *private_data,
 			  u8 private_data_len)
 {
-	cm_format_mad_hdr(&lap_msg->hdr, cm_id_priv,
-			  CM_LAP_ATTR_ID, CM_MSG_SEQUENCE_LAP);
+	cm_format_mad_hdr(&lap_msg->hdr, CM_LAP_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_LAP));
 
 	lap_msg->local_comm_id = cm_id_priv->id.local_id;
 	lap_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -2089,8 +2146,9 @@ static void cm_format_apr(struct cm_apr_
 			  void *private_data,
 			  u8 private_data_len)
 {
-	cm_format_mad_hdr(&apr_msg->hdr, cm_id_priv,
-			  CM_APR_ATTR_ID, CM_MSG_SEQUENCE_LAP);
+	/* todo: TID should match received LAP */
+	cm_format_mad_hdr(&apr_msg->hdr, CM_APR_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_LAP));
 
 	apr_msg->local_comm_id = cm_id_priv->id.local_id;
 	apr_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -2237,8 +2295,8 @@ static void cm_format_sidr_req(struct cm
 			       struct cm_id_private *cm_id_priv,
 			       struct ib_cm_sidr_req_param *param)
 {
-	cm_format_mad_hdr(&sidr_req_msg->hdr, cm_id_priv,
-			  CM_SIDR_REQ_ATTR_ID, CM_MSG_SEQUENCE_SIDR);
+	cm_format_mad_hdr(&sidr_req_msg->hdr, CM_SIDR_REQ_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_SIDR));
 
 	sidr_req_msg->request_id = cm_id_priv->id.local_id;
 	sidr_req_msg->pkey = param->pkey;
@@ -2351,7 +2409,7 @@ static int cm_sidr_req_handler(struct cm
 	if (!cur_cm_id_priv) {
 		rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table);
 		spin_unlock_irqrestore(&cm.lock, flags);
-		/* todo: reject with no match */
+		/* todo: reply with no match */
 		goto out; /* No match. */
 	}
 	atomic_inc(&cur_cm_id_priv->refcount);
@@ -2375,8 +2433,9 @@ static void cm_format_sidr_rep(struct cm
 			       struct cm_id_private *cm_id_priv,
 			       struct ib_cm_sidr_rep_param *param)
 {
-	cm_format_mad_hdr(&sidr_rep_msg->hdr, cm_id_priv,
-			  CM_SIDR_REP_ATTR_ID, CM_MSG_SEQUENCE_SIDR);
+	/* todo: TID should match received SIDR REQ */
+	cm_format_mad_hdr(&sidr_rep_msg->hdr, CM_SIDR_REP_ATTR_ID,
+			  cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_SIDR));
 
 	sidr_rep_msg->request_id = cm_id_priv->id.remote_id;
 	sidr_rep_msg->status = param->status;






More information about the general mailing list