Fixed race condition bug in ioqueue unregistration for select and Win32 IOCP backend

git-svn-id: https://svn.pjsip.org/repos/pjproject/trunk@365 74dad513-b988-da41-8d7b-12977e46ad98
diff --git a/pjlib/src/pj/ioqueue_select.c b/pjlib/src/pj/ioqueue_select.c
index 16a511a..4aa4f91 100644
--- a/pjlib/src/pj/ioqueue_select.c
+++ b/pjlib/src/pj/ioqueue_select.c
@@ -109,12 +109,18 @@
     DECLARE_COMMON_IOQUEUE
 
     unsigned		max, count;
-    pj_ioqueue_key_t	key_list;
+    pj_ioqueue_key_t	active_list;
     pj_fd_set_t		rfdset;
     pj_fd_set_t		wfdset;
 #if PJ_HAS_TCP
     pj_fd_set_t		xfdset;
 #endif
+
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+    pj_mutex_t	       *ref_cnt_mutex;
+    pj_ioqueue_key_t	closing_list;
+    pj_ioqueue_key_t	free_list;
+#endif
 };
 
 /* Include implementation for common abstraction after we declare
@@ -141,6 +147,7 @@
 {
     pj_ioqueue_t *ioqueue;
     pj_lock_t *lock;
+    unsigned i;
     pj_status_t rc;
 
     /* Check that arguments are valid. */
@@ -152,8 +159,8 @@
     PJ_ASSERT_RETURN(sizeof(pj_ioqueue_op_key_t)-sizeof(void*) >=
                      sizeof(union operation_key), PJ_EBUG);
 
+    /* Create and init common ioqueue stuffs */
     ioqueue = pj_pool_alloc(pool, sizeof(pj_ioqueue_t));
-
     ioqueue_init(ioqueue);
 
     ioqueue->max = max_fd;
@@ -163,8 +170,49 @@
 #if PJ_HAS_TCP
     PJ_FD_ZERO(&ioqueue->xfdset);
 #endif
-    pj_list_init(&ioqueue->key_list);
+    pj_list_init(&ioqueue->active_list);
 
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+    /* When safe unregistration is used (the default), we pre-create
+     * all keys and put them in the free list.
+     */
+
+    /* Mutex to protect key's reference counter 
+     * We don't want to use key's mutex or ioqueue's mutex because
+     * that would create deadlock situation in some cases.
+     */
+    rc = pj_mutex_create_simple(pool, NULL, &ioqueue->ref_cnt_mutex);
+    if (rc != PJ_SUCCESS)
+	return rc;
+
+
+    /* Init key list */
+    pj_list_init(&ioqueue->free_list);
+    pj_list_init(&ioqueue->closing_list);
+
+
+    /* Pre-create all keys according to max_fd */
+    for (i=0; i<max_fd; ++i) {
+	pj_ioqueue_key_t *key;
+
+	key = pj_pool_alloc(pool, sizeof(pj_ioqueue_key_t));
+	key->ref_count = 0;
+	rc = pj_mutex_create_recursive(pool, NULL, &key->mutex);
+	if (rc != PJ_SUCCESS) {
+	    key = ioqueue->free_list.next;
+	    while (key != &ioqueue->free_list) {
+		pj_mutex_destroy(key->mutex);
+		key = key->next;
+	    }
+	    pj_mutex_destroy(ioqueue->ref_cnt_mutex);
+	    return rc;
+	}
+
+	pj_list_push_back(&ioqueue->free_list, key);
+    }
+#endif
+
+    /* Create and init ioqueue mutex */
     rc = pj_lock_create_simple_mutex(pool, "ioq%p", &lock);
     if (rc != PJ_SUCCESS)
 	return rc;
@@ -186,9 +234,35 @@
  */
 PJ_DEF(pj_status_t) pj_ioqueue_destroy(pj_ioqueue_t *ioqueue)
 {
+    pj_ioqueue_key_t *key;
+
     PJ_ASSERT_RETURN(ioqueue, PJ_EINVAL);
 
     pj_lock_acquire(ioqueue->lock);
+
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+    /* Destroy reference counters */
+    key = ioqueue->active_list.next;
+    while (key != &ioqueue->active_list) {
+	pj_mutex_destroy(key->mutex);
+	key = key->next;
+    }
+
+    key = ioqueue->closing_list.next;
+    while (key != &ioqueue->closing_list) {
+	pj_mutex_destroy(key->mutex);
+	key = key->next;
+    }
+
+    key = ioqueue->free_list.next;
+    while (key != &ioqueue->free_list) {
+	pj_mutex_destroy(key->mutex);
+	key = key->next;
+    }
+
+    pj_mutex_destroy(ioqueue->ref_cnt_mutex);
+#endif
+
     return ioqueue_destroy(ioqueue);
 }
 
@@ -196,7 +270,7 @@
 /*
  * pj_ioqueue_register_sock()
  *
- * Register a handle to ioqueue.
+ * Register socket handle to ioqueue.
  */
 PJ_DEF(pj_status_t) pj_ioqueue_register_sock( pj_pool_t *pool,
 					      pj_ioqueue_t *ioqueue,
@@ -219,6 +293,28 @@
 	goto on_return;
     }
 
+    /* If safe unregistration (PJ_IOQUEUE_HAS_SAFE_UNREG) is used, get
+     * the key from the free list. Otherwise allocate a new one. 
+     */
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+    pj_assert(!pj_list_empty(&ioqueue->free_list));
+    if (pj_list_empty(&ioqueue->free_list)) {
+	rc = PJ_ETOOMANY;
+	goto on_return;
+    }
+
+    key = ioqueue->free_list.next;
+    pj_list_erase(key);
+#else
+    key = (pj_ioqueue_key_t*)pj_pool_zalloc(pool, sizeof(pj_ioqueue_key_t));
+#endif
+
+    rc = ioqueue_init_key(pool, ioqueue, key, sock, user_data, cb);
+    if (rc != PJ_SUCCESS) {
+	key = NULL;
+	goto on_return;
+    }
+
     /* Set socket to nonblocking. */
     value = 1;
 #if defined(PJ_WIN32) && PJ_WIN32!=0 || \
@@ -231,16 +327,9 @@
 	goto on_return;
     }
 
-    /* Create key. */
-    key = (pj_ioqueue_key_t*)pj_pool_zalloc(pool, sizeof(pj_ioqueue_key_t));
-    rc = ioqueue_init_key(pool, ioqueue, key, sock, user_data, cb);
-    if (rc != PJ_SUCCESS) {
-	key = NULL;
-	goto on_return;
-    }
 
-    /* Register */
-    pj_list_insert_before(&ioqueue->key_list, key);
+    /* Put in active list. */
+    pj_list_insert_before(&ioqueue->active_list, key);
     ++ioqueue->count;
 
 on_return:
@@ -251,6 +340,41 @@
     return rc;
 }
 
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+/* Increment key's reference counter */
+static void increment_counter(pj_ioqueue_key_t *key)
+{
+    pj_mutex_lock(key->ioqueue->ref_cnt_mutex);
+    ++key->ref_count;
+    pj_mutex_unlock(key->ioqueue->ref_cnt_mutex);
+}
+
+/* Decrement the key's reference counter, and when the counter reach zero,
+ * destroy the key.
+ *
+ * Note: MUST NOT CALL THIS FUNCTION WHILE HOLDING ioqueue's LOCK.
+ */
+static void decrement_counter(pj_ioqueue_key_t *key)
+{
+    pj_mutex_lock(key->ioqueue->ref_cnt_mutex);
+    --key->ref_count;
+    if (key->ref_count == 0) {
+
+	pj_assert(key->closing == 1);
+	pj_gettimeofday(&key->free_time);
+	key->free_time.msec += PJ_IOQUEUE_KEY_FREE_DELAY;
+	pj_time_val_normalize(&key->free_time);
+
+	pj_lock_acquire(key->ioqueue->lock);
+	pj_list_erase(key);
+	pj_list_push_back(&key->ioqueue->closing_list, key);
+	pj_lock_release(key->ioqueue->lock);
+    }
+    pj_mutex_unlock(key->ioqueue->ref_cnt_mutex);
+}
+#endif
+
+
 /*
  * pj_ioqueue_unregister()
  *
@@ -264,6 +388,13 @@
 
     ioqueue = key->ioqueue;
 
+    /* Lock the key to make sure no callback is simultaneously modifying
+     * the key. We need to lock the key before ioqueue here to prevent
+     * deadlock.
+     */
+    pj_mutex_lock(key->mutex);
+
+    /* Also lock ioqueue */
     pj_lock_acquire(ioqueue->lock);
 
     pj_assert(ioqueue->count > 0);
@@ -275,15 +406,32 @@
     PJ_FD_CLR(key->fd, &ioqueue->xfdset);
 #endif
 
-    /* ioqueue_destroy may try to acquire key's mutex.
-     * Since normally the order of locking is to lock key's mutex first
-     * then ioqueue's mutex, ioqueue_destroy may deadlock unless we
-     * release ioqueue's mutex first.
+    /* Close socket. */
+    pj_sock_close(key->fd);
+
+    /* Clear callback */
+    key->cb.on_accept_complete = NULL;
+    key->cb.on_connect_complete = NULL;
+    key->cb.on_read_complete = NULL;
+    key->cb.on_write_complete = NULL;
+
+    /* Must release ioqueue lock first before decrementing counter, to
+     * prevent deadlock.
      */
     pj_lock_release(ioqueue->lock);
 
-    /* Destroy the key. */
-    ioqueue_destroy_key(key);
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+    /* Mark key is closing. */
+    key->closing = 1;
+
+    /* Decrement counter. */
+    decrement_counter(key);
+
+    /* Done. */
+    pj_mutex_unlock(key->mutex);
+#else
+    pj_mutex_destroy(key->mutex);
+#endif
 
     return PJ_SUCCESS;
 }
@@ -308,8 +456,8 @@
      */
     pj_assert(0);
 
-    key = ioqueue->key_list.next;
-    while (key != &ioqueue->key_list) {
+    key = ioqueue->active_list.next;
+    while (key != &ioqueue->active_list) {
 	if (!pj_list_empty(&key->read_list)
 #if defined(PJ_HAS_TCP) && PJ_HAS_TCP != 0
 	    || !pj_list_empty(&key->accept_list)
@@ -395,6 +543,30 @@
     pj_lock_release(ioqueue->lock);
 }
 
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+/* Scan closing keys to be put to free list again */
+static void scan_closing_keys(pj_ioqueue_t *ioqueue)
+{
+    pj_time_val now;
+    pj_ioqueue_key_t *h;
+
+    pj_gettimeofday(&now);
+    h = ioqueue->closing_list.next;
+    while (h != &ioqueue->closing_list) {
+	pj_ioqueue_key_t *next = h->next;
+
+	pj_assert(h->closing != 0);
+
+	if (PJ_TIME_VAL_GTE(now, h->free_time)) {
+	    pj_list_erase(h);
+	    pj_list_push_back(&ioqueue->free_list, h);
+	}
+	h = next;
+    }
+}
+#endif
+
+
 /*
  * pj_ioqueue_poll()
  *
@@ -435,7 +607,10 @@
         PJ_FD_COUNT(&ioqueue->wfdset)==0 &&
         PJ_FD_COUNT(&ioqueue->xfdset)==0)
     {
-        pj_lock_release(ioqueue->lock);
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+	scan_closing_keys(ioqueue);
+#endif
+	pj_lock_release(ioqueue->lock);
         if (timeout)
             pj_thread_sleep(PJ_TIME_VAL_MSEC(*timeout));
         return 0;
@@ -475,11 +650,15 @@
     /* Scan for writable sockets first to handle piggy-back data
      * coming with accept().
      */
-    h = ioqueue->key_list.next;
-    for ( ; h!=&ioqueue->key_list && counter<count; h = h->next) {
+    h = ioqueue->active_list.next;
+    for ( ; h!=&ioqueue->active_list && counter<count; h = h->next) {
+
 	if ( (key_has_pending_write(h) || key_has_pending_connect(h))
-	     && PJ_FD_ISSET(h->fd, &wfdset))
+	     && PJ_FD_ISSET(h->fd, &wfdset) && !h->closing)
         {
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+	    increment_counter(h);
+#endif
             event[counter].key = h;
             event[counter].event_type = WRITEABLE_EVENT;
             ++counter;
@@ -487,15 +666,23 @@
 
         /* Scan for readable socket. */
 	if ((key_has_pending_read(h) || key_has_pending_accept(h))
-            && PJ_FD_ISSET(h->fd, &rfdset))
+            && PJ_FD_ISSET(h->fd, &rfdset) && !h->closing)
         {
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+	    increment_counter(h);
+#endif
             event[counter].key = h;
             event[counter].event_type = READABLE_EVENT;
             ++counter;
 	}
 
 #if PJ_HAS_TCP
-        if (key_has_pending_connect(h) && PJ_FD_ISSET(h->fd, &xfdset)) {
+        if (key_has_pending_connect(h) && PJ_FD_ISSET(h->fd, &xfdset) &&
+	    !h->closing) 
+	{
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+	    increment_counter(h);
+#endif
             event[counter].key = h;
             event[counter].event_type = EXCEPTION_EVENT;
             ++counter;
@@ -525,8 +712,13 @@
             pj_assert(!"Invalid event!");
             break;
         }
+
+#if PJ_IOQUEUE_HAS_SAFE_UNREG
+	decrement_counter(event[counter].key);
+#endif
     }
 
+
     return count;
 }