[openib-general] [PATCH] [CM] reject stale connections
Sean Hefty
sean.hefty at intel.com
Fri May 13 13:36:52 PDT 2005
The following patch will reject stale connection requests.
It also fixes a couple of minor error handling cleanup issues
found during testing, and fixes a bug handling a stale
connection reject message where the cm_id would access an
invalid timewait pointer.
Signed-off-by: Sean Hefty <sean.hefty at intel.com>
Index: core/cm.c
===================================================================
--- core/cm.c (revision 2346)
+++ core/cm.c (working copy)
@@ -185,10 +185,35 @@ static int cm_alloc_msg(struct cm_id_pri
return 0;
}
+static int cm_alloc_response_msg(struct cm_port *port,
+ struct ib_mad_recv_wc *mad_recv_wc,
+ struct ib_mad_send_buf **msg)
+{
+ struct ib_mad_send_buf *m;
+ struct ib_ah *ah;
+
+ ah = ib_create_ah_from_wc(port->mad_agent->qp->pd, mad_recv_wc->wc,
+ mad_recv_wc->recv_buf.grh, port->port_num);
+ if (IS_ERR(ah))
+ return PTR_ERR(ah);
+
+ m = ib_create_send_mad(port->mad_agent, 1, mad_recv_wc->wc->pkey_index,
+ ah, 0, sizeof(struct ib_mad_hdr),
+ sizeof(struct ib_mad)-sizeof(struct ib_mad_hdr),
+ GFP_ATOMIC);
+ if (IS_ERR(m)) {
+ ib_destroy_ah(ah);
+ return PTR_ERR(m);
+ }
+ *msg = m;
+ return 0;
+}
+
static void cm_free_msg(struct ib_mad_send_buf *msg)
{
ib_destroy_ah(msg->send_wr.wr.ud.ah);
- cm_deref_id(msg->context[0]);
+ if (msg->context[0])
+ cm_deref_id(msg->context[0]);
ib_free_send_mad(msg);
}
@@ -531,10 +556,7 @@ static void cm_cleanup_timewait(struct c
spin_unlock_irqrestore(&cm.lock, flags);
}
-static struct cm_timewait_info * cm_create_timewait_info(u32 local_id,
- u32 remote_id,
- u64 remote_ca_guid,
- u32 remote_qpn)
+static struct cm_timewait_info * cm_create_timewait_info(u32 local_id)
{
struct cm_timewait_info *timewait_info;
@@ -544,10 +566,6 @@ static struct cm_timewait_info * cm_crea
memset(timewait_info, 0, sizeof *timewait_info);
timewait_info->work.local_id = local_id;
- timewait_info->work.remote_id = remote_id;
- timewait_info->remote_ca_guid = remote_ca_guid;
- timewait_info->remote_qpn = remote_qpn;
-
INIT_WORK(&timewait_info->work.work, cm_work_handler,
&timewait_info->work);
timewait_info->work.cm_event.event = IB_CM_TIMEWAIT_EXIT;
@@ -674,30 +692,33 @@ int ib_cm_listen(struct ib_cm_id *cm_id,
}
EXPORT_SYMBOL(ib_cm_listen);
-static void cm_format_mad_hdr(struct ib_mad_hdr *hdr,
- struct cm_id_private *cm_id_priv,
- enum cm_msg_attr_id attr_id,
- enum cm_msg_sequence msg_seq)
+static u64 cm_form_tid(struct cm_id_private *cm_id_priv,
+ enum cm_msg_sequence msg_seq)
{
u64 hi_tid, low_tid;
+ hi_tid = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
+ low_tid = (u64) (cm_id_priv->id.local_id | (msg_seq << 30));
+ return cpu_to_be64(hi_tid | low_tid);
+}
+
+static void cm_format_mad_hdr(struct ib_mad_hdr *hdr,
+ enum cm_msg_attr_id attr_id, u64 tid)
+{
hdr->base_version = IB_MGMT_BASE_VERSION;
hdr->mgmt_class = IB_MGMT_CLASS_CM;
hdr->class_version = IB_CM_CLASS_VERSION;
hdr->method = IB_MGMT_METHOD_SEND;
hdr->attr_id = attr_id;
-
- hi_tid = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
- low_tid = (u64) (cm_id_priv->id.local_id | (msg_seq << 30));
- hdr->tid = cpu_to_be64(hi_tid | low_tid);
+ hdr->tid = tid;
}
static void cm_format_req(struct cm_req_msg *req_msg,
struct cm_id_private *cm_id_priv,
struct ib_cm_req_param *param)
{
- cm_format_mad_hdr(&req_msg->hdr, cm_id_priv,
- CM_REQ_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+ cm_format_mad_hdr(&req_msg->hdr, CM_REQ_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
req_msg->local_comm_id = cm_id_priv->id.local_id;
req_msg->service_id = param->service_id;
@@ -755,6 +776,10 @@ static void cm_format_req(struct cm_req_
static inline int cm_validate_req_param(struct ib_cm_req_param *param)
{
+ /* peer-to-peer not supported */
+ if (param->peer_to_peer)
+ return -EINVAL;
+
if (!param->primary_path)
return -EINVAL;
@@ -796,14 +821,19 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
}
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+ cm_id_priv->timewait_info = cm_create_timewait_info(cm_id_priv->
+ id.local_id);
+ if (IS_ERR(cm_id_priv->timewait_info))
+ goto out;
+
ret = cm_init_av_by_path(param->primary_path, &cm_id_priv->av);
if (ret)
- goto out;
+ goto error1;
if (param->alternate_path) {
ret = cm_init_av_by_path(param->alternate_path,
&cm_id_priv->alt_av);
if (ret)
- goto out;
+ goto error1;
}
cm_id->service_id = param->service_id;
cm_id->service_mask = ~0ULL;
@@ -819,7 +849,7 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg);
if (ret)
- goto out;
+ goto error1;
req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad;
cm_format_req(req_msg, cm_id_priv, param);
@@ -831,35 +861,61 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
cm_id_priv->local_ack_timeout =
cm_req_get_primary_local_ack_timeout(req_msg);
- /*
- * Received REQs won't match until we're in REQ_SENT state. This
- * simplifies error recovery if the send fails.
- */
- if (param->peer_to_peer) {
- ret = -EINVAL;
- goto out;
- }
-
spin_lock_irqsave(&cm_id_priv->lock, flags);
ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
&cm_id_priv->msg->send_wr, &bad_send_wr);
-
if (ret) {
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
- /* if (param->peer_to_peer) {
- cleanup peer_service_table
- } */
- cm_free_msg(cm_id_priv->msg);
- goto out;
+ goto error2;
}
BUG_ON(cm_id->state != IB_CM_IDLE);
cm_id->state = IB_CM_REQ_SENT;
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
- return ret;
+ return 0;
+
+error2: cm_free_msg(cm_id_priv->msg);
+error1: kfree(cm_id_priv->timewait_info);
+out: return ret;
}
EXPORT_SYMBOL(ib_send_cm_req);
+static int cm_issue_rej(struct cm_port *port,
+ struct ib_mad_recv_wc *mad_recv_wc,
+ enum ib_cm_rej_reason reason,
+ enum cm_msg_response msg_rejected,
+ void *ari, u8 ari_length)
+{
+ struct ib_mad_send_buf *msg;
+ struct ib_send_wr *bad_send_wr;
+ struct cm_rej_msg *rej_msg, *rcv_msg;
+ int ret;
+
+ ret = cm_alloc_response_msg(port, mad_recv_wc, &msg);
+ if (ret)
+ return ret;
+
+ /* We just need common CM header information. Cast to any message. */
+ rcv_msg = (struct cm_rej_msg *) mad_recv_wc->recv_buf.mad;
+ rej_msg = (struct cm_rej_msg *) msg->mad;
+
+ cm_format_mad_hdr(&rej_msg->hdr, CM_REJ_ATTR_ID, rcv_msg->hdr.tid);
+ rej_msg->remote_comm_id = rcv_msg->local_comm_id;
+ rej_msg->local_comm_id = rcv_msg->remote_comm_id;
+ cm_rej_set_msg_rejected(rej_msg, msg_rejected);
+ rej_msg->reason = reason;
+
+ if (ari && ari_length) {
+ cm_rej_set_reject_info_len(rej_msg, ari_length);
+ memcpy(rej_msg->ari, ari, ari_length);
+ }
+
+ ret = ib_post_send_mad(port->mad_agent, &msg->send_wr, &bad_send_wr);
+ if (ret)
+ cm_free_msg(msg);
+
+ return ret;
+}
+
static inline int cm_is_active_peer(u64 local_ca_guid, u64 remote_ca_guid,
u32 local_qpn, u32 remote_qpn)
{
@@ -992,29 +1048,31 @@ static int cm_req_handler(struct cm_work
cm_id_priv->id.remote_id = req_msg->local_comm_id;
cm_init_av_for_response(work->port, work->mad_recv_wc->wc,
&cm_id_priv->av);
- cm_id_priv->timewait_info = cm_create_timewait_info(
- cm_id_priv->id.local_id,
- cm_id_priv->id.remote_id,
- req_msg->local_ca_guid,
- cm_req_get_local_qpn(req_msg));
+ cm_id_priv->timewait_info = cm_create_timewait_info(cm_id_priv->
+ id.local_id);
if (IS_ERR(cm_id_priv->timewait_info)) {
ret = PTR_ERR(cm_id_priv->timewait_info);
goto error1;
}
+ cm_id_priv->timewait_info->work.remote_id = req_msg->local_comm_id;
+ cm_id_priv->timewait_info->remote_ca_guid = req_msg->local_ca_guid;
+ cm_id_priv->timewait_info->remote_qpn = cm_req_get_local_qpn(req_msg);
spin_lock_irqsave(&cm.lock, flags);
/* Check for duplicate REQ. */
if (cm_insert_remote_id(cm_id_priv->timewait_info)) {
spin_unlock_irqrestore(&cm.lock, flags);
ret = -EINVAL;
- goto error1;
+ goto error2;
}
/* Check for a stale connection. */
if (cm_insert_remote_qpn(cm_id_priv->timewait_info)) {
spin_unlock_irqrestore(&cm.lock, flags);
- /* todo: reject as stale */
+ cm_issue_rej(work->port, work->mad_recv_wc,
+ IB_CM_REJ_STALE_CONN, CM_MSG_RESPONSE_REQ,
+ NULL, 0);
ret = -EINVAL;
- goto error1;
+ goto error2;
}
/* Find matching listen request. */
listen_cm_id_priv = cm_find_listen(req_msg->service_id);
@@ -1035,11 +1093,11 @@ static int cm_req_handler(struct cm_work
cm_format_paths_from_req(req_msg, &work->path[0], &work->path[1]);
ret = cm_init_av_by_path(&work->path[0], &cm_id_priv->av);
if (ret)
- goto error2;
+ goto error3;
if (req_msg->alt_local_lid) {
ret = cm_init_av_by_path(&work->path[1], &cm_id_priv->alt_av);
if (ret)
- goto error2;
+ goto error3;
}
cm_id_priv->timeout_ms = cm_convert_to_ms(
cm_req_get_local_resp_timeout(req_msg));
@@ -1058,11 +1116,12 @@ static int cm_req_handler(struct cm_work
cm_process_work(cm_id_priv, work);
cm_deref_id(listen_cm_id_priv);
return 0;
-error2:
- atomic_dec(&cm_id_priv->refcount);
+
+error3: atomic_dec(&cm_id_priv->refcount);
cm_deref_id(listen_cm_id_priv);
-error1:
- ib_destroy_cm_id(&cm_id_priv->id);
+error2: cm_cleanup_timewait(cm_id_priv->timewait_info);
+ kfree(cm_id_priv->timewait_info);
+error1: ib_destroy_cm_id(&cm_id_priv->id);
return ret;
}
@@ -1070,8 +1129,9 @@ static void cm_format_rep(struct cm_rep_
struct cm_id_private *cm_id_priv,
struct ib_cm_rep_param *param)
{
- cm_format_mad_hdr(&rep_msg->hdr, cm_id_priv,
- CM_REP_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+ /* todo: TID should match received REQ */
+ cm_format_mad_hdr(&rep_msg->hdr, CM_REP_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
rep_msg->local_comm_id = cm_id_priv->id.local_id;
rep_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1148,8 +1208,8 @@ static void cm_format_rtu(struct cm_rtu_
void *private_data,
u8 private_data_len)
{
- cm_format_mad_hdr(&rtu_msg->hdr, cm_id_priv,
- CM_RTU_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+ cm_format_mad_hdr(&rtu_msg->hdr, CM_RTU_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
rtu_msg->local_comm_id = cm_id_priv->id.local_id;
rtu_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1262,7 +1322,6 @@ static void cm_dup_rep_handler(struct cm
static int cm_rep_handler(struct cm_work *work)
{
struct cm_id_private *cm_id_priv;
- struct cm_timewait_info *timewait_info;
struct cm_rep_msg *rep_msg;
unsigned long flags;
int ret;
@@ -1274,27 +1333,25 @@ static int cm_rep_handler(struct cm_work
return -EINVAL;
}
- timewait_info = cm_create_timewait_info(cm_id_priv->id.local_id,
- rep_msg->local_comm_id,
- rep_msg->local_ca_guid,
- cm_rep_get_local_qpn(rep_msg));
- if (IS_ERR(timewait_info)) {
- ret = PTR_ERR(timewait_info);
- goto error1;
- }
+ cm_id_priv->timewait_info->work.remote_id = rep_msg->local_comm_id;
+ cm_id_priv->timewait_info->remote_ca_guid = rep_msg->local_ca_guid;
+ cm_id_priv->timewait_info->remote_qpn = cm_rep_get_local_qpn(rep_msg);
+
spin_lock_irqsave(&cm.lock, flags);
/* Check for duplicate REP. */
- if (cm_insert_remote_id(timewait_info)) {
+ if (cm_insert_remote_id(cm_id_priv->timewait_info)) {
spin_unlock_irqrestore(&cm.lock, flags);
ret = -EINVAL;
- goto error2;
+ goto error;
}
/* Check for a stale connection. */
- if (cm_insert_remote_qpn(timewait_info)) {
+ if (cm_insert_remote_qpn(cm_id_priv->timewait_info)) {
spin_unlock_irqrestore(&cm.lock, flags);
- /* todo: reject as stale */
+ cm_issue_rej(work->port, work->mad_recv_wc,
+ IB_CM_REJ_STALE_CONN, CM_MSG_RESPONSE_REP,
+ NULL, 0);
ret = -EINVAL;
- goto error2;
+ goto error;
}
spin_unlock_irqrestore(&cm.lock, flags);
@@ -1308,7 +1365,7 @@ static int cm_rep_handler(struct cm_work
default:
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
ret = -EINVAL;
- goto error2;
+ goto error;
}
cm_id_priv->id.state = IB_CM_REP_RCVD;
cm_id_priv->id.remote_id = rep_msg->local_comm_id;
@@ -1317,7 +1374,6 @@ static int cm_rep_handler(struct cm_work
cm_id_priv->responder_resources = rep_msg->initiator_depth;
cm_id_priv->sq_psn = cm_rep_get_starting_psn(rep_msg);
cm_id_priv->rnr_retry_count = cm_rep_get_rnr_retry_count(rep_msg);
- cm_id_priv->timewait_info = timewait_info;
/* todo: handle peer_to_peer */
@@ -1333,10 +1389,8 @@ static int cm_rep_handler(struct cm_work
else
cm_deref_id(cm_id_priv);
return 0;
-error2:
- cm_cleanup_timewait(timewait_info);
- kfree(timewait_info);
-error1:
+
+error: cm_cleanup_timewait(cm_id_priv->timewait_info);
cm_deref_id(cm_id_priv);
return ret;
}
@@ -1420,8 +1474,8 @@ static void cm_format_dreq(struct cm_dre
void *private_data,
u8 private_data_len)
{
- cm_format_mad_hdr(&dreq_msg->hdr, cm_id_priv,
- CM_DREQ_ATTR_ID, CM_MSG_SEQUENCE_DREQ);
+ cm_format_mad_hdr(&dreq_msg->hdr, CM_DREQ_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_DREQ));
dreq_msg->local_comm_id = cm_id_priv->id.local_id;
dreq_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1480,8 +1534,9 @@ static void cm_format_drep(struct cm_dre
void *private_data,
u8 private_data_len)
{
- cm_format_mad_hdr(&drep_msg->hdr, cm_id_priv,
- CM_DREP_ATTR_ID, CM_MSG_SEQUENCE_DREQ);
+ /* todo: TID should match received DREQ */
+ cm_format_mad_hdr(&drep_msg->hdr, CM_DREP_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_DREQ));
drep_msg->local_comm_id = cm_id_priv->id.local_id;
drep_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1642,8 +1697,9 @@ static void cm_format_rej(struct cm_rej_
void *private_data,
u8 private_data_len)
{
- cm_format_mad_hdr(&rej_msg->hdr, cm_id_priv,
- CM_REJ_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+ /* todo: TID should match received REQ */
+ cm_format_mad_hdr(&rej_msg->hdr, CM_REJ_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
rej_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1861,8 +1917,9 @@ static void cm_format_mra(struct cm_mra_
return;
}
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
- cm_format_mad_hdr(&mra_msg->hdr, cm_id_priv,
- CM_MRA_ATTR_ID, msg_sequence);
+ /* todo: TID should matched REQ or LAP */
+ cm_format_mad_hdr(&mra_msg->hdr, CM_MRA_ATTR_ID,
+ cm_form_tid(cm_id_priv, msg_sequence));
cm_mra_set_msg_mraed(mra_msg, msg_mraed);
mra_msg->local_comm_id = cm_id_priv->id.local_id;
@@ -1946,8 +2003,8 @@ static void cm_format_lap(struct cm_lap_
void *private_data,
u8 private_data_len)
{
- cm_format_mad_hdr(&lap_msg->hdr, cm_id_priv,
- CM_LAP_ATTR_ID, CM_MSG_SEQUENCE_LAP);
+ cm_format_mad_hdr(&lap_msg->hdr, CM_LAP_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_LAP));
lap_msg->local_comm_id = cm_id_priv->id.local_id;
lap_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -2089,8 +2146,9 @@ static void cm_format_apr(struct cm_apr_
void *private_data,
u8 private_data_len)
{
- cm_format_mad_hdr(&apr_msg->hdr, cm_id_priv,
- CM_APR_ATTR_ID, CM_MSG_SEQUENCE_LAP);
+ /* todo: TID should match received LAP */
+ cm_format_mad_hdr(&apr_msg->hdr, CM_APR_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_LAP));
apr_msg->local_comm_id = cm_id_priv->id.local_id;
apr_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -2237,8 +2295,8 @@ static void cm_format_sidr_req(struct cm
struct cm_id_private *cm_id_priv,
struct ib_cm_sidr_req_param *param)
{
- cm_format_mad_hdr(&sidr_req_msg->hdr, cm_id_priv,
- CM_SIDR_REQ_ATTR_ID, CM_MSG_SEQUENCE_SIDR);
+ cm_format_mad_hdr(&sidr_req_msg->hdr, CM_SIDR_REQ_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_SIDR));
sidr_req_msg->request_id = cm_id_priv->id.local_id;
sidr_req_msg->pkey = param->pkey;
@@ -2351,7 +2409,7 @@ static int cm_sidr_req_handler(struct cm
if (!cur_cm_id_priv) {
rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table);
spin_unlock_irqrestore(&cm.lock, flags);
- /* todo: reject with no match */
+ /* todo: reply with no match */
goto out; /* No match. */
}
atomic_inc(&cur_cm_id_priv->refcount);
@@ -2375,8 +2433,9 @@ static void cm_format_sidr_rep(struct cm
struct cm_id_private *cm_id_priv,
struct ib_cm_sidr_rep_param *param)
{
- cm_format_mad_hdr(&sidr_rep_msg->hdr, cm_id_priv,
- CM_SIDR_REP_ATTR_ID, CM_MSG_SEQUENCE_SIDR);
+ /* todo: TID should match received SIDR REQ */
+ cm_format_mad_hdr(&sidr_rep_msg->hdr, CM_SIDR_REP_ATTR_ID,
+ cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_SIDR));
sidr_rep_msg->request_id = cm_id_priv->id.remote_id;
sidr_rep_msg->status = param->status;
More information about the general
mailing list