[CSRSRV] Only when CSRSRV is compiled in debugging mode, should we display debugging...
[reactos.git] / subsystems / win32 / csrsrv / api.c
index 5bb6b76..9c86c58 100644 (file)
@@ -3,7 +3,7 @@
  * PROJECT:         ReactOS Client/Server Runtime SubSystem
  * FILE:            subsystems/win32/csrsrv/api.c
  * PURPOSE:         CSR Server DLL API LPC Implementation
- *                  "\windows\ApiPort" port process management functions
+ *                  "\Windows\ApiPort" port process management functions
  * PROGRAMMERS:     Alex Ionescu (alex@relsoft.net)
  */
 
@@ -11,6 +11,8 @@
 
 #include "srv.h"
 
+#include <ndk/kefuncs.h>
+
 #define NDEBUG
 #include <debug.h>
 
@@ -18,8 +20,8 @@
 
 BOOLEAN (*CsrClientThreadSetup)(VOID) = NULL;
 UNICODE_STRING CsrApiPortName;
-volatile LONG CsrpStaticThreadCount;
-volatile LONG CsrpDynamicThreadTotal;
+volatile ULONG CsrpStaticThreadCount;
+volatile ULONG CsrpDynamicThreadTotal;
 extern ULONG CsrMaxApiRequestThreads;
 
 /* FUNCTIONS ******************************************************************/
@@ -63,7 +65,7 @@ CsrCallServerFromServer(IN PCSR_API_MESSAGE ReceiveMsg,
     {
         /* We are beyond the Maximum Server ID */
         DPRINT1("CSRSS: %lx is invalid ServerDllIndex (%08x)\n", ServerId, ServerDll);
-        ReplyMsg->Status = (ULONG)STATUS_ILLEGAL_FUNCTION;
+        ReplyMsg->Status = STATUS_ILLEGAL_FUNCTION;
         return STATUS_ILLEGAL_FUNCTION;
     }
     else
@@ -76,22 +78,28 @@ CsrCallServerFromServer(IN PCSR_API_MESSAGE ReceiveMsg,
             ((ServerDll->ValidTable) && !(ServerDll->ValidTable[ApiId])))
         {
             /* We are beyond the Maximum API ID, or it doesn't exist */
+#ifdef CSR_DBG
+            DPRINT1("API: %d\n", ApiId);
             DPRINT1("CSRSS: %lx (%s) is invalid ApiTableIndex for %Z or is an "
                     "invalid API to call from the server.\n",
-                    ServerDll->ValidTable[ApiId],
+                    ApiId,
                     ((ServerDll->NameTable) && (ServerDll->NameTable[ApiId])) ?
-                    ServerDll->NameTable[ApiId] : "*** UNKNOWN ***", &ServerDll->Name);
-            // DbgBreakPoint();
-            ReplyMsg->Status = (ULONG)STATUS_ILLEGAL_FUNCTION;
+                    ServerDll->NameTable[ApiId] : "*** UNKNOWN ***",
+                    &ServerDll->Name);
+            if (NtCurrentPeb()->BeingDebugged) DbgBreakPoint();
+#endif
+            ReplyMsg->Status = STATUS_ILLEGAL_FUNCTION;
             return STATUS_ILLEGAL_FUNCTION;
         }
     }
 
+#ifdef CSR_DBG
     if (CsrDebug & 2)
     {
         DPRINT1("CSRSS: %s Api Request received from server process\n",
                 ServerDll->NameTable[ApiId]);
     }
+#endif
 
     /* Validation complete, start SEH */
     _SEH2_TRY
@@ -134,7 +142,7 @@ CsrApiHandleConnectionRequest(IN PCSR_API_MESSAGE ApiMessage)
     PCSR_THREAD CsrThread = NULL;
     PCSR_PROCESS CsrProcess = NULL;
     NTSTATUS Status = STATUS_SUCCESS;
-    PCSR_CONNECTION_INFO ConnectInfo = &ApiMessage->ConnectionInfo;
+    PCSR_API_CONNECTINFO ConnectInfo = &ApiMessage->ConnectionInfo;
     BOOLEAN AllowConnection = FALSE;
     REMOTE_PORT_VIEW RemotePortView;
     HANDLE ServerPort;
@@ -148,47 +156,28 @@ CsrApiHandleConnectionRequest(IN PCSR_API_MESSAGE ApiMessage)
     /* Check if we have a thread */
     if (CsrThread)
     {
-        /* Get the Process */
+        /* Get the Process and make sure we have it as well */
         CsrProcess = CsrThread->Process;
-
-        /* Make sure we have a Process as well */
         if (CsrProcess)
         {
             /* Reference the Process */
-            CsrLockedReferenceProcess(CsrThread->Process);
+            CsrLockedReferenceProcess(CsrProcess);
 
-            /* Release the lock */
-            CsrReleaseProcessLock();
-
-            /* Duplicate the Object Directory */
-            Status = NtDuplicateObject(NtCurrentProcess(),
-                                       CsrObjectDirectory,
-                                       CsrProcess->ProcessHandle,
-                                       &ConnectInfo->ObjectDirectory,
-                                       0,
-                                       0,
-                                       DUPLICATE_SAME_ACCESS |
-                                       DUPLICATE_SAME_ATTRIBUTES);
-
-            /* Acquire the lock */
-            CsrAcquireProcessLock();
-
-            /* Check for success */
+            /* Attach the Shared Section */
+            Status = CsrSrvAttachSharedSection(CsrProcess, ConnectInfo);
             if (NT_SUCCESS(Status))
             {
-                /* Attach the Shared Section */
-                Status = CsrSrvAttachSharedSection(CsrProcess, ConnectInfo);
-
-                /* Check how this went */
-                if (NT_SUCCESS(Status)) AllowConnection = TRUE;
+                /* Allow the connection and return debugging flag */
+                ConnectInfo->DebugFlags = CsrDebug;
+                AllowConnection = TRUE;
             }
 
-            /* Dereference the project */
+            /* Dereference the Process */
             CsrLockedDereferenceProcess(CsrProcess);
         }
     }
 
-    /* Release the lock */
+    /* Release the Process Lock */
     CsrReleaseProcessLock();
 
     /* Setup the Port View Structure */
@@ -197,9 +186,10 @@ CsrApiHandleConnectionRequest(IN PCSR_API_MESSAGE ApiMessage)
     RemotePortView.ViewBase = NULL;
 
     /* Save the Process ID */
-    ConnectInfo->ProcessId = NtCurrentTeb()->ClientId.UniqueProcess;
+    ConnectInfo->ServerProcessId = NtCurrentTeb()->ClientId.UniqueProcess;
 
     /* Accept the Connection */
+    ASSERT(!AllowConnection || (AllowConnection && CsrProcess));
     Status = NtAcceptConnectPort(&ServerPort,
                                  AllowConnection ? UlongToPtr(CsrProcess->SequenceNumber) : 0,
                                  &ApiMessage->Header,
@@ -269,7 +259,7 @@ CsrpCheckRequestThreads(VOID)
     NTSTATUS Status;
 
     /* Decrease the count, and see if we're out */
-    if (_InterlockedDecrement(&CsrpStaticThreadCount) == 0)
+    if (InterlockedDecrementUL(&CsrpStaticThreadCount) == 0)
     {
         /* Check if we've still got space for a Dynamic Thread */
         if (CsrpDynamicThreadTotal < CsrMaxApiRequestThreads)
@@ -289,8 +279,8 @@ CsrpCheckRequestThreads(VOID)
             if (NT_SUCCESS(Status))
             {
                 /* Increase the thread counts */
-                _InterlockedIncrement(&CsrpStaticThreadCount);
-                _InterlockedIncrement(&CsrpDynamicThreadTotal);
+                InterlockedIncrementUL(&CsrpStaticThreadCount);
+                InterlockedIncrementUL(&CsrpDynamicThreadTotal);
 
                 /* Add a new server thread */
                 if (CsrAddStaticServerThread(hThread,
@@ -303,8 +293,8 @@ CsrpCheckRequestThreads(VOID)
                 else
                 {
                     /* Failed to create a new static thread */
-                    _InterlockedDecrement(&CsrpStaticThreadCount);
-                    _InterlockedDecrement(&CsrpDynamicThreadTotal);
+                    InterlockedDecrementUL(&CsrpStaticThreadCount);
+                    InterlockedDecrementUL(&CsrpDynamicThreadTotal);
 
                     /* Terminate it */
                     DPRINT1("Failing\n");
@@ -383,8 +373,8 @@ CsrApiRequestThread(IN PVOID Parameter)
         ASSERT(NT_SUCCESS(Status));
 
         /* Increase the Thread Counts */
-        _InterlockedIncrement(&CsrpStaticThreadCount);
-        _InterlockedIncrement(&CsrpDynamicThreadTotal);
+        InterlockedIncrementUL(&CsrpStaticThreadCount);
+        InterlockedIncrementUL(&CsrpDynamicThreadTotal);
     }
 
     /* Now start the loop */
@@ -393,6 +383,7 @@ CsrApiRequestThread(IN PVOID Parameter)
         /* Make sure the real CID is set */
         Teb->RealClientId = Teb->ClientId;
 
+#ifdef CSR_DBG
         /* Debug check */
         if (Teb->CountOfOwnedCriticalSections)
         {
@@ -402,6 +393,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                     &ReceiveMsg, ReplyMsg);
             DbgBreakPoint();
         }
+#endif
 
         /* Wait for a message to come through */
         Status = NtReplyWaitReceivePort(ReplyPort,
@@ -415,15 +407,17 @@ CsrApiRequestThread(IN PVOID Parameter)
             /* Was it a failure or another success code? */
             if (!NT_SUCCESS(Status))
             {
+#ifdef CSR_DBG
                 /* Check for specific status cases */
                 if ((Status != STATUS_INVALID_CID) &&
                     (Status != STATUS_UNSUCCESSFUL) &&
-                    ((Status == STATUS_INVALID_HANDLE) || (ReplyPort == CsrApiPort)))
+                    ((Status != STATUS_INVALID_HANDLE) || (ReplyPort == CsrApiPort)))
                 {
                     /* Notify the debugger */
                     DPRINT1("CSRSS: ReceivePort failed - Status == %X\n", Status);
                     DPRINT1("CSRSS: ReplyPortHandle %lx CsrApiPort %lx\n", ReplyPort, CsrApiPort);
                 }
+#endif
 
                 /* We failed big time, so start out fresh */
                 ReplyMsg = NULL;
@@ -432,12 +426,15 @@ CsrApiRequestThread(IN PVOID Parameter)
             }
             else
             {
-                /* A bizare "success" code, just try again */
+                /* A strange "success" code, just try again */
                 DPRINT1("NtReplyWaitReceivePort returned \"success\" status 0x%x\n", Status);
                 continue;
             }
         }
 
+        // ASSERT(ReceiveMsg.Header.u1.s1.TotalLength >= sizeof(PORT_MESSAGE));
+        // ASSERT(ReceiveMsg.Header.u1.s1.TotalLength <  sizeof(ReceiveMsg));
+
         /* Use whatever Client ID we got */
         Teb->RealClientId = ReceiveMsg.Header.ClientId;
 
@@ -513,7 +510,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                 }
 
                 /* Increase the thread count */
-                _InterlockedIncrement(&CsrpStaticThreadCount);
+                InterlockedIncrementUL(&CsrpStaticThreadCount);
 
                 /* If the response was 0xFFFFFFFF, we'll ignore it */
                 if (HardErrorMsg->Response == 0xFFFFFFFF)
@@ -545,9 +542,11 @@ CsrApiRequestThread(IN PVOID Parameter)
                     (!(ServerDll = CsrLoadedServerDll[ServerId])))
                 {
                     /* We are beyond the Maximum Server ID */
+#ifdef CSR_DBG
                     DPRINT1("CSRSS: %lx is invalid ServerDllIndex (%08x)\n",
                             ServerId, ServerDll);
-                    // DbgBreakPoint();
+                    if (NtCurrentPeb()->BeingDebugged) DbgBreakPoint();
+#endif
 
                     ReplyMsg = NULL;
                     ReplyPort = CsrApiPort;
@@ -570,6 +569,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                     continue;
                 }
 
+#ifdef CSR_DBG
                 if (CsrDebug & 2)
                 {
                     DPRINT1("[%02x] CSRSS: [%02x,%02x] - %s Api called from %08x\n",
@@ -579,6 +579,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                             ServerDll->NameTable[ApiId],
                             NULL);
                 }
+#endif
 
                 /* Assume success */
                 ReceiveMsg.Status = STATUS_SUCCESS;
@@ -595,7 +596,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                     ServerDll->DispatchTable[ApiId](&ReceiveMsg, &ReplyCode);
 
                     /* Increase the static thread count */
-                    _InterlockedIncrement(&CsrpStaticThreadCount);
+                    InterlockedIncrementUL(&CsrpStaticThreadCount);
                 }
                 _SEH2_EXCEPT(CsrUnhandledExceptionFilter(_SEH2_GetExceptionInformation()))
                 {
@@ -706,7 +707,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                 }
 
                 /* Increase the thread count */
-                _InterlockedIncrement(&CsrpStaticThreadCount);
+                InterlockedIncrementUL(&CsrpStaticThreadCount);
 
                 /* If the response was 0xFFFFFFFF, we'll ignore it */
                 if (HardErrorMsg->Response == 0xFFFFFFFF)
@@ -745,9 +746,11 @@ CsrApiRequestThread(IN PVOID Parameter)
             (!(ServerDll = CsrLoadedServerDll[ServerId])))
         {
             /* We are beyond the Maximum Server ID */
+#ifdef CSR_DBG
             DPRINT1("CSRSS: %lx is invalid ServerDllIndex (%08x)\n",
                     ServerId, ServerDll);
-            // DbgBreakPoint();
+            if (NtCurrentPeb()->BeingDebugged) DbgBreakPoint();
+#endif
 
             ReplyPort = CsrApiPort;
             ReplyMsg = &ReceiveMsg;
@@ -774,6 +777,7 @@ CsrApiRequestThread(IN PVOID Parameter)
             continue;
         }
 
+#ifdef CSR_DBG
         if (CsrDebug & 2)
         {
             DPRINT1("[%02x] CSRSS: [%02x,%02x] - %s Api called from %08x, Process %08x - %08x\n",
@@ -785,6 +789,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                     CsrThread->Process,
                     CsrProcess);
         }
+#endif
 
         /* Assume success */
         ReplyMsg = &ReceiveMsg;
@@ -818,7 +823,7 @@ CsrApiRequestThread(IN PVOID Parameter)
             ReplyMsg->Status = ServerDll->DispatchTable[ApiId](&ReceiveMsg, &ReplyCode);
 
             /* Increase the static thread count */
-            _InterlockedIncrement(&CsrpStaticThreadCount);
+            InterlockedIncrementUL(&CsrpStaticThreadCount);
 
             Teb->CsrClientThread = CurrentThread;
 
@@ -835,7 +840,10 @@ CsrApiRequestThread(IN PVOID Parameter)
             else if (ReplyCode == CsrReplyDeadClient)
             {
                 /* Reply to the death message */
-                NtReplyPort(ReplyPort, &ReplyMsg->Header);
+                NTSTATUS Status2;
+                Status2 = NtReplyPort(ReplyPort, &ReplyMsg->Header);
+                if (!NT_SUCCESS(Status2))
+                    DPRINT1("CSRSS: Error while replying to the death message, Status 0x%lx\n", Status2);
 
                 /* Reply back to the API port now */
                 ReplyMsg = NULL;
@@ -913,7 +921,7 @@ CsrApiPortInitialize(VOID)
     {
         DPRINT1("CSRSS: Creating %wZ port and associated threads\n", &CsrApiPortName);
         DPRINT1("CSRSS: sizeof( CONNECTINFO ) == %ld  sizeof( API_MSG ) == %ld\n",
-                sizeof(CSR_CONNECTION_INFO), sizeof(CSR_API_MESSAGE));
+                sizeof(CSR_API_CONNECTINFO), sizeof(CSR_API_MESSAGE));
     }
 
     /* FIXME: Create a Security Descriptor */
@@ -928,7 +936,7 @@ CsrApiPortInitialize(VOID)
     /* Create the Port Object */
     Status = NtCreatePort(&CsrApiPort,
                           &ObjectAttributes,
-                          sizeof(CSR_CONNECTION_INFO),
+                          sizeof(CSR_API_CONNECTINFO),
                           sizeof(CSR_API_MESSAGE),
                           16 * PAGE_SIZE);
     if (NT_SUCCESS(Status))
@@ -1009,7 +1017,6 @@ PCSR_THREAD
 NTAPI
 CsrConnectToUser(VOID)
 {
-#if 0 // FIXME: This code is OK, however it is ClientThreadSetup which sucks.
     NTSTATUS Status;
     ANSI_STRING DllName;
     UNICODE_STRING TempName;
@@ -1050,8 +1057,9 @@ CsrConnectToUser(VOID)
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
         Connected = FALSE;
-    } _SEH2_END;
-    
+    }
+    _SEH2_END;
+
     if (!Connected)
     {
         DPRINT1("CSRSS: CsrConnectToUser failed\n");
@@ -1066,21 +1074,6 @@ CsrConnectToUser(VOID)
 
     /* Return it */
     return CsrThread;
-
-#else
-
-    PTEB Teb = NtCurrentTeb();
-    PCSR_THREAD CsrThread;
-
-    /* Save pointer to this thread in TEB */
-    CsrAcquireProcessLock();
-    CsrThread = CsrLocateThreadInProcess(NULL, &Teb->ClientId);
-    CsrReleaseProcessLock();
-    if (CsrThread) Teb->CsrClientThread = CsrThread;
-
-    /* Return it */
-    return CsrThread;
-#endif
 }
 
 /*++
@@ -1154,11 +1147,13 @@ CsrCaptureArguments(IN PCSR_THREAD CsrThread,
         /* Check if the Length is valid */
         if ((FIELD_OFFSET(CSR_CAPTURE_BUFFER, PointerOffsetsArray) +
                 (LocalCaptureBuffer->PointerCount * sizeof(PVOID)) > Length) ||
-            (Length > MAXWORD))
+            (LocalCaptureBuffer->PointerCount > MAXUSHORT))
         {
-            /* Return failure */
+#ifdef CSR_DBG
             DPRINT1("*** CSRSS: CaptureBuffer %p has bad length\n", LocalCaptureBuffer);
-            DbgBreakPoint();
+            if (NtCurrentPeb()->BeingDebugged) DbgBreakPoint();
+#endif
+            /* Return failure */
             ApiMessage->Status = STATUS_INVALID_PARAMETER;
             _SEH2_YIELD(return FALSE);
         }
@@ -1209,9 +1204,11 @@ CsrCaptureArguments(IN PCSR_THREAD CsrThread,
             }
             else
             {
-                /* Invalid pointer, fail */
+#ifdef CSR_DBG
                 DPRINT1("*** CSRSS: CaptureBuffer MessagePointer outside of ClientView\n");
-                DbgBreakPoint();
+                if (NtCurrentPeb()->BeingDebugged) DbgBreakPoint();
+#endif
+                /* Invalid pointer, fail */
                 ApiMessage->Status = STATUS_INVALID_PARAMETER;
             }
         }
@@ -1398,8 +1395,10 @@ CsrValidateMessageBuffer(IN PCSR_API_MESSAGE ApiMessage,
     }
 
     /* Failure */
+#ifdef CSR_DBG
     DPRINT1("CSRSRV: Bad message buffer %p\n", ApiMessage);
-    DbgBreakPoint();
+    if (NtCurrentPeb()->BeingDebugged) DbgBreakPoint();
+#endif
     return FALSE;
 }
 
@@ -1424,7 +1423,7 @@ CsrValidateMessageBuffer(IN PCSR_API_MESSAGE ApiMessage,
 BOOLEAN
 NTAPI
 CsrValidateMessageString(IN PCSR_API_MESSAGE ApiMessage,
-                         IN LPWSTR *MessageString)
+                         IN PWSTR *MessageString)
 {
     if (MessageString)
     {