[openib-general] Re: [PATCH] ipoib_flush_paths

Roland Dreier rdreier at cisco.com
Thu Apr 6 11:20:54 PDT 2006


    Michael> Actually, it turned out to be the simplest solution - and
    Michael> quite elegant since there's no room for mistakes: if
    Michael> query is going to be running this means module is still
    Michael> loaded so we can take a reference to it without races.

Yes, this is suprisingly clean.

    Michael> As a bonus, and assertion inside __module_get increases
    Michael> the chance to catch races where user forgets to cancel
    Michael> the query - much nicer than crashing randomly.

Actually I think __module_get() will do the wrong thing if called
during module unloading -- it doesn't test module_is_live().  In other
words, calling __module_get() without already holding a ref has a
race: __try_stop_module() can see the ref count as 0, then
__module_get() can increment it, and then __try_stop_module() sets the
module state to GOING and returns.

So the right thing to do is BUG_ON(!try_module_get(owner))

Also, I don't think that a consumer of ib_sa() would ever pass an
owner other than THIS_MODULE.  So how about if we keep the API the
same and just do the THIS_MODULE stuff in an inline wrapper?

Like the following...  it ends up being a pretty big diff, but just
because I moved some comments around and so on.  Also I put the
try_module_get() stuff out of line into call_sa_callback(), because
the compiled code ends up smaller that way.

Does anyone disagree with this patch?  Michael, are you happy with
this tweaked version of yours?

diff --git a/drivers/infiniband/core/sa_query.c b/drivers/infiniband/core/sa_query.c
index 501cc05..c43ed75 100644
--- a/drivers/infiniband/core/sa_query.c
+++ b/drivers/infiniband/core/sa_query.c
@@ -74,6 +74,7 @@ struct ib_sa_device {
 struct ib_sa_query {
 	void (*callback)(struct ib_sa_query *, int, struct ib_sa_mad *);
 	void (*release)(struct ib_sa_query *);
+	struct module	       *owner;
 	struct ib_sa_port      *port;
 	struct ib_mad_send_buf *mad_buf;
 	struct ib_sa_sm_ah     *sm_ah;
@@ -547,15 +548,16 @@ static void ib_sa_path_rec_release(struc
  * error code.  Otherwise it is a query ID that can be used to cancel
  * the query.
  */
-int ib_sa_path_rec_get(struct ib_device *device, u8 port_num,
-		       struct ib_sa_path_rec *rec,
-		       ib_sa_comp_mask comp_mask,
-		       int timeout_ms, gfp_t gfp_mask,
-		       void (*callback)(int status,
-					struct ib_sa_path_rec *resp,
-					void *context),
-		       void *context,
-		       struct ib_sa_query **sa_query)
+int __ib_sa_path_rec_get(struct ib_device *device, u8 port_num,
+			 struct ib_sa_path_rec *rec,
+			 ib_sa_comp_mask comp_mask,
+			 int timeout_ms, gfp_t gfp_mask,
+			 void (*callback)(int status,
+					  struct ib_sa_path_rec *resp,
+					  void *context),
+			 void *context,
+			 struct module *owner,
+			 struct ib_sa_query **sa_query)
 {
 	struct ib_sa_path_query *query;
 	struct ib_sa_device *sa_dev = ib_get_client_data(device, &sa_client);
@@ -590,6 +592,7 @@ int ib_sa_path_rec_get(struct ib_device 
 
 	query->sa_query.callback = callback ? ib_sa_path_rec_callback : NULL;
 	query->sa_query.release  = ib_sa_path_rec_release;
+	query->sa_query.owner    = owner;
 	query->sa_query.port     = port;
 	mad->mad_hdr.method	 = IB_MGMT_METHOD_GET;
 	mad->mad_hdr.attr_id	 = cpu_to_be16(IB_SA_ATTR_PATH_REC);
@@ -613,7 +616,7 @@ err1:
 	kfree(query);
 	return ret;
 }
-EXPORT_SYMBOL(ib_sa_path_rec_get);
+EXPORT_SYMBOL(__ib_sa_path_rec_get);
 
 static void ib_sa_service_rec_callback(struct ib_sa_query *sa_query,
 				    int status,
@@ -663,15 +666,16 @@ static void ib_sa_service_rec_release(st
  * error code.  Otherwise it is a request ID that can be used to cancel
  * the query.
  */
-int ib_sa_service_rec_query(struct ib_device *device, u8 port_num, u8 method,
-			    struct ib_sa_service_rec *rec,
-			    ib_sa_comp_mask comp_mask,
-			    int timeout_ms, gfp_t gfp_mask,
-			    void (*callback)(int status,
-					     struct ib_sa_service_rec *resp,
-					     void *context),
-			    void *context,
-			    struct ib_sa_query **sa_query)
+int __ib_sa_service_rec_query(struct ib_device *device, u8 port_num, u8 method,
+			      struct ib_sa_service_rec *rec,
+			      ib_sa_comp_mask comp_mask,
+			      int timeout_ms, gfp_t gfp_mask,
+			      void (*callback)(int status,
+					       struct ib_sa_service_rec *resp,
+					       void *context),
+			      void *context,
+			      struct module *owner,
+			      struct ib_sa_query **sa_query)
 {
 	struct ib_sa_service_query *query;
 	struct ib_sa_device *sa_dev = ib_get_client_data(device, &sa_client);
@@ -711,6 +715,7 @@ int ib_sa_service_rec_query(struct ib_de
 
 	query->sa_query.callback = callback ? ib_sa_service_rec_callback : NULL;
 	query->sa_query.release  = ib_sa_service_rec_release;
+	query->sa_query.owner    = owner;
 	query->sa_query.port     = port;
 	mad->mad_hdr.method	 = method;
 	mad->mad_hdr.attr_id	 = cpu_to_be16(IB_SA_ATTR_SERVICE_REC);
@@ -735,7 +740,7 @@ err1:
 	kfree(query);
 	return ret;
 }
-EXPORT_SYMBOL(ib_sa_service_rec_query);
+EXPORT_SYMBOL(__ib_sa_service_rec_query);
 
 static void ib_sa_mcmember_rec_callback(struct ib_sa_query *sa_query,
 					int status,
@@ -759,16 +764,17 @@ static void ib_sa_mcmember_rec_release(s
 	kfree(container_of(sa_query, struct ib_sa_mcmember_query, sa_query));
 }
 
-int ib_sa_mcmember_rec_query(struct ib_device *device, u8 port_num,
-			     u8 method,
-			     struct ib_sa_mcmember_rec *rec,
-			     ib_sa_comp_mask comp_mask,
-			     int timeout_ms, gfp_t gfp_mask,
-			     void (*callback)(int status,
-					      struct ib_sa_mcmember_rec *resp,
-					      void *context),
-			     void *context,
-			     struct ib_sa_query **sa_query)
+int __ib_sa_mcmember_rec_query(struct ib_device *device, u8 port_num,
+			       u8 method,
+			       struct ib_sa_mcmember_rec *rec,
+			       ib_sa_comp_mask comp_mask,
+			       int timeout_ms, gfp_t gfp_mask,
+			       void (*callback)(int status,
+						struct ib_sa_mcmember_rec *resp,
+						void *context),
+			       void *context,
+			       struct module *owner,
+			       struct ib_sa_query **sa_query)
 {
 	struct ib_sa_mcmember_query *query;
 	struct ib_sa_device *sa_dev = ib_get_client_data(device, &sa_client);
@@ -803,6 +809,7 @@ int ib_sa_mcmember_rec_query(struct ib_d
 
 	query->sa_query.callback = callback ? ib_sa_mcmember_rec_callback : NULL;
 	query->sa_query.release  = ib_sa_mcmember_rec_release;
+	query->sa_query.owner    = owner;
 	query->sa_query.port     = port;
 	mad->mad_hdr.method	 = method;
 	mad->mad_hdr.attr_id	 = cpu_to_be16(IB_SA_ATTR_MC_MEMBER_REC);
@@ -827,7 +834,15 @@ err1:
 	kfree(query);
 	return ret;
 }
-EXPORT_SYMBOL(ib_sa_mcmember_rec_query);
+EXPORT_SYMBOL(__ib_sa_mcmember_rec_query);
+
+static void call_sa_callback(struct ib_sa_query *query, int status,
+			     struct ib_sa_mad *mad)
+{
+	BUG_ON(!try_module_get(query->owner));
+	query->callback(query, status, mad);
+	module_put(query->owner);
+}
 
 static void send_handler(struct ib_mad_agent *agent,
 			 struct ib_mad_send_wc *mad_send_wc)
@@ -841,13 +856,13 @@ static void send_handler(struct ib_mad_a
 			/* No callback -- already got recv */
 			break;
 		case IB_WC_RESP_TIMEOUT_ERR:
-			query->callback(query, -ETIMEDOUT, NULL);
+			call_sa_callback(query, -ETIMEDOUT, NULL);
 			break;
 		case IB_WC_WR_FLUSH_ERR:
-			query->callback(query, -EINTR, NULL);
+			call_sa_callback(query, -EINTR, NULL);
 			break;
 		default:
-			query->callback(query, -EIO, NULL);
+			call_sa_callback(query, -EIO, NULL);
 			break;
 		}
 
@@ -871,12 +886,12 @@ static void recv_handler(struct ib_mad_a
 
 	if (query->callback) {
 		if (mad_recv_wc->wc->status == IB_WC_SUCCESS)
-			query->callback(query,
-					mad_recv_wc->recv_buf.mad->mad_hdr.status ?
-					-EINVAL : 0,
-					(struct ib_sa_mad *) mad_recv_wc->recv_buf.mad);
+			call_sa_callback(query,
+					 mad_recv_wc->recv_buf.mad->mad_hdr.status ?
+					 -EINVAL : 0,
+					 (struct ib_sa_mad *) mad_recv_wc->recv_buf.mad);
 		else
-			query->callback(query, -EIO, NULL);
+			call_sa_callback(query, -EIO, NULL);
 	}
 
 	ib_free_recv_mad(mad_recv_wc);
diff --git a/include/rdma/ib_sa.h b/include/rdma/ib_sa.h
index ad63c21..6769d1b 100644
--- a/include/rdma/ib_sa.h
+++ b/include/rdma/ib_sa.h
@@ -254,37 +254,80 @@ struct ib_sa_query;
 
 void ib_sa_cancel_query(int id, struct ib_sa_query *query);
 
-int ib_sa_path_rec_get(struct ib_device *device, u8 port_num,
-		       struct ib_sa_path_rec *rec,
-		       ib_sa_comp_mask comp_mask,
-		       int timeout_ms, gfp_t gfp_mask,
-		       void (*callback)(int status,
-					struct ib_sa_path_rec *resp,
-					void *context),
-		       void *context,
-		       struct ib_sa_query **query);
-
-int ib_sa_mcmember_rec_query(struct ib_device *device, u8 port_num,
-			     u8 method,
-			     struct ib_sa_mcmember_rec *rec,
-			     ib_sa_comp_mask comp_mask,
-			     int timeout_ms, gfp_t gfp_mask,
-			     void (*callback)(int status,
-					      struct ib_sa_mcmember_rec *resp,
-					      void *context),
-			     void *context,
-			     struct ib_sa_query **query);
-
-int ib_sa_service_rec_query(struct ib_device *device, u8 port_num,
-			 u8 method,
-			 struct ib_sa_service_rec *rec,
+int __ib_sa_path_rec_get(struct ib_device *device, u8 port_num,
+			 struct ib_sa_path_rec *rec,
 			 ib_sa_comp_mask comp_mask,
 			 int timeout_ms, gfp_t gfp_mask,
 			 void (*callback)(int status,
-					  struct ib_sa_service_rec *resp,
+					  struct ib_sa_path_rec *resp,
 					  void *context),
 			 void *context,
-			 struct ib_sa_query **sa_query);
+			 struct module *owner,
+			 struct ib_sa_query **query);
+
+int __ib_sa_mcmember_rec_query(struct ib_device *device, u8 port_num,
+			       u8 method,
+			       struct ib_sa_mcmember_rec *rec,
+			       ib_sa_comp_mask comp_mask,
+			       int timeout_ms, gfp_t gfp_mask,
+			       void (*callback)(int status,
+						struct ib_sa_mcmember_rec *resp,
+						void *context),
+			       void *context,
+			       struct module *owner,
+			       struct ib_sa_query **query);
+
+int __ib_sa_service_rec_query(struct ib_device *device, u8 port_num,
+			      u8 method,
+			      struct ib_sa_service_rec *rec,
+			      ib_sa_comp_mask comp_mask,
+			      int timeout_ms, gfp_t gfp_mask,
+			      void (*callback)(int status,
+					       struct ib_sa_service_rec *resp,
+					       void *context),
+			      void *context,
+			      struct module *owner,
+			      struct ib_sa_query **sa_query);
+
+/**
+ * ib_sa_path_rec_get - Start a Path get query
+ * @device:device to send query on
+ * @port_num: port number to send query on
+ * @rec:Path Record to send in query
+ * @comp_mask:component mask to send in query
+ * @timeout_ms:time to wait for response
+ * @gfp_mask:GFP mask to use for internal allocations
+ * @callback:function called when query completes, times out or is
+ * canceled
+ * @context:opaque user context passed to callback
+ * @sa_query:query context, used to cancel query
+ *
+ * Send a Path Record Get query to the SA to look up a path.  The
+ * callback function will be called when the query completes (or
+ * fails); status is 0 for a successful response, -EINTR if the query
+ * is canceled, -ETIMEDOUT is the query timed out, or -EIO if an error
+ * occurred sending the query.  The resp parameter of the callback is
+ * only valid if status is 0.
+ *
+ * If the return value of ib_sa_path_rec_get() is negative, it is an
+ * error code.  Otherwise it is a query ID that can be used to cancel
+ * the query.
+ */
+static inline int
+ib_sa_path_rec_get(struct ib_device *device, u8 port_num,
+		   struct ib_sa_path_rec *rec,
+		   ib_sa_comp_mask comp_mask,
+		   int timeout_ms, gfp_t gfp_mask,
+		   void (*callback)(int status,
+				    struct ib_sa_path_rec *resp,
+				    void *context),
+		   void *context,
+		   struct ib_sa_query **sa_query)
+{
+	return __ib_sa_path_rec_get(device, port_num, rec, comp_mask,
+				    timeout_ms, gfp_mask, callback,
+				    context, THIS_MODULE, sa_query);
+}
 
 /**
  * ib_sa_mcmember_rec_set - Start an MCMember set query
@@ -321,11 +364,11 @@ ib_sa_mcmember_rec_set(struct ib_device 
 		       void *context,
 		       struct ib_sa_query **query)
 {
-	return ib_sa_mcmember_rec_query(device, port_num,
-					IB_MGMT_METHOD_SET,
-					rec, comp_mask,
-					timeout_ms, gfp_mask, callback,
-					context, query);
+	return __ib_sa_mcmember_rec_query(device, port_num,
+					  IB_MGMT_METHOD_SET,
+					  rec, comp_mask,
+					  timeout_ms, gfp_mask, callback,
+					  context, THIS_MODULE, query);
 }
 
 /**
@@ -363,12 +406,54 @@ ib_sa_mcmember_rec_delete(struct ib_devi
 			  void *context,
 			  struct ib_sa_query **query)
 {
-	return ib_sa_mcmember_rec_query(device, port_num,
-					IB_SA_METHOD_DELETE,
-					rec, comp_mask,
-					timeout_ms, gfp_mask, callback,
-					context, query);
+	return __ib_sa_mcmember_rec_query(device, port_num,
+					  IB_SA_METHOD_DELETE,
+					  rec, comp_mask,
+					  timeout_ms, gfp_mask, callback,
+					  context, THIS_MODULE, query);
 }
 
+/**
+ * ib_sa_service_rec_query - Start Service Record operation
+ * @device:device to send request on
+ * @port_num: port number to send request on
+ * @method:SA method - should be get, set, or delete
+ * @rec:Service Record to send in request
+ * @comp_mask:component mask to send in request
+ * @timeout_ms:time to wait for response
+ * @gfp_mask:GFP mask to use for internal allocations
+ * @callback:function called when request completes, times out or is
+ * canceled
+ * @context:opaque user context passed to callback
+ * @sa_query:request context, used to cancel request
+ *
+ * Send a Service Record set/get/delete to the SA to register,
+ * unregister or query a service record.
+ * The callback function will be called when the request completes (or
+ * fails); status is 0 for a successful response, -EINTR if the query
+ * is canceled, -ETIMEDOUT is the query timed out, or -EIO if an error
+ * occurred sending the query.  The resp parameter of the callback is
+ * only valid if status is 0.
+ *
+ * If the return value of ib_sa_service_rec_query() is negative, it is an
+ * error code.  Otherwise it is a request ID that can be used to cancel
+ * the query.
+ */
+static inline int
+ib_sa_service_rec_query(struct ib_device *device, u8 port_num, u8 method,
+			struct ib_sa_service_rec *rec,
+			ib_sa_comp_mask comp_mask,
+			int timeout_ms, gfp_t gfp_mask,
+			void (*callback)(int status,
+					 struct ib_sa_service_rec *resp,
+					 void *context),
+			void *context,
+			struct ib_sa_query **sa_query)
+{
+	return __ib_sa_service_rec_query(device, port_num, method, rec,
+					 comp_mask, timeout_ms, gfp_mask,
+					 callback, context, THIS_MODULE,
+					 sa_query);
+}
 
 #endif /* IB_SA_H */



More information about the general mailing list