Sync to Wine-20050830:
[reactos.git] / reactos / lib / rpcrt4 / rpc_server.c
index 7f03aa8..c748ac4 100644 (file)
@@ -50,7 +50,7 @@
 
 #define MAX_THREADS 128
 
-WINE_DEFAULT_DEBUG_CHANNEL(ole);
+WINE_DEFAULT_DEBUG_CHANNEL(rpc);
 
 typedef struct _RpcPacket
 {
@@ -91,9 +91,18 @@ static CRITICAL_SECTION_DEBUG listen_cs_debug =
 };
 static CRITICAL_SECTION listen_cs = { &listen_cs_debug, -1, 0, 0, 0, 0 };
 
+/* whether the server is currently listening */
 static BOOL std_listen;
-static LONG listen_count = -1;
-static HANDLE mgr_event, server_thread;
+/* number of manual listeners (calls to RpcServerListen) */
+static LONG manual_listen_count;
+/* total listeners including auto listeners */
+static LONG listen_count;
+/* set on change of configuration (e.g. listening on new protseq) */
+static HANDLE mgr_event;
+/* mutex for ensuring only one thread can change state at a time */
+static HANDLE mgr_mutex;
+/* set when server thread has finished opening connections */
+static HANDLE server_ready_event;
 
 static CRITICAL_SECTION spacket_cs;
 static CRITICAL_SECTION_DEBUG spacket_cs_debug =
@@ -108,7 +117,7 @@ static RpcPacket* spacket_head;
 static RpcPacket* spacket_tail;
 static HANDLE server_sem;
 
-static DWORD worker_count, worker_free, worker_tls;
+static LONG worker_count, worker_free, worker_tls;
 
 static UUID uuid_nil;
 
@@ -149,7 +158,7 @@ static RpcServerInterface* RPCRT4_find_interface(UUID* object,
   while (cif) {
     if (!memcmp(if_id, &cif->If->InterfaceId, sizeof(RPC_SYNTAX_IDENTIFIER)) &&
         (check_object == FALSE || UuidEqual(MgrType, &cif->MgrTypeUuid, &status)) &&
-        (std_listen || (cif->Flags & RPC_IF_AUTOLISTEN))) break;
+        std_listen) break;
     cif = cif->Next;
   }
   LeaveCriticalSection(&server_cs);
@@ -434,7 +443,7 @@ static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg)
 #endif
     msg = NULL;
   }
-  if (msg) HeapFree(GetProcessHeap(), 0, msg);
+  HeapFree(GetProcessHeap(), 0, msg);
   RPCRT4_DestroyConnection(conn);
   return 0;
 }
@@ -463,6 +472,7 @@ static DWORD CALLBACK RPCRT4_server_thread(LPVOID the_arg)
   RpcServerProtseq* cps;
   RpcConnection* conn;
   RpcConnection* cconn;
+  BOOL set_ready_event = FALSE;
 
   TRACE("(the_arg == ^%p)\n", the_arg);
 
@@ -475,7 +485,7 @@ static DWORD CALLBACK RPCRT4_server_thread(LPVOID the_arg)
       conn = cps->conn;
       while (conn) {
         RPCRT4_OpenConnection(conn);
-        if (conn->ovl.hEvent) count++;
+        if (conn->ovl[0].hEvent) count++;
         conn = conn->Next;
       }
       cps = cps->Next;
@@ -492,18 +502,29 @@ static DWORD CALLBACK RPCRT4_server_thread(LPVOID the_arg)
     while (cps) {
       conn = cps->conn;
       while (conn) {
-        if (conn->ovl.hEvent) objs[count++] = conn->ovl.hEvent;
+        if (conn->ovl[0].hEvent) objs[count++] = conn->ovl[0].hEvent;
         conn = conn->Next;
       }
       cps = cps->Next;
     }
     LeaveCriticalSection(&server_cs);
 
+    if (set_ready_event)
+    {
+        /* signal to function that changed state that we are now sync'ed */
+        SetEvent(server_ready_event);
+        set_ready_event = FALSE;
+    }
+
     /* start waiting */
     res = WaitForMultipleObjects(count, objs, FALSE, INFINITE);
     if (res == WAIT_OBJECT_0) {
-      ResetEvent(m_event);
-      if (!std_listen) break;
+      if (!std_listen)
+      {
+        SetEvent(server_ready_event);
+        break;
+      }
+      set_ready_event = TRUE;
     }
     else if (res == WAIT_FAILED) {
       ERR("wait failed\n");
@@ -517,7 +538,7 @@ static DWORD CALLBACK RPCRT4_server_thread(LPVOID the_arg)
       while (cps) {
         conn = cps->conn;
         while (conn) {
-          if (conn->ovl.hEvent == b_handle) break;
+          if (conn->ovl[0].hEvent == b_handle) break;
           conn = conn->Next;
         }
         if (conn) break;
@@ -548,36 +569,64 @@ static DWORD CALLBACK RPCRT4_server_thread(LPVOID the_arg)
   return 0;
 }
 
-static void RPCRT4_start_listen(void)
+/* tells the server thread that the state has changed and waits for it to
+ * make the changes */
+static void RPCRT4_sync_with_server_thread(void)
+{
+  /* make sure we are the only thread sync'ing the server state, otherwise
+   * there is a race with the server thread setting an older state and setting
+   * the server_ready_event when the new state hasn't yet been applied */
+  WaitForSingleObject(mgr_mutex, INFINITE);
+
+  SetEvent(mgr_event);
+  /* wait for server thread to make the requested changes before returning */
+  WaitForSingleObject(server_ready_event, INFINITE);
+
+  ReleaseMutex(mgr_mutex);
+}
+
+static RPC_STATUS RPCRT4_start_listen(BOOL auto_listen)
 {
+  RPC_STATUS status = RPC_S_ALREADY_LISTENING;
+
   TRACE("\n");
 
   EnterCriticalSection(&listen_cs);
-  if (! ++listen_count) {
-    if (!mgr_event) mgr_event = CreateEventA(NULL, TRUE, FALSE, NULL);
-    if (!server_sem) server_sem = CreateSemaphoreA(NULL, 0, MAX_THREADS, NULL);
-    if (!worker_tls) worker_tls = TlsAlloc();
-    std_listen = TRUE;
-    server_thread = CreateThread(NULL, 0, RPCRT4_server_thread, NULL, 0, NULL);
-    LeaveCriticalSection(&listen_cs);
-  } else {
-    LeaveCriticalSection(&listen_cs);
-    SetEvent(mgr_event);
+  if (auto_listen || (manual_listen_count++ == 0))
+  {
+    status = RPC_S_OK;
+    if (++listen_count == 1) {
+      HANDLE server_thread;
+      /* first listener creates server thread */
+      if (!mgr_mutex) mgr_mutex = CreateMutexW(NULL, FALSE, NULL);
+      if (!mgr_event) mgr_event = CreateEventW(NULL, FALSE, FALSE, NULL);
+      if (!server_ready_event) server_ready_event = CreateEventW(NULL, FALSE, FALSE, NULL);
+      if (!server_sem) server_sem = CreateSemaphoreW(NULL, 0, MAX_THREADS, NULL);
+      if (!worker_tls) worker_tls = TlsAlloc();
+      std_listen = TRUE;
+      server_thread = CreateThread(NULL, 0, RPCRT4_server_thread, NULL, 0, NULL);
+      CloseHandle(server_thread);
+    }
   }
+  LeaveCriticalSection(&listen_cs);
+
+  return status;
 }
 
-static void RPCRT4_stop_listen(void)
+static void RPCRT4_stop_listen(BOOL auto_listen)
 {
   EnterCriticalSection(&listen_cs);
-  if (listen_count == -1)
-    LeaveCriticalSection(&listen_cs);
-  else if (--listen_count == -1) {
-    std_listen = FALSE;
-    LeaveCriticalSection(&listen_cs);
-    SetEvent(mgr_event);
-  } else
-    LeaveCriticalSection(&listen_cs);
-  assert(listen_count > -2);
+  if (auto_listen || (--manual_listen_count == 0))
+  {
+    if (listen_count != 0 && --listen_count == 0) {
+      std_listen = FALSE;
+      LeaveCriticalSection(&listen_cs);
+      RPCRT4_sync_with_server_thread();
+      return;
+    }
+    assert(listen_count >= 0);
+  }
+  LeaveCriticalSection(&listen_cs);
 }
 
 static RPC_STATUS RPCRT4_use_protseq(RpcServerProtseq* ps)
@@ -589,7 +638,7 @@ static RPC_STATUS RPCRT4_use_protseq(RpcServerProtseq* ps)
   protseqs = ps;
   LeaveCriticalSection(&server_cs);
 
-  if (std_listen) SetEvent(mgr_event);
+  if (std_listen) RPCRT4_sync_with_server_thread();
 
   return RPC_S_OK;
 }
@@ -690,14 +739,14 @@ RPC_STATUS WINAPI RpcServerUseProtseqEpExA( unsigned char *Protseq, UINT MaxCall
 {
   RpcServerProtseq* ps;
 
-  TRACE("(%s,%u,%s,%p,{%u,%lu,%lu})\n", debugstr_a( Protseq ), MaxCalls,
-       debugstr_a( Endpoint ), SecurityDescriptor,
+  TRACE("(%s,%u,%s,%p,{%u,%lu,%lu})\n", debugstr_a( (char*)Protseq ), MaxCalls,
+       debugstr_a( (char*)Endpoint ), SecurityDescriptor,
        lpPolicy->Length, lpPolicy->EndpointFlags, lpPolicy->NICFlags );
 
   ps = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(RpcServerProtseq));
   ps->MaxCalls = MaxCalls;
-  ps->Protseq = RPCRT4_strdupA(Protseq);
-  ps->Endpoint = RPCRT4_strdupA(Endpoint);
+  ps->Protseq = RPCRT4_strdupA((char*)Protseq);
+  ps->Endpoint = RPCRT4_strdupA((char*)Endpoint);
 
   return RPCRT4_use_protseq(ps);
 }
@@ -727,7 +776,7 @@ RPC_STATUS WINAPI RpcServerUseProtseqEpExW( LPWSTR Protseq, UINT MaxCalls, LPWST
  */
 RPC_STATUS WINAPI RpcServerUseProtseqA(unsigned char *Protseq, unsigned int MaxCalls, void *SecurityDescriptor)
 {
-  TRACE("(Protseq == %s, MaxCalls == %d, SecurityDescriptor == ^%p)\n", debugstr_a(Protseq), MaxCalls, SecurityDescriptor);
+  TRACE("(Protseq == %s, MaxCalls == %d, SecurityDescriptor == ^%p)\n", debugstr_a((char*)Protseq), MaxCalls, SecurityDescriptor);
   return RpcServerUseProtseqEpA(Protseq, MaxCalls, NULL, SecurityDescriptor);
 }
 
@@ -767,7 +816,7 @@ RPC_STATUS WINAPI RpcServerRegisterIf2( RPC_IF_HANDLE IfSpec, UUID* MgrTypeUuid,
 {
   PRPC_SERVER_INTERFACE If = (PRPC_SERVER_INTERFACE)IfSpec;
   RpcServerInterface* sif;
-  int i;
+  unsigned int i;
 
   TRACE("(%p,%s,%p,%u,%u,%u,%p)\n", IfSpec, debugstr_guid(MgrTypeUuid), MgrEpv, Flags, MaxCalls,
          MaxRpcSize, IfCallbackFn);
@@ -809,8 +858,11 @@ RPC_STATUS WINAPI RpcServerRegisterIf2( RPC_IF_HANDLE IfSpec, UUID* MgrTypeUuid,
   LeaveCriticalSection(&server_cs);
 
   if (sif->Flags & RPC_IF_AUTOLISTEN) {
-    /* well, start listening, I think... */
-    RPCRT4_start_listen();
+    RPCRT4_start_listen(TRUE);
+
+    /* make sure server is actually listening on the interface before
+     * returning */
+    RPCRT4_sync_with_server_thread();
   }
 
   return RPC_S_OK;
@@ -905,7 +957,7 @@ RPC_STATUS WINAPI RpcObjectSetType( UUID* ObjUuid, UUID* TypeUuid )
 /***********************************************************************
  *             RpcServerRegisterAuthInfoA (RPCRT4.@)
  */
-RPC_STATUS WINAPI RpcServerRegisterAuthInfoA( unsigned char *ServerPrincName, ULONG AuthnSvc, RPC_AUTH_KEY_RETRIEVAL_FN GetKeyFn,
+RPC_STATUS WINAPI RpcServerRegisterAuthInfoA( unsigned char *ServerPrincName, unsigned long AuthnSvc, RPC_AUTH_KEY_RETRIEVAL_FN GetKeyFn,
                             LPVOID Arg )
 {
   FIXME( "(%s,%lu,%p,%p): stub\n", ServerPrincName, AuthnSvc, GetKeyFn, Arg );
@@ -916,7 +968,7 @@ RPC_STATUS WINAPI RpcServerRegisterAuthInfoA( unsigned char *ServerPrincName, UL
 /***********************************************************************
  *             RpcServerRegisterAuthInfoW (RPCRT4.@)
  */
-RPC_STATUS WINAPI RpcServerRegisterAuthInfoW( LPWSTR ServerPrincName, ULONG AuthnSvc, RPC_AUTH_KEY_RETRIEVAL_FN GetKeyFn,
+RPC_STATUS WINAPI RpcServerRegisterAuthInfoW( LPWSTR ServerPrincName, unsigned long AuthnSvc, RPC_AUTH_KEY_RETRIEVAL_FN GetKeyFn,
                             LPVOID Arg )
 {
   FIXME( "(%s,%lu,%p,%p): stub\n", debugstr_w( ServerPrincName ), AuthnSvc, GetKeyFn, Arg );
@@ -929,23 +981,16 @@ RPC_STATUS WINAPI RpcServerRegisterAuthInfoW( LPWSTR ServerPrincName, ULONG Auth
  */
 RPC_STATUS WINAPI RpcServerListen( UINT MinimumCallThreads, UINT MaxCalls, UINT DontWait )
 {
+  RPC_STATUS status;
+
   TRACE("(%u,%u,%u)\n", MinimumCallThreads, MaxCalls, DontWait);
 
   if (!protseqs)
     return RPC_S_NO_PROTSEQS_REGISTERED;
 
-  EnterCriticalSection(&listen_cs);
-
-  if (std_listen) {
-    LeaveCriticalSection(&listen_cs);
-    return RPC_S_ALREADY_LISTENING;
-  }
-
-  RPCRT4_start_listen();
+  status = RPCRT4_start_listen(FALSE);
 
-  LeaveCriticalSection(&listen_cs);
-
-  if (DontWait) return RPC_S_OK;
+  if (DontWait || (status != RPC_S_OK)) return status;
 
   return RpcMgmtWaitServerListen();
 }
@@ -955,29 +1000,20 @@ RPC_STATUS WINAPI RpcServerListen( UINT MinimumCallThreads, UINT MaxCalls, UINT
  */
 RPC_STATUS WINAPI RpcMgmtWaitServerListen( void )
 {
-  RPC_STATUS rslt = RPC_S_OK;
-
-  TRACE("\n");
+  TRACE("()\n");
 
   EnterCriticalSection(&listen_cs);
 
-  if (!std_listen)
-    if ( (rslt = RpcServerListen(1, 0, TRUE)) != RPC_S_OK ) {
-      LeaveCriticalSection(&listen_cs);
-      return rslt;
-    }
+  if (!std_listen) {
+    LeaveCriticalSection(&listen_cs);
+    return RPC_S_NOT_LISTENING;
+  }
   
   LeaveCriticalSection(&listen_cs);
 
-  while (std_listen) {
-    WaitForSingleObject(mgr_event, INFINITE);
-    if (!std_listen) {
-      Sleep(100); /* don't spin violently */
-      TRACE("spinning.\n");
-    }
-  }
+  RPCRT4_sync_with_server_thread();
 
-  return rslt;
+  return RPC_S_OK;
 }
 
 /***********************************************************************
@@ -992,11 +1028,7 @@ RPC_STATUS WINAPI RpcMgmtStopServerListening ( RPC_BINDING_HANDLE Binding )
     return RPC_S_WRONG_KIND_OF_BINDING;
   }
   
-  /* hmm... */
-  EnterCriticalSection(&listen_cs);
-  while (std_listen)
-    RPCRT4_stop_listen();
-  LeaveCriticalSection(&listen_cs);
+  RPCRT4_stop_listen(FALSE);
 
   return RPC_S_OK;
 }