#include "srv.h"
-//#define NDEBUG
+#define NDEBUG
#include <debug.h>
/* GLOBALS ********************************************************************/
BOOLEAN (*CsrClientThreadSetup)(VOID) = NULL;
UNICODE_STRING CsrApiPortName;
-volatile LONG CsrpStaticThreadCount;
-volatile LONG CsrpDynamicThreadTotal;
+volatile ULONG CsrpStaticThreadCount;
+volatile ULONG CsrpDynamicThreadTotal;
extern ULONG CsrMaxApiRequestThreads;
/* FUNCTIONS ******************************************************************/
ULONG ServerId;
PCSR_SERVER_DLL ServerDll;
ULONG ApiId;
- ULONG Reply;
- NTSTATUS Status;
+ CSR_REPLY_CODE ReplyCode = CsrReplyImmediately;
/* Get the Server ID */
ServerId = CSR_API_NUMBER_TO_SERVER_ID(ReceiveMsg->ApiNumber);
{
/* We are beyond the Maximum Server ID */
DPRINT1("CSRSS: %lx is invalid ServerDllIndex (%08x)\n", ServerId, ServerDll);
- ReplyMsg->Status = (ULONG)STATUS_ILLEGAL_FUNCTION;
+ ReplyMsg->Status = STATUS_ILLEGAL_FUNCTION;
return STATUS_ILLEGAL_FUNCTION;
}
else
ServerDll->ValidTable[ApiId],
((ServerDll->NameTable) && (ServerDll->NameTable[ApiId])) ?
ServerDll->NameTable[ApiId] : "*** UNKNOWN ***", &ServerDll->Name);
- DbgBreakPoint();
- ReplyMsg->Status = (ULONG)STATUS_ILLEGAL_FUNCTION;
+ // DbgBreakPoint();
+ ReplyMsg->Status = STATUS_ILLEGAL_FUNCTION;
return STATUS_ILLEGAL_FUNCTION;
}
}
/* Validation complete, start SEH */
_SEH2_TRY
{
- /* Call the API and get the result */
- /// CsrApiCallHandler(ReplyMsg, /*ProcessData*/ &ReplyCode); ///
- Status = (ServerDll->DispatchTable[ApiId])(ReceiveMsg, &Reply);
-
- /* Return the result, no matter what it is */
- ReplyMsg->Status = Status;
+ /* Call the API, get the reply code and return the result */
+ ReplyMsg->Status = ServerDll->DispatchTable[ApiId](ReceiveMsg, &ReplyCode);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
if (CsrProcess)
{
/* Reference the Process */
- CsrLockedReferenceProcess(CsrThread->Process);
+ CsrLockedReferenceProcess(CsrProcess);
/* Release the lock */
CsrReleaseProcessLock();
NTSTATUS Status;
/* Decrease the count, and see if we're out */
- if (!(_InterlockedDecrement(&CsrpStaticThreadCount)))
+ if (InterlockedDecrementUL(&CsrpStaticThreadCount) == 0)
{
/* Check if we've still got space for a Dynamic Thread */
if (CsrpDynamicThreadTotal < CsrMaxApiRequestThreads)
if (NT_SUCCESS(Status))
{
/* Increase the thread counts */
- _InterlockedIncrement(&CsrpStaticThreadCount);
- _InterlockedIncrement(&CsrpDynamicThreadTotal);
+ InterlockedIncrementUL(&CsrpStaticThreadCount);
+ InterlockedIncrementUL(&CsrpDynamicThreadTotal);
/* Add a new server thread */
if (CsrAddStaticServerThread(hThread,
else
{
/* Failed to create a new static thread */
- _InterlockedDecrement(&CsrpStaticThreadCount);
- _InterlockedDecrement(&CsrpDynamicThreadTotal);
+ InterlockedDecrementUL(&CsrpStaticThreadCount);
+ InterlockedDecrementUL(&CsrpDynamicThreadTotal);
/* Terminate it */
DPRINT1("Failing\n");
LARGE_INTEGER TimeOut;
PCSR_THREAD CurrentThread, CsrThread;
NTSTATUS Status;
+ CSR_REPLY_CODE ReplyCode;
PCSR_API_MESSAGE ReplyMsg;
CSR_API_MESSAGE ReceiveMsg;
PCSR_PROCESS CsrProcess;
PCSR_SERVER_DLL ServerDll;
PCLIENT_DIED_MSG ClientDiedMsg;
PDBGKM_MSG DebugMessage;
- ULONG ServerId, ApiId, Reply, MessageType, i;
+ ULONG ServerId, ApiId, MessageType, i;
HANDLE ReplyPort;
/* Setup LPC loop port and message */
ASSERT(NT_SUCCESS(Status));
/* Increase the Thread Counts */
- _InterlockedIncrement(&CsrpStaticThreadCount);
- _InterlockedIncrement(&CsrpDynamicThreadTotal);
+ InterlockedIncrementUL(&CsrpStaticThreadCount);
+ InterlockedIncrementUL(&CsrpDynamicThreadTotal);
}
/* Now start the loop */
}
else
{
- /* A bizare "success" code, just try again */
+ /* A strange "success" code, just try again */
DPRINT1("NtReplyWaitReceivePort returned \"success\" status 0x%x\n", Status);
continue;
}
{
/* Handle the Connection Request */
CsrApiHandleConnectionRequest(&ReceiveMsg);
- ReplyPort = CsrApiPort;
+
ReplyMsg = NULL;
+ ReplyPort = CsrApiPort;
continue;
}
if ((ServerDll) && (ServerDll->HardErrorCallback))
{
/* Call it */
- ServerDll->HardErrorCallback(NULL /* CsrThread == NULL */, HardErrorMsg);
+ ServerDll->HardErrorCallback(NULL /* == CsrThread */, HardErrorMsg);
/* If it's handled, get out of here */
if (HardErrorMsg->Response != ResponseNotHandled) break;
}
/* Increase the thread count */
- _InterlockedIncrement(&CsrpStaticThreadCount);
+ InterlockedIncrementUL(&CsrpStaticThreadCount);
/* If the response was 0xFFFFFFFF, we'll ignore it */
if (HardErrorMsg->Response == 0xFFFFFFFF)
else
{
ReplyMsg = &ReceiveMsg;
+ ReplyPort = CsrApiPort;
}
}
else if (MessageType == LPC_REQUEST)
/* We are beyond the Maximum Server ID */
DPRINT1("CSRSS: %lx is invalid ServerDllIndex (%08x)\n",
ServerId, ServerDll);
- DbgBreakPoint();
- ReplyPort = CsrApiPort;
+ // DbgBreakPoint();
+
ReplyMsg = NULL;
+ ReplyPort = CsrApiPort;
continue;
}
DPRINT1("CSRSS: %lx is invalid ApiTableIndex for %Z\n",
CSR_API_NUMBER_TO_API_ID(ReceiveMsg.ApiNumber),
&ServerDll->Name);
+
ReplyPort = CsrApiPort;
ReplyMsg = NULL;
continue;
/* Make sure we have enough threads */
CsrpCheckRequestThreads();
- /* Call the API and get the result */
+ /* Call the API and get the reply code */
ReplyMsg = NULL;
ReplyPort = CsrApiPort;
- ServerDll->DispatchTable[ApiId](&ReceiveMsg, &Reply);
+ ServerDll->DispatchTable[ApiId](&ReceiveMsg, &ReplyCode);
/* Increase the static thread count */
- _InterlockedIncrement(&CsrpStaticThreadCount);
+ InterlockedIncrementUL(&CsrpStaticThreadCount);
}
_SEH2_EXCEPT(CsrUnhandledExceptionFilter(_SEH2_GetExceptionInformation()))
{
ClientDiedMsg = (PCLIENT_DIED_MSG)&ReceiveMsg;
if (ClientDiedMsg->CreateTime.QuadPart == CsrThread->CreateTime.QuadPart)
{
+ /* Now we reply to the dying client */
+ ReplyPort = CsrThread->Process->ClientPort;
+
/* Reference the thread */
CsrLockedReferenceThread(CsrThread);
/* Release the lock and keep looping */
CsrReleaseProcessLock();
+
ReplyMsg = NULL;
ReplyPort = CsrApiPort;
continue;
}
/* Increase the thread count */
- _InterlockedIncrement(&CsrpStaticThreadCount);
+ InterlockedIncrementUL(&CsrpStaticThreadCount);
/* If the response was 0xFFFFFFFF, we'll ignore it */
if (HardErrorMsg->Response == 0xFFFFFFFF)
/* We are beyond the Maximum Server ID */
DPRINT1("CSRSS: %lx is invalid ServerDllIndex (%08x)\n",
ServerId, ServerDll);
- DbgBreakPoint();
+ // DbgBreakPoint();
ReplyPort = CsrApiPort;
ReplyMsg = &ReceiveMsg;
if (CsrDebug & 2)
{
- DPRINT1("[%02x] CSRSS: [%02x,%02x] - %s Api called from %08x\n",
+ DPRINT1("[%02x] CSRSS: [%02x,%02x] - %s Api called from %08x, Process %08x - %08x\n",
Teb->ClientId.UniqueThread,
ReceiveMsg.Header.ClientId.UniqueProcess,
ReceiveMsg.Header.ClientId.UniqueThread,
ServerDll->NameTable[ApiId],
- CsrThread);
+ CsrThread,
+ CsrThread->Process,
+ CsrProcess);
}
/* Assume success */
Teb->CsrClientThread = CsrThread;
- /* Call the API and get the result */
- Reply = 0;
- ServerDll->DispatchTable[ApiId](&ReceiveMsg, &Reply);
+ /* Call the API, get the reply code and return the result */
+ ReplyCode = CsrReplyImmediately;
+ ReplyMsg->Status = ServerDll->DispatchTable[ApiId](&ReceiveMsg, &ReplyCode);
/* Increase the static thread count */
- _InterlockedIncrement(&CsrpStaticThreadCount);
+ InterlockedIncrementUL(&CsrpStaticThreadCount);
Teb->CsrClientThread = CurrentThread;
- if (Reply == 3)
+ if (ReplyCode == CsrReplyAlreadySent)
{
- ReplyMsg = NULL;
if (ReceiveMsg.CsrCaptureData)
{
CsrReleaseCapturedArguments(&ReceiveMsg);
}
- CsrDereferenceThread(CsrThread);
+ ReplyMsg = NULL;
ReplyPort = CsrApiPort;
+ CsrDereferenceThread(CsrThread);
}
- else if (Reply == 2)
+ else if (ReplyCode == CsrReplyDeadClient)
{
+ /* Reply to the death message */
NtReplyPort(ReplyPort, &ReplyMsg->Header);
- ReplyPort = CsrApiPort;
+
+ /* Reply back to the API port now */
ReplyMsg = NULL;
+ ReplyPort = CsrApiPort;
+
CsrDereferenceThread(CsrThread);
}
- else if (Reply == 1)
+ else if (ReplyCode == CsrReplyPending)
{
- ReplyPort = CsrApiPort;
ReplyMsg = NULL;
+ ReplyPort = CsrApiPort;
}
else
{
*
* @param None
*
- * @return STATUS_SUCCESS in case of success, STATUS_UNSUCCESSFUL
- * otherwise.
+ * @return STATUS_SUCCESS in case of success, STATUS_UNSUCCESSFUL otherwise.
*
* @remarks None.
*
&CsrApiPortName,
0,
NULL,
- NULL /* FIXME*/);
+ NULL /* FIXME: Use the Security Descriptor */);
/* Create the Port Object */
Status = NtCreatePort(&CsrApiPort,
&ObjectAttributes,
- LPC_MAX_DATA_LENGTH, // HACK: the real value is: sizeof(CSR_CONNECTION_INFO),
- LPC_MAX_MESSAGE_LENGTH, // HACK: the real value is: sizeof(CSR_API_MESSAGE),
+ sizeof(CSR_CONNECTION_INFO),
+ sizeof(CSR_API_MESSAGE),
16 * PAGE_SIZE);
if (NT_SUCCESS(Status))
{
NTAPI
CsrConnectToUser(VOID)
{
-#if 0 // This code is OK, however it is ClientThreadSetup which sucks.
+#if 0 // FIXME: This code is OK, however it is ClientThreadSetup which sucks.
NTSTATUS Status;
ANSI_STRING DllName;
UNICODE_STRING TempName;
PCSR_THREAD CsrThread;
/* Save pointer to this thread in TEB */
+ CsrAcquireProcessLock();
CsrThread = CsrLocateThreadInProcess(NULL, &Teb->ClientId);
+ CsrReleaseProcessLock();
if (CsrThread) Teb->CsrClientThread = CsrThread;
/* Return it */
NTAPI
CsrQueryApiPort(VOID)
{
- DPRINT("CSRSRV: %s called\n", __FUNCTION__);
return CsrApiPort;
}
PCSR_CAPTURE_BUFFER LocalCaptureBuffer = NULL, RemoteCaptureBuffer = NULL;
SIZE_T BufferDistance;
ULONG Length = 0;
- ULONG i;
+ ULONG PointerCount;
+ PULONG_PTR OffsetPointer;
+ ULONG_PTR CurrentOffset;
/* Use SEH to make sure this is valid */
_SEH2_TRY
} _SEH2_END;
/* We validated the incoming buffer, now allocate the remote one */
- RemoteCaptureBuffer = RtlAllocateHeap(CsrHeap, 0, Length);
+ RemoteCaptureBuffer = RtlAllocateHeap(CsrHeap, HEAP_ZERO_MEMORY, Length);
if (!RemoteCaptureBuffer)
{
/* We're out of memory */
BufferDistance = (ULONG_PTR)RemoteCaptureBuffer - (ULONG_PTR)LocalCaptureBuffer;
/*
- * Convert all the pointer offsets into real pointers, and make
- * them point to the remote data buffer instead of the local one.
+ * All the pointer offsets correspond to pointers which point
+ * to the remote data buffer instead of the local one.
*/
- for (i = 0 ; i < RemoteCaptureBuffer->PointerCount ; ++i)
+ PointerCount = RemoteCaptureBuffer->PointerCount;
+ OffsetPointer = RemoteCaptureBuffer->PointerOffsetsArray;
+ while (PointerCount--)
{
- if (RemoteCaptureBuffer->PointerOffsetsArray[i] != 0)
+ CurrentOffset = *OffsetPointer;
+
+ if (CurrentOffset != 0)
{
- RemoteCaptureBuffer->PointerOffsetsArray[i] += (ULONG_PTR)ApiMessage;
+ /* Get the pointer corresponding to the offset */
+ CurrentOffset += (ULONG_PTR)ApiMessage;
- /* Validate the bounds of the current pointer */
- if ((*(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] >= CsrThread->Process->ClientViewBase) &&
- (*(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] < CsrThread->Process->ClientViewBounds))
+ /* Validate the bounds of the current pointed pointer */
+ if ((*(PULONG_PTR)CurrentOffset >= CsrThread->Process->ClientViewBase) &&
+ (*(PULONG_PTR)CurrentOffset < CsrThread->Process->ClientViewBounds))
{
- /* Modify the pointer to take into account its new position */
- *(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] += BufferDistance;
+ /* Modify the pointed pointer to take into account its new position */
+ *(PULONG_PTR)CurrentOffset += BufferDistance;
}
else
{
ApiMessage->Status = STATUS_INVALID_PARAMETER;
}
}
+
+ ++OffsetPointer;
}
/* Check if we got success */
if (ApiMessage->Status != STATUS_SUCCESS)
{
- /* Failure. Free the buffer and return*/
+ /* Failure. Free the buffer and return */
RtlFreeHeap(CsrHeap, 0, RemoteCaptureBuffer);
return FALSE;
}
else
{
- /* Success, save the previous buffer */
+ /* Success, save the previous buffer and use the remote capture buffer */
RemoteCaptureBuffer->PreviousCaptureBuffer = LocalCaptureBuffer;
ApiMessage->CsrCaptureData = RemoteCaptureBuffer;
}
{
PCSR_CAPTURE_BUFFER RemoteCaptureBuffer, LocalCaptureBuffer;
SIZE_T BufferDistance;
- ULONG i;
+ ULONG PointerCount;
+ PULONG_PTR OffsetPointer;
+ ULONG_PTR CurrentOffset;
- /* Get the capture buffers */
+ /* Get the remote capture buffer */
RemoteCaptureBuffer = ApiMessage->CsrCaptureData;
- LocalCaptureBuffer = RemoteCaptureBuffer->PreviousCaptureBuffer;
/* Do not continue if there is no captured buffer */
if (!RemoteCaptureBuffer) return;
- /* Free the previous one */
+ /* If there is one, get the corresponding local capture buffer */
+ LocalCaptureBuffer = RemoteCaptureBuffer->PreviousCaptureBuffer;
+
+ /* Free the previous one and use again the local capture buffer */
RemoteCaptureBuffer->PreviousCaptureBuffer = NULL;
+ ApiMessage->CsrCaptureData = LocalCaptureBuffer;
/* Calculate the difference between our buffer and the client's */
BufferDistance = (ULONG_PTR)RemoteCaptureBuffer - (ULONG_PTR)LocalCaptureBuffer;
/*
- * Convert back all the pointers into pointer offsets, and make them
- * point to the local data buffer instead of the remote one (revert
+ * All the pointer offsets correspond to pointers which point
+ * to the local data buffer instead of the remote one (revert
* the logic of CsrCaptureArguments).
*/
- for (i = 0 ; i < RemoteCaptureBuffer->PointerCount ; ++i)
+ PointerCount = RemoteCaptureBuffer->PointerCount;
+ OffsetPointer = RemoteCaptureBuffer->PointerOffsetsArray;
+ while (PointerCount--)
{
- if (RemoteCaptureBuffer->PointerOffsetsArray[i] != 0)
+ CurrentOffset = *OffsetPointer;
+
+ if (CurrentOffset != 0)
{
- *(PULONG_PTR)RemoteCaptureBuffer->PointerOffsetsArray[i] -= BufferDistance;
- RemoteCaptureBuffer->PointerOffsetsArray[i] -= (ULONG_PTR)ApiMessage;
+ /* Get the pointer corresponding to the offset */
+ CurrentOffset += (ULONG_PTR)ApiMessage;
+
+ /* Modify the pointed pointer to take into account its new position */
+ *(PULONG_PTR)CurrentOffset -= BufferDistance;
}
+
+ ++OffsetPointer;
}
/* Copy the data back */
RtlFreeHeap(CsrHeap, 0, RemoteCaptureBuffer);
}
-
/*++
* @name CsrValidateMessageBuffer
* @implemented NT5.1
* @param ElementSize
* Size of each element.
*
- * @return TRUE if validation suceeded, FALSE otherwise.
+ * @return TRUE if validation succeeded, FALSE otherwise.
*
* @remarks None.
*
IN ULONG ElementSize)
{
PCSR_CAPTURE_BUFFER CaptureBuffer = ApiMessage->CsrCaptureData;
- // SIZE_T BufferDistance = (ULONG_PTR)Buffer - (ULONG_PTR)ApiMessage;
- ULONG i;
+ SIZE_T BufferDistance = (ULONG_PTR)Buffer - (ULONG_PTR)ApiMessage;
+ ULONG PointerCount;
+ PULONG_PTR OffsetPointer;
/*
* Check whether we have a valid buffer pointer, elements
if ((CaptureBuffer->Size - (ULONG_PTR)*Buffer + (ULONG_PTR)CaptureBuffer) >=
(ElementCount * ElementSize))
{
- for (i = 0 ; i < CaptureBuffer->PointerCount ; ++i)
+ /* Perform the validation test */
+ PointerCount = CaptureBuffer->PointerCount;
+ OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+ while (PointerCount--)
{
/*
- * If the pointer offset is in fact equal to the
- * real address of the buffer then it's OK.
+ * The pointer offset must be equal to the delta between
+ * the addresses of the buffer and of the API message.
*/
- if (CaptureBuffer->PointerOffsetsArray[i] == (ULONG_PTR)Buffer /* BufferDistance + (ULONG_PTR)ApiMessage */)
+ if (*OffsetPointer == BufferDistance)
{
return TRUE;
}
+ ++OffsetPointer;
}
}
}
return FALSE;
}
-/*** This is what we have in consrv/server.c ***
-
-/\* Ensure that a captured buffer is safe to access *\/
-BOOL FASTCALL
-Win32CsrValidateBuffer(PCSR_PROCESS ProcessData, PVOID Buffer,
- SIZE_T NumElements, SIZE_T ElementSize)
-{
- /\* Check that the following conditions are true:
- * 1. The start of the buffer is somewhere within the process's
- * shared memory section view.
- * 2. The remaining space in the view is at least as large as the buffer.
- * (NB: Please don't try to "optimize" this by using multiplication
- * instead of division; remember that 2147483648 * 2 = 0.)
- * 3. The buffer is DWORD-aligned.
- *\/
- ULONG_PTR Offset = (BYTE *)Buffer - (BYTE *)ProcessData->ClientViewBase;
- if (Offset >= ProcessData->ClientViewBounds
- || NumElements > (ProcessData->ClientViewBounds - Offset) / ElementSize
- || (Offset & (sizeof(DWORD) - 1)) != 0)
- {
- DPRINT1("Invalid buffer %p(%u*%u); section view is %p(%u)\n",
- Buffer, NumElements, ElementSize,
- ProcessData->ClientViewBase, ProcessData->ClientViewBounds);
- return FALSE;
- }
- return TRUE;
-}
-
-***********************************************/
-
/*++
* @name CsrValidateMessageString
* @implemented NT5.1
* @param MessageString
* Pointer to the buffer containing the string to validate.
*
- * @return TRUE if validation suceeded, FALSE otherwise.
+ * @return TRUE if validation succeeded, FALSE otherwise.
*
* @remarks None.
*