[CONSRV]
[reactos.git] / subsystems / win32 / csrsrv / api.c
index 0f4c60c..80e3d9d 100644 (file)
@@ -52,8 +52,7 @@ CsrCallServerFromServer(IN PCSR_API_MESSAGE ReceiveMsg,
     ULONG ServerId;
     PCSR_SERVER_DLL ServerDll;
     ULONG ApiId;
-    ULONG Reply;
-    NTSTATUS Status;
+    CSR_REPLY_CODE ReplyCode = CsrReplyImmediately;
 
     /* Get the Server ID */
     ServerId = CSR_API_NUMBER_TO_SERVER_ID(ReceiveMsg->ApiNumber);
@@ -97,12 +96,8 @@ CsrCallServerFromServer(IN PCSR_API_MESSAGE ReceiveMsg,
     /* Validation complete, start SEH */
     _SEH2_TRY
     {
-        /* Call the API and get the result */
-        /// CsrApiCallHandler(ReplyMsg, /*ProcessData*/ &ReplyCode); ///
-        Status = (ServerDll->DispatchTable[ApiId])(ReceiveMsg, &Reply);
-
-        /* Return the result, no matter what it is */
-        ReplyMsg->Status = Status;
+        /* Call the API, get the reply code and return the result */
+        ReplyMsg->Status = ServerDll->DispatchTable[ApiId](ReceiveMsg, &ReplyCode);
     }
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
@@ -142,7 +137,7 @@ CsrApiHandleConnectionRequest(IN PCSR_API_MESSAGE ApiMessage)
     PCSR_CONNECTION_INFO ConnectInfo = &ApiMessage->ConnectionInfo;
     BOOLEAN AllowConnection = FALSE;
     REMOTE_PORT_VIEW RemotePortView;
-    HANDLE hPort;
+    HANDLE ServerPort;
 
     /* Acquire the Process Lock */
     CsrAcquireProcessLock();
@@ -205,7 +200,7 @@ CsrApiHandleConnectionRequest(IN PCSR_API_MESSAGE ApiMessage)
     ConnectInfo->ProcessId = NtCurrentTeb()->ClientId.UniqueProcess;
 
     /* Accept the Connection */
-    Status = NtAcceptConnectPort(&hPort,
+    Status = NtAcceptConnectPort(&ServerPort,
                                  AllowConnection ? UlongToPtr(CsrProcess->SequenceNumber) : 0,
                                  &ApiMessage->Header,
                                  AllowConnection,
@@ -227,13 +222,13 @@ CsrApiHandleConnectionRequest(IN PCSR_API_MESSAGE ApiMessage)
         }
 
         /* Set some Port Data in the Process */
-        CsrProcess->ClientPort = hPort;
+        CsrProcess->ClientPort = ServerPort;
         CsrProcess->ClientViewBase = (ULONG_PTR)RemotePortView.ViewBase;
         CsrProcess->ClientViewBounds = (ULONG_PTR)((ULONG_PTR)RemotePortView.ViewBase +
                                                    (ULONG_PTR)RemotePortView.ViewSize);
 
         /* Complete the connection */
-        Status = NtCompleteConnectPort(hPort);
+        Status = NtCompleteConnectPort(ServerPort);
         if (!NT_SUCCESS(Status))
         {
             DPRINT1("CSRSS: NtCompleteConnectPort - failed.  Status == %X\n", Status);
@@ -250,107 +245,6 @@ CsrApiHandleConnectionRequest(IN PCSR_API_MESSAGE ApiMessage)
     return Status;
 }
 
-// TODO: See CsrApiHandleConnectionRequest
-NTSTATUS WINAPI
-CsrpHandleConnectionRequest(PPORT_MESSAGE Request)
-{
-    NTSTATUS Status;
-    HANDLE ServerPort = NULL;//, ServerThread = NULL;
-    PCSR_PROCESS ProcessData = NULL;
-    REMOTE_PORT_VIEW RemotePortView;
-//    CLIENT_ID ClientId;
-    BOOLEAN AllowConnection = FALSE;
-    PCSR_CONNECTION_INFO ConnectInfo;
-    ServerPort = NULL;
-
-    DPRINT1("CSR: %s: Handling: %p\n", __FUNCTION__, Request);
-
-    ConnectInfo = (PCSR_CONNECTION_INFO)(Request + 1);
-
-    /* Save the process ID */
-    RtlZeroMemory(ConnectInfo, sizeof(CSR_CONNECTION_INFO));
-
-    CsrLockProcessByClientId(Request->ClientId.UniqueProcess, &ProcessData);
-    if (!ProcessData)
-    {
-        DPRINT1("CSRSRV: Unknown process: %lx. Will be rejecting connection\n",
-                Request->ClientId.UniqueProcess);
-    }
-
-    if ((ProcessData) && (ProcessData != CsrRootProcess))
-    {
-        /* Attach the Shared Section */
-        Status = CsrSrvAttachSharedSection(ProcessData, ConnectInfo);
-        if (NT_SUCCESS(Status))
-        {
-            DPRINT1("Connection ok\n");
-            AllowConnection = TRUE;
-        }
-        else
-        {
-            DPRINT1("Shared section map failed: %lx\n", Status);
-        }
-    }
-    else if (ProcessData == CsrRootProcess)
-    {
-        AllowConnection = TRUE;
-    }
-
-    /* Release the process */
-    if (ProcessData) CsrUnlockProcess(ProcessData);
-
-    /* Setup the Port View Structure */
-    RemotePortView.Length = sizeof(REMOTE_PORT_VIEW);
-    RemotePortView.ViewSize = 0;
-    RemotePortView.ViewBase = NULL;
-
-    /* Save the Process ID */
-    ConnectInfo->ProcessId = NtCurrentTeb()->ClientId.UniqueProcess;
-
-    Status = NtAcceptConnectPort(&ServerPort,
-                                 AllowConnection ? UlongToPtr(ProcessData->SequenceNumber) : 0,
-                                 Request,
-                                 AllowConnection,
-                                 NULL,
-                                 &RemotePortView);
-    if (!NT_SUCCESS(Status))
-    {
-         DPRINT1("CSRSS: NtAcceptConnectPort - failed.  Status == %X\n", Status);
-    }
-    else if (AllowConnection)
-    {
-        if (CsrDebug & 2)
-        {
-            DPRINT1("CSRSS: ClientId: %lx.%lx has ClientView: Base=%p, Size=%lx\n",
-                    Request->ClientId.UniqueProcess,
-                    Request->ClientId.UniqueThread,
-                    RemotePortView.ViewBase,
-                    RemotePortView.ViewSize);
-        }
-
-        /* Set some Port Data in the Process */
-        ProcessData->ClientPort = ServerPort;
-        ProcessData->ClientViewBase = (ULONG_PTR)RemotePortView.ViewBase;
-        ProcessData->ClientViewBounds = (ULONG_PTR)((ULONG_PTR)RemotePortView.ViewBase +
-                                                    (ULONG_PTR)RemotePortView.ViewSize);
-
-        /* Complete the connection */
-        Status = NtCompleteConnectPort(ServerPort);
-        if (!NT_SUCCESS(Status))
-        {
-            DPRINT1("CSRSS: NtCompleteConnectPort - failed.  Status == %X\n", Status);
-        }
-    }
-    else
-    {
-        DPRINT1("CSRSS: Rejecting Connection Request from ClientId: %lx.%lx\n",
-                Request->ClientId.UniqueProcess,
-                Request->ClientId.UniqueThread);
-    }
-
-    return Status;
-}
-
 /*++
  * @name CsrpCheckRequestThreads
  *
@@ -451,6 +345,7 @@ CsrApiRequestThread(IN PVOID Parameter)
     LARGE_INTEGER TimeOut;
     PCSR_THREAD CurrentThread, CsrThread;
     NTSTATUS Status;
+    CSR_REPLY_CODE ReplyCode;
     PCSR_API_MESSAGE ReplyMsg;
     CSR_API_MESSAGE ReceiveMsg;
     PCSR_PROCESS CsrProcess;
@@ -459,7 +354,7 @@ CsrApiRequestThread(IN PVOID Parameter)
     PCSR_SERVER_DLL ServerDll;
     PCLIENT_DIED_MSG ClientDiedMsg;
     PDBGKM_MSG DebugMessage;
-    ULONG ServerId, ApiId, Reply, MessageType, i;
+    ULONG ServerId, ApiId, MessageType, i;
     HANDLE ReplyPort;
 
     /* Setup LPC loop port and message */
@@ -554,8 +449,9 @@ CsrApiRequestThread(IN PVOID Parameter)
         {
             /* Handle the Connection Request */
             CsrApiHandleConnectionRequest(&ReceiveMsg);
-            ReplyPort = CsrApiPort;
+
             ReplyMsg = NULL;
+            ReplyPort = CsrApiPort;
             continue;
         }
 
@@ -608,7 +504,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                         if ((ServerDll) && (ServerDll->HardErrorCallback))
                         {
                             /* Call it */
-                            ServerDll->HardErrorCallback(NULL /* CsrThread == NULL */, HardErrorMsg);
+                            ServerDll->HardErrorCallback(NULL /* == CsrThread */, HardErrorMsg);
 
                             /* If it's handled, get out of here */
                             if (HardErrorMsg->Response != ResponseNotHandled) break;
@@ -628,6 +524,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                 else
                 {
                     ReplyMsg = &ReceiveMsg;
+                    ReplyPort = CsrApiPort;
                 }
             }
             else if (MessageType == LPC_REQUEST)
@@ -651,8 +548,9 @@ CsrApiRequestThread(IN PVOID Parameter)
                     DPRINT1("CSRSS: %lx is invalid ServerDllIndex (%08x)\n",
                             ServerId, ServerDll);
                     DbgBreakPoint();
-                    ReplyPort = CsrApiPort;
+
                     ReplyMsg = NULL;
+                    ReplyPort = CsrApiPort;
                     continue;
                 }
 
@@ -666,6 +564,7 @@ CsrApiRequestThread(IN PVOID Parameter)
                     DPRINT1("CSRSS: %lx is invalid ApiTableIndex for %Z\n",
                             CSR_API_NUMBER_TO_API_ID(ReceiveMsg.ApiNumber),
                             &ServerDll->Name);
+
                     ReplyPort = CsrApiPort;
                     ReplyMsg = NULL;
                     continue;
@@ -690,10 +589,10 @@ CsrApiRequestThread(IN PVOID Parameter)
                     /* Make sure we have enough threads */
                     CsrpCheckRequestThreads();
 
-                    /* Call the API and get the result */
+                    /* Call the API and get the reply code */
                     ReplyMsg = NULL;
                     ReplyPort = CsrApiPort;
-                    ServerDll->DispatchTable[ApiId](&ReceiveMsg, &Reply);
+                    ServerDll->DispatchTable[ApiId](&ReceiveMsg, &ReplyCode);
 
                     /* Increase the static thread count */
                     _InterlockedIncrement(&CsrpStaticThreadCount);
@@ -726,6 +625,9 @@ CsrApiRequestThread(IN PVOID Parameter)
                 ClientDiedMsg = (PCLIENT_DIED_MSG)&ReceiveMsg;
                 if (ClientDiedMsg->CreateTime.QuadPart == CsrThread->CreateTime.QuadPart)
                 {
+                    /* Now we reply to the dying client */
+                    ReplyPort = CsrThread->Process->ClientPort;
+
                     /* Reference the thread */
                     CsrLockedReferenceThread(CsrThread);
 
@@ -745,6 +647,7 @@ CsrApiRequestThread(IN PVOID Parameter)
 
                 /* Release the lock and keep looping */
                 CsrReleaseProcessLock();
+
                 ReplyMsg = NULL;
                 ReplyPort = CsrApiPort;
                 continue;
@@ -873,12 +776,14 @@ CsrApiRequestThread(IN PVOID Parameter)
 
         if (CsrDebug & 2)
         {
-            DPRINT1("[%02x] CSRSS: [%02x,%02x] - %s Api called from %08x\n",
+            DPRINT1("[%02x] CSRSS: [%02x,%02x] - %s Api called from %08x, Process %08x - %08x\n",
                     Teb->ClientId.UniqueThread,
                     ReceiveMsg.Header.ClientId.UniqueProcess,
                     ReceiveMsg.Header.ClientId.UniqueThread,
                     ServerDll->NameTable[ApiId],
-                    CsrThread);
+                    CsrThread,
+                    CsrThread->Process,
+                    CsrProcess);
         }
 
         /* Assume success */
@@ -908,36 +813,40 @@ CsrApiRequestThread(IN PVOID Parameter)
 
             Teb->CsrClientThread = CsrThread;
 
-            /* Call the API and get the result */
-            Reply = 0;
-            ServerDll->DispatchTable[ApiId](&ReceiveMsg, &Reply);
+            /* Call the API, get the reply code and return the result */
+            ReplyCode = CsrReplyImmediately;
+            ReplyMsg->Status = ServerDll->DispatchTable[ApiId](&ReceiveMsg, &ReplyCode);
 
             /* Increase the static thread count */
             _InterlockedIncrement(&CsrpStaticThreadCount);
 
             Teb->CsrClientThread = CurrentThread;
 
-            if (Reply == 3)
+            if (ReplyCode == CsrReplyAlreadySent)
             {
-                ReplyMsg = NULL;
                 if (ReceiveMsg.CsrCaptureData)
                 {
                     CsrReleaseCapturedArguments(&ReceiveMsg);
                 }
-                CsrDereferenceThread(CsrThread);
+                ReplyMsg = NULL;
                 ReplyPort = CsrApiPort;
+                CsrDereferenceThread(CsrThread);
             }
-            else if (Reply == 2)
+            else if (ReplyCode == CsrReplyDeadClient)
             {
+                /* Reply to the death message */
                 NtReplyPort(ReplyPort, &ReplyMsg->Header);
-                ReplyPort = CsrApiPort;
+
+                /* Reply back to the API port now */
                 ReplyMsg = NULL;
+                ReplyPort = CsrApiPort;
+
                 CsrDereferenceThread(CsrThread);
             }
-            else if (Reply == 1)
+            else if (ReplyCode == CsrReplyPending)
             {
-                ReplyPort = CsrApiPort;
                 ReplyMsg = NULL;
+                ReplyPort = CsrApiPort;
             }
             else
             {
@@ -1015,13 +924,13 @@ CsrApiPortInitialize(VOID)
                                &CsrApiPortName,
                                0,
                                NULL,
-                               NULL /* FIXME*/);
+                               NULL /* FIXME: Use the Security Descriptor */);
 
     /* Create the Port Object */
     Status = NtCreatePort(&CsrApiPort,
                           &ObjectAttributes,
-                          LPC_MAX_DATA_LENGTH, // HACK: the real value is: sizeof(CSR_CONNECTION_INFO),
-                          LPC_MAX_MESSAGE_LENGTH, // HACK: the real value is: sizeof(CSR_API_MESSAGE),
+                          sizeof(CSR_CONNECTION_INFO),
+                          sizeof(CSR_API_MESSAGE),
                           16 * PAGE_SIZE);
     if (NT_SUCCESS(Status))
     {
@@ -1221,7 +1130,9 @@ CsrCaptureArguments(IN PCSR_THREAD CsrThread,
     PCSR_CAPTURE_BUFFER LocalCaptureBuffer = NULL, RemoteCaptureBuffer = NULL;
     SIZE_T BufferDistance;
     ULONG Length = 0;
-    ULONG i;
+    ULONG PointerCount;
+    PULONG_PTR OffsetPointer;
+    ULONG_PTR CurrentOffset;
 
     /* Use SEH to make sure this is valid */
     _SEH2_TRY
@@ -1275,21 +1186,26 @@ CsrCaptureArguments(IN PCSR_THREAD CsrThread,
     BufferDistance = (ULONG_PTR)RemoteCaptureBuffer - (ULONG_PTR)LocalCaptureBuffer;
 
     /*
-     * Convert all the pointer offsets into real pointers, and make
-     * them point to the remote data buffer instead of the local one.
+     * All the pointer offsets correspond to pointers which point
+     * to the remote data buffer instead of the local one.
      */
-    for (i = 0 ; i < RemoteCaptureBuffer->PointerCount ; ++i)
+    PointerCount  = RemoteCaptureBuffer->PointerCount;
+    OffsetPointer = RemoteCaptureBuffer->PointerOffsetsArray;
+    while (PointerCount--)
     {
-        if (RemoteCaptureBuffer->PointerOffsetsArray[i] != 0)
+        CurrentOffset = *OffsetPointer;
+
+        if (CurrentOffset != 0)
         {
-            RemoteCaptureBuffer->PointerOffsetsArray[i] += (ULONG_PTR)ApiMessage;
+            /* Get the pointer corresponding to the offset */
+            CurrentOffset += (ULONG_PTR)ApiMessage;
 
-            /* Validate the bounds of the current pointer */
-            if ((*(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] >= CsrThread->Process->ClientViewBase) &&
-                (*(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] < CsrThread->Process->ClientViewBounds))
+            /* Validate the bounds of the current pointed pointer */
+            if ((*(PULONG_PTR)CurrentOffset >= CsrThread->Process->ClientViewBase) &&
+                (*(PULONG_PTR)CurrentOffset < CsrThread->Process->ClientViewBounds))
             {
-                /* Modify the pointer to take into account its new position */
-                *(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] += BufferDistance;
+                /* Modify the pointed pointer to take into account its new position */
+                *(PULONG_PTR)CurrentOffset += BufferDistance;
             }
             else
             {
@@ -1299,18 +1215,20 @@ CsrCaptureArguments(IN PCSR_THREAD CsrThread,
                 ApiMessage->Status = STATUS_INVALID_PARAMETER;
             }
         }
+
+        ++OffsetPointer;
     }
 
     /* Check if we got success */
     if (ApiMessage->Status != STATUS_SUCCESS)
     {
-        /* Failure. Free the buffer and return*/
+        /* Failure. Free the buffer and return */
         RtlFreeHeap(CsrHeap, 0, RemoteCaptureBuffer);
         return FALSE;
     }
     else
     {
-        /* Success, save the previous buffer */
+        /* Success, save the previous buffer and use the remote capture buffer */
         RemoteCaptureBuffer->PreviousCaptureBuffer = LocalCaptureBuffer;
         ApiMessage->CsrCaptureData = RemoteCaptureBuffer;
     }
@@ -1341,33 +1259,47 @@ CsrReleaseCapturedArguments(IN PCSR_API_MESSAGE ApiMessage)
 {
     PCSR_CAPTURE_BUFFER RemoteCaptureBuffer, LocalCaptureBuffer;
     SIZE_T BufferDistance;
-    ULONG i;
+    ULONG PointerCount;
+    PULONG_PTR OffsetPointer;
+    ULONG_PTR CurrentOffset;
 
-    /* Get the capture buffers */
+    /* Get the remote capture buffer */
     RemoteCaptureBuffer = ApiMessage->CsrCaptureData;
-    LocalCaptureBuffer = RemoteCaptureBuffer->PreviousCaptureBuffer;
 
     /* Do not continue if there is no captured buffer */
     if (!RemoteCaptureBuffer) return;
 
-    /* Free the previous one */
+    /* If there is one, get the corresponding local capture buffer */
+    LocalCaptureBuffer = RemoteCaptureBuffer->PreviousCaptureBuffer;
+
+    /* Free the previous one and use again the local capture buffer */
     RemoteCaptureBuffer->PreviousCaptureBuffer = NULL;
+    ApiMessage->CsrCaptureData = LocalCaptureBuffer;
 
     /* Calculate the difference between our buffer and the client's */
     BufferDistance = (ULONG_PTR)RemoteCaptureBuffer - (ULONG_PTR)LocalCaptureBuffer;
 
     /*
-     * Convert back all the pointers into pointer offsets, and make them
-     * point to the local data buffer instead of the remote one (revert
+     * All the pointer offsets correspond to pointers which point
+     * to the local data buffer instead of the remote one (revert
      * the logic of CsrCaptureArguments).
      */
-    for (i = 0 ; i < RemoteCaptureBuffer->PointerCount ; ++i)
+    PointerCount  = RemoteCaptureBuffer->PointerCount;
+    OffsetPointer = RemoteCaptureBuffer->PointerOffsetsArray;
+    while (PointerCount--)
     {
-        if (RemoteCaptureBuffer->PointerOffsetsArray[i] != 0)
+        CurrentOffset = *OffsetPointer;
+
+        if (CurrentOffset != 0)
         {
-            *(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] -= BufferDistance;
-            RemoteCaptureBuffer->PointerOffsetsArray[i] -= (ULONG_PTR)ApiMessage;
+            /* Get the pointer corresponding to the offset */
+            CurrentOffset += (ULONG_PTR)ApiMessage;
+
+            /* Modify the pointed pointer to take into account its new position */
+            *(PULONG_PTR)CurrentOffset -= BufferDistance;
         }
+
+        ++OffsetPointer;
     }
 
     /* Copy the data back */
@@ -1410,8 +1342,9 @@ CsrValidateMessageBuffer(IN PCSR_API_MESSAGE ApiMessage,
                          IN ULONG ElementSize)
 {
     PCSR_CAPTURE_BUFFER CaptureBuffer = ApiMessage->CsrCaptureData;
-    // SIZE_T BufferDistance = (ULONG_PTR)Buffer - (ULONG_PTR)ApiMessage;
-    ULONG i;
+    SIZE_T BufferDistance = (ULONG_PTR)Buffer - (ULONG_PTR)ApiMessage;
+    ULONG PointerCount;
+    PULONG_PTR OffsetPointer;
 
     /*
      * Check whether we have a valid buffer pointer, elements
@@ -1447,16 +1380,20 @@ CsrValidateMessageBuffer(IN PCSR_API_MESSAGE ApiMessage,
         if ((CaptureBuffer->Size - (ULONG_PTR)*Buffer + (ULONG_PTR)CaptureBuffer) >=
             (ElementCount * ElementSize))
         {
-            for (i = 0 ; i < CaptureBuffer->PointerCount ; ++i)
+            /* Perform the validation test */
+            PointerCount  = CaptureBuffer->PointerCount;
+            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+            while (PointerCount--)
             {
                 /*
-                 * If the pointer offset is in fact equal to the
-                 * real address of the buffer then it's OK.
+                 * The pointer offset must be equal to the delta between
+                 * the addresses of the buffer and of the API message.
                  */
-                if (CaptureBuffer->PointerOffsetsArray[i] == (ULONG_PTR)Buffer /* BufferDistance + (ULONG_PTR)ApiMessage */)
+                if (*OffsetPointer == BufferDistance)
                 {
                     return TRUE;
                 }
+                ++OffsetPointer;
             }
         }
     }