From f197af7358b49247f71e2af4750a28389838dc1d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Herm=C3=A8s=20B=C3=A9lusca-Ma=C3=AFto?= Date: Mon, 7 Nov 2016 01:24:24 +0000 Subject: [PATCH] [NTOS:LPC] - Fix the usage, or add support (for NtSecureConnectPort based on patches by Alexander Andrejevic, and for NtReplyPort, NtReplyWaitReceivePortEx, NtListenPort by me) for capturing user-mode parameters and using SEH in LPC functions. CORE-7371 #resolve - Make NtSecureConnectPort call SeQueryInformationToken, now that the latter is implemented since r73122. - Fix ObDereferenceObject usage for Port vs. ClientPort in NtSecureConnectPort. - ObCloseHandle certainly needs to be called with the actual 'PreviousMode'. svn path=/trunk/; revision=73164 --- reactos/ntoskrnl/lpc/complete.c | 48 +++-- reactos/ntoskrnl/lpc/connect.c | 331 +++++++++++++++++++++++--------- reactos/ntoskrnl/lpc/create.c | 36 ++-- reactos/ntoskrnl/lpc/listen.c | 19 +- reactos/ntoskrnl/lpc/port.c | 20 +- reactos/ntoskrnl/lpc/reply.c | 133 +++++++------ reactos/ntoskrnl/lpc/send.c | 12 +- 7 files changed, 386 insertions(+), 213 deletions(-) diff --git a/reactos/ntoskrnl/lpc/complete.c b/reactos/ntoskrnl/lpc/complete.c index 4bf8edd11b4..23472ac6d8e 100644 --- a/reactos/ntoskrnl/lpc/complete.c +++ b/reactos/ntoskrnl/lpc/complete.c @@ -46,6 +46,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, { NTSTATUS Status; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); + PORT_VIEW CapturedServerView; + PORT_MESSAGE CapturedReplyMessage; ULONG ConnectionInfoLength; PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort; PLPCP_CONNECTION_MESSAGE ConnectMessage; @@ -55,8 +57,6 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, PEPROCESS ClientProcess; PETHREAD ClientThread; LARGE_INTEGER SectionOffset; - CLIENT_ID ClientId; - ULONG MessageId; PAGED_CODE(); LPCTRACE(LPC_COMPLETE_DEBUG, @@ -70,18 +70,15 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, /* Check if the call comes from user mode */ if (PreviousMode != KernelMode) { - /* Enter SEH for probing the parameters */ _SEH2_TRY { + /* Probe the PortHandle */ ProbeForWriteHandle(PortHandle); /* Probe the basic ReplyMessage structure */ - ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); - - /* Grab some values */ - ClientId = ReplyMessage->ClientId; - MessageId = ReplyMessage->MessageId; - ConnectionInfoLength = ReplyMessage->u1.s1.DataLength; + ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG)); + CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage; + ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength; /* Probe the connection info */ ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1); @@ -89,10 +86,11 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, /* The following parameters are optional */ if (ServerView != NULL) { - ProbeForWrite(ServerView, sizeof(PORT_VIEW), sizeof(ULONG)); + ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG)); + CapturedServerView = *(volatile PORT_VIEW*)ServerView; /* Validate the size of the server view */ - if (ServerView->Length != sizeof(PORT_VIEW)) + if (CapturedServerView.Length != sizeof(CapturedServerView)) { /* Invalid size */ _SEH2_YIELD(return STATUS_INVALID_PARAMETER); @@ -101,10 +99,10 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, if (ClientView != NULL) { - ProbeForWrite(ClientView, sizeof(REMOTE_PORT_VIEW), sizeof(ULONG)); + ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG)); /* Validate the size of the client view */ - if (ClientView->Length != sizeof(REMOTE_PORT_VIEW)) + if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView)) { /* Invalid size */ _SEH2_YIELD(return STATUS_INVALID_PARAMETER); @@ -120,20 +118,19 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, } else { - /* Grab some values */ - ClientId = ReplyMessage->ClientId; - MessageId = ReplyMessage->MessageId; - ConnectionInfoLength = ReplyMessage->u1.s1.DataLength; + CapturedReplyMessage = *ReplyMessage; + ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength; /* Validate the size of the server view */ - if ((ServerView) && (ServerView->Length != sizeof(PORT_VIEW))) + if ((ServerView) && (ServerView->Length != sizeof(*ServerView))) { /* Invalid size */ return STATUS_INVALID_PARAMETER; } + CapturedServerView = *ServerView; /* Validate the size of the client view */ - if ((ClientView) && (ClientView->Length != sizeof(REMOTE_PORT_VIEW))) + if ((ClientView) && (ClientView->Length != sizeof(*ClientView))) { /* Invalid size */ return STATUS_INVALID_PARAMETER; @@ -141,7 +138,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, } /* Get the client process and thread */ - Status = PsLookupProcessThreadByCid(&ClientId, + Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId, &ClientProcess, &ClientThread); if (!NT_SUCCESS(Status)) return Status; @@ -151,8 +148,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, /* Make sure that the client wants a reply, and this is the right one */ if (!(LpcpGetMessageFromThread(ClientThread)) || - !(MessageId) || - (ClientThread->LpcReplyMessageId != MessageId)) + !(CapturedReplyMessage.MessageId) || + (ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId)) { /* Not the reply asked for, or no reply wanted, fail */ KeReleaseGuardedMutex(&LpcpLock); @@ -203,8 +200,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, /* Setup the reply message */ Message->Request.u2.s2.Type = LPC_REPLY; Message->Request.u2.s2.DataInfoOffset = 0; - Message->Request.ClientId = ClientId; - Message->Request.MessageId = MessageId; + Message->Request.ClientId = CapturedReplyMessage.ClientId; + Message->Request.MessageId = CapturedReplyMessage.MessageId; Message->Request.ClientViewSize = 0; _SEH2_TRY @@ -310,6 +307,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, if (ServerView) { /* FIXME: TODO */ + UNREFERENCED_PARAMETER(CapturedServerView); ASSERT(FALSE); } @@ -347,7 +345,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { /* Cleanup and return the exception code */ - ObCloseHandle(Handle, UserMode); + ObCloseHandle(Handle, PreviousMode); ObDereferenceObject(ServerPort); Status = _SEH2_GetExceptionCode(); _SEH2_YIELD(goto Cleanup); diff --git a/reactos/ntoskrnl/lpc/connect.c b/reactos/ntoskrnl/lpc/connect.c index 3620c56913d..9e99538c256 100644 --- a/reactos/ntoskrnl/lpc/connect.c +++ b/reactos/ntoskrnl/lpc/connect.c @@ -90,6 +90,9 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, NTSTATUS Status = STATUS_SUCCESS; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); PETHREAD Thread = PsGetCurrentThread(); + SECURITY_QUALITY_OF_SERVICE CapturedQos; + PORT_VIEW CapturedClientView; + PSID CapturedServerSid; ULONG ConnectionInfoLength = 0; PLPCP_PORT_OBJECT Port, ClientPort; PLPCP_MESSAGE Message; @@ -110,25 +113,122 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, ServerView, ServerSid); - /* Validate client view */ - if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW))) + /* Check if the call comes from user mode */ + if (PreviousMode != KernelMode) { - /* Fail */ - return STATUS_INVALID_PARAMETER; - } + /* Enter SEH for probing the parameters */ + _SEH2_TRY + { + /* Probe the PortHandle */ + ProbeForWriteHandle(PortHandle); - /* Validate server view */ - if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW))) - { - /* Fail */ - return STATUS_INVALID_PARAMETER; - } + /* Probe and capture the QoS */ + ProbeForRead(SecurityQos, sizeof(*SecurityQos), sizeof(ULONG)); + CapturedQos = *(volatile SECURITY_QUALITY_OF_SERVICE*)SecurityQos; + /* NOTE: Do not care about CapturedQos.Length */ + + /* The following parameters are optional */ + + /* Capture the client view */ + if (ClientView != NULL) + { + ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG)); + CapturedClientView = *(volatile PORT_VIEW*)ClientView; + + /* Validate the size of the client view */ + if (CapturedClientView.Length != sizeof(CapturedClientView)) + { + /* Invalid size */ + _SEH2_YIELD(return STATUS_INVALID_PARAMETER); + } + + } + + /* Capture the server view */ + if (ServerView != NULL) + { + ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG)); + + /* Validate the size of the server view */ + if (((volatile REMOTE_PORT_VIEW*)ServerView)->Length != sizeof(*ServerView)) + { + /* Invalid size */ + _SEH2_YIELD(return STATUS_INVALID_PARAMETER); + } + } + + if (MaxMessageLength) + ProbeForWriteUlong(MaxMessageLength); + + /* Capture connection information length */ + if (ConnectionInformationLength) + { + ProbeForWriteUlong(ConnectionInformationLength); + ConnectionInfoLength = *(volatile ULONG*)ConnectionInformationLength; + } - /* Check if caller sent connection information length */ - if (ConnectionInformationLength) + /* Probe the ConnectionInformation */ + if (ConnectionInformation) + ProbeForWrite(ConnectionInformation, ConnectionInfoLength, sizeof(ULONG)); + + CapturedServerSid = ServerSid; + if (ServerSid != NULL) + { + /* Capture it */ + Status = SepCaptureSid(ServerSid, + PreviousMode, + PagedPool, + TRUE, + &CapturedServerSid); + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to capture ServerSid!\n"); + _SEH2_YIELD(return Status); + } + } + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + /* There was an exception, return the exception code */ + _SEH2_YIELD(return _SEH2_GetExceptionCode()); + } + _SEH2_END; + } + else { - /* Retrieve the input length */ - ConnectionInfoLength = *ConnectionInformationLength; + CapturedQos = *SecurityQos; + /* NOTE: Do not care about CapturedQos.Length */ + + /* The following parameters are optional */ + + /* Capture the client view */ + if (ClientView != NULL) + { + /* Validate the size of the client view */ + if (ClientView->Length != sizeof(*ClientView)) + { + /* Invalid size */ + return STATUS_INVALID_PARAMETER; + } + CapturedClientView = *ClientView; + } + + /* Capture the server view */ + if (ServerView != NULL) + { + /* Validate the size of the server view */ + if (ServerView->Length != sizeof(*ServerView)) + { + /* Invalid size */ + return STATUS_INVALID_PARAMETER; + } + } + + /* Capture connection information length */ + if (ConnectionInformationLength) + ConnectionInfoLength = *ConnectionInformationLength; + + CapturedServerSid = ServerSid; } /* Get the port */ @@ -143,6 +243,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (!NT_SUCCESS(Status)) { DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status); + + if (CapturedServerSid != ServerSid) + SepReleaseSid(CapturedServerSid, PreviousMode, TRUE); + return Status; } @@ -151,10 +255,14 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, { /* It isn't, so fail */ ObDereferenceObject(Port); + + if (CapturedServerSid != ServerSid) + SepReleaseSid(CapturedServerSid, PreviousMode, TRUE); + return STATUS_INVALID_PORT_HANDLE; } - /* Check if we have a SID */ + /* Check if we have a (captured) SID */ if (ServerSid) { /* Make sure that we have a server */ @@ -162,18 +270,14 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, { /* Get its token and query user information */ Token = PsReferencePrimaryToken(Port->ServerProcess); - //Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo); - // FIXME: Need SeQueryInformationToken - Status = STATUS_SUCCESS; - TokenUserInfo = ExAllocatePoolWithTag(PagedPool, sizeof(TOKEN_USER), TAG_SE); - TokenUserInfo->User.Sid = ServerSid; + Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo); PsDereferencePrimaryToken(Token); /* Check for success */ if (NT_SUCCESS(Status)) { /* Compare the SIDs */ - if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid)) + if (!RtlEqualSid(CapturedServerSid, TokenUserInfo->User.Sid)) { /* Fail */ Status = STATUS_SERVER_SID_MISMATCH; @@ -189,6 +293,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, Status = STATUS_SERVER_SID_MISMATCH; } + /* Finally release the captured SID, we don't need it anymore */ + if (CapturedServerSid != ServerSid) + SepReleaseSid(CapturedServerSid, PreviousMode, TRUE); + /* Check if SID failed */ if (!NT_SUCCESS(Status)) { @@ -215,17 +323,20 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, return Status; } - /* Setup the client port */ + /* + * Setup the client port -- From now on, dereferencing the client port + * will automatically dereference the connection port too. + */ RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT)); ClientPort->Flags = LPCP_CLIENT_PORT; ClientPort->ConnectionPort = Port; ClientPort->MaxMessageLength = Port->MaxMessageLength; - ClientPort->SecurityQos = *Qos; + ClientPort->SecurityQos = CapturedQos; InitializeListHead(&ClientPort->LpcReplyChainHead); InitializeListHead(&ClientPort->LpcDataInfoChainHead); /* Check if we have dynamic security */ - if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING) + if (CapturedQos.ContextTrackingMode == SECURITY_DYNAMIC_TRACKING) { /* Remember that */ ClientPort->Flags |= LPCP_SECURITY_DYNAMIC; @@ -234,7 +345,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, { /* Create our own client security */ Status = SeCreateClientSecurity(Thread, - Qos, + &CapturedQos, FALSE, &ClientPort->StaticSecurity); if (!NT_SUCCESS(Status)) @@ -258,7 +369,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (ClientView) { /* Get the section handle */ - Status = ObReferenceObjectByHandle(ClientView->SectionHandle, + Status = ObReferenceObjectByHandle(CapturedClientView.SectionHandle, SECTION_MAP_READ | SECTION_MAP_WRITE, MmSectionObjectType, @@ -268,12 +379,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (!NT_SUCCESS(Status)) { /* Fail */ - ObDereferenceObject(Port); + ObDereferenceObject(ClientPort); return Status; } /* Set the section offset */ - SectionOffset.QuadPart = ClientView->SectionOffset; + SectionOffset.QuadPart = CapturedClientView.SectionOffset; /* Map it */ Status = MmMapViewOfSection(SectionToMap, @@ -282,25 +393,25 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, 0, 0, &SectionOffset, - &ClientView->ViewSize, + &CapturedClientView.ViewSize, ViewUnmap, 0, PAGE_READWRITE); /* Update the offset */ - ClientView->SectionOffset = SectionOffset.LowPart; + CapturedClientView.SectionOffset = SectionOffset.LowPart; /* Check for failure */ if (!NT_SUCCESS(Status)) { /* Fail */ ObDereferenceObject(SectionToMap); - ObDereferenceObject(Port); + ObDereferenceObject(ClientPort); return Status; } /* Update the base */ - ClientView->ViewBase = ClientPort->ClientSectionBase; + CapturedClientView.ViewBase = ClientPort->ClientSectionBase; /* Reference and remember the process */ ClientPort->MappingProcess = PsGetCurrentProcess(); @@ -337,12 +448,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (ClientView) { /* Set the view size */ - Message->Request.ClientViewSize = ClientView->ViewSize; + Message->Request.ClientViewSize = CapturedClientView.ViewSize; /* Copy the client view and clear the server view */ RtlCopyMemory(&ConnectMessage->ClientView, - ClientView, - sizeof(PORT_VIEW)); + &CapturedClientView, + sizeof(CapturedClientView)); RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW)); } else @@ -366,12 +477,33 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, /* Check if we have connection information */ if (ConnectionInformation) { - /* Copy it in */ - RtlCopyMemory(ConnectMessage + 1, - ConnectionInformation, - ConnectionInfoLength); + _SEH2_TRY + { + /* Copy it in */ + RtlCopyMemory(ConnectMessage + 1, + ConnectionInformation, + ConnectionInfoLength); + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + /* Cleanup and return the exception code */ + + /* Free the message we have */ + LpcpFreeToPortZone(Message, 0); + + /* Dereference other objects */ + if (SectionToMap) ObDereferenceObject(SectionToMap); + ObDereferenceObject(ClientPort); + + /* Return status */ + _SEH2_YIELD(return _SEH2_GetExceptionCode()); + } + _SEH2_END; } + /* Reset the status code */ + Status = STATUS_SUCCESS; + /* Acquire the port lock */ KeAcquireGuardedMutex(&LpcpLock); @@ -433,11 +565,25 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode); } + /* Now, always free the connection message */ + SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); + /* Check for failure */ - if (!NT_SUCCESS(Status)) goto Cleanup; + if (!NT_SUCCESS(Status)) + { + /* Check if the semaphore got signaled in the meantime */ + if (KeReadStateSemaphore(&Thread->LpcReplySemaphore)) + { + /* Wait on it */ + KeWaitForSingleObject(&Thread->LpcReplySemaphore, + WrExecutive, + KernelMode, + FALSE, + NULL); + } - /* Free the connection message */ - SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); + goto Failure; + } /* Check if we got a message back */ if (Message) @@ -451,20 +597,27 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, sizeof(LPCP_CONNECTION_MESSAGE); } - /* Check if we had connection information */ + /* Check if the caller had connection information */ if (ConnectionInformation) { - /* Check if we had a length pointer */ - if (ConnectionInformationLength) + _SEH2_TRY { - /* Return the length */ - *ConnectionInformationLength = ConnectionInfoLength; + /* Return the connection information length if needed */ + if (ConnectionInformationLength) + *ConnectionInformationLength = ConnectionInfoLength; + + /* Return the connection information */ + RtlCopyMemory(ConnectionInformation, + ConnectMessage + 1, + ConnectionInfoLength); } - - /* Return the connection information */ - RtlCopyMemory(ConnectionInformation, - ConnectMessage + 1, - ConnectionInfoLength ); + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + /* Cleanup and return the exception code */ + Status = _SEH2_GetExceptionCode(); + _SEH2_YIELD(goto Failure); + } + _SEH2_END; } /* Make sure we had a connected port */ @@ -482,33 +635,45 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, &Handle); if (NT_SUCCESS(Status)) { - /* Return the handle */ - *PortHandle = Handle; LPCTRACE(LPC_CONNECT_DEBUG, "Handle: %p. Length: %lx\n", Handle, PortMessageLength); - /* Check if maximum length was requested */ - if (MaxMessageLength) *MaxMessageLength = PortMessageLength; - - /* Check if we had a client view */ - if (ClientView) + _SEH2_TRY { - /* Copy it back */ - RtlCopyMemory(ClientView, - &ConnectMessage->ClientView, - sizeof(PORT_VIEW)); + /* Return the handle */ + *PortHandle = Handle; + + /* Check if maximum length was requested */ + if (MaxMessageLength) + *MaxMessageLength = PortMessageLength; + + /* Check if we had a client view */ + if (ClientView) + { + /* Copy it back */ + RtlCopyMemory(ClientView, + &ConnectMessage->ClientView, + sizeof(*ClientView)); + } + + /* Check if we had a server view */ + if (ServerView) + { + /* Copy it back */ + RtlCopyMemory(ServerView, + &ConnectMessage->ServerView, + sizeof(*ServerView)); + } } - - /* Check if we had a server view */ - if (ServerView) + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { - /* Copy it back */ - RtlCopyMemory(ServerView, - &ConnectMessage->ServerView, - sizeof(REMOTE_PORT_VIEW)); + /* An exception happened, close the opened handle */ + ObCloseHandle(Handle, PreviousMode); + Status = _SEH2_GetExceptionCode(); } + _SEH2_END; } } else @@ -545,39 +710,25 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, else { /* No reply message, fail */ - if (SectionToMap) ObDereferenceObject(SectionToMap); - ObDereferenceObject(ClientPort); Status = STATUS_PORT_CONNECTION_REFUSED; + goto Failure; } - /* Return status */ ObDereferenceObject(Port); - return Status; -Cleanup: - /* We failed, free the message */ - SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); - - /* Check if the semaphore got signaled */ - if (KeReadStateSemaphore(&Thread->LpcReplySemaphore)) - { - /* Wait on it */ - KeWaitForSingleObject(&Thread->LpcReplySemaphore, - WrExecutive, - KernelMode, - FALSE, - NULL); - } + /* Return status */ + return Status; +Failure: /* Check if we had a message and free it */ if (Message) LpcpFreeToPortZone(Message, 0); /* Dereference other objects */ if (SectionToMap) ObDereferenceObject(SectionToMap); ObDereferenceObject(ClientPort); + ObDereferenceObject(Port); /* Return status */ - ObDereferenceObject(Port); return Status; } diff --git a/reactos/ntoskrnl/lpc/create.c b/reactos/ntoskrnl/lpc/create.c index 067f0a75a67..aaa5ed70518 100644 --- a/reactos/ntoskrnl/lpc/create.c +++ b/reactos/ntoskrnl/lpc/create.c @@ -49,14 +49,15 @@ LpcpCreatePort(OUT PHANDLE PortHandle, { NTSTATUS Status; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); + UNICODE_STRING CapturedObjectName, *ObjectName; PLPCP_PORT_OBJECT Port; HANDLE Handle; - PUNICODE_STRING ObjectName; - BOOLEAN NoName; PAGED_CODE(); LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName); + RtlInitEmptyUnicodeString(&CapturedObjectName, NULL, 0); + /* Check if the call comes from user mode */ if (PreviousMode != KernelMode) { @@ -65,15 +66,14 @@ LpcpCreatePort(OUT PHANDLE PortHandle, /* Probe the PortHandle */ ProbeForWriteHandle(PortHandle); - /* Probe the ObjectAttributes */ - ProbeForRead(ObjectAttributes, sizeof(OBJECT_ATTRIBUTES), sizeof(ULONG)); - - /* Get the object name and probe the unicode string */ - ObjectName = ObjectAttributes->ObjectName; - ProbeForRead(ObjectName, sizeof(UNICODE_STRING), 1); - - /* Check if we have no name */ - NoName = (ObjectName->Buffer == NULL) || (ObjectName->Length == 0); + /* Probe the ObjectAttributes and its object name (not the buffer) */ + ProbeForRead(ObjectAttributes, sizeof(*ObjectAttributes), sizeof(ULONG)); + ObjectName = ((volatile OBJECT_ATTRIBUTES*)ObjectAttributes)->ObjectName; + if (ObjectName) + { + ProbeForRead(ObjectName, sizeof(*ObjectName), 1); + CapturedObjectName = *(volatile UNICODE_STRING*)ObjectName; + } } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { @@ -84,11 +84,14 @@ LpcpCreatePort(OUT PHANDLE PortHandle, } else { - /* Check if we have no name */ - NoName = (ObjectAttributes->ObjectName->Buffer == NULL) || - (ObjectAttributes->ObjectName->Length == 0); + if (ObjectAttributes->ObjectName) + CapturedObjectName = *(ObjectAttributes->ObjectName); } + /* Normalize the buffer pointer in case we don't have a name */ + if (CapturedObjectName.Length == 0) + CapturedObjectName.Buffer = NULL; + /* Create the Object */ Status = ObCreateObject(PreviousMode, LpcPortObjectType, @@ -109,7 +112,7 @@ LpcpCreatePort(OUT PHANDLE PortHandle, InitializeListHead(&Port->LpcReplyChainHead); /* Check if we don't have a name */ - if (NoName) + if (CapturedObjectName.Buffer == NULL) { /* Set up for an unconnected port */ Port->Flags = LPCP_UNCONNECTED_PORT; @@ -187,7 +190,8 @@ LpcpCreatePort(OUT PHANDLE PortHandle, } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { - ObCloseHandle(Handle, UserMode); + /* An exception happened, close the opened handle */ + ObCloseHandle(Handle, PreviousMode); Status = _SEH2_GetExceptionCode(); } _SEH2_END; diff --git a/reactos/ntoskrnl/lpc/listen.c b/reactos/ntoskrnl/lpc/listen.c index c4fc1ca86b7..eddcc66e28b 100644 --- a/reactos/ntoskrnl/lpc/listen.c +++ b/reactos/ntoskrnl/lpc/listen.c @@ -36,13 +36,22 @@ NtListenPort(IN HANDLE PortHandle, NULL, ConnectMessage); - /* Accept only LPC_CONNECTION_REQUEST requests */ - if ((Status != STATUS_SUCCESS) || - (LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST)) + _SEH2_TRY { - /* Break out */ - break; + /* Accept only LPC_CONNECTION_REQUEST requests */ + if ((Status != STATUS_SUCCESS) || + (LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST)) + { + /* Break out */ + _SEH2_YIELD(break); + } } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + Status = _SEH2_GetExceptionCode(); + _SEH2_YIELD(break); + } + _SEH2_END; } /* Return status */ diff --git a/reactos/ntoskrnl/lpc/port.c b/reactos/ntoskrnl/lpc/port.c index 28cca60e772..7ca93a1574f 100644 --- a/reactos/ntoskrnl/lpc/port.c +++ b/reactos/ntoskrnl/lpc/port.c @@ -136,28 +136,26 @@ NtImpersonateClientOfPort(IN HANDLE PortHandle, PAGED_CODE(); - /* Check the previous mode */ - PreviousMode = ExGetPreviousMode(); - if (PreviousMode == KernelMode) - { - ClientId = ClientMessage->ClientId; - MessageId = ClientMessage->MessageId; - } - else + /* Check if the call comes from user mode */ + if (PreviousMode != KernelMode) { _SEH2_TRY { ProbeForRead(ClientMessage, sizeof(*ClientMessage), sizeof(PVOID)); - ClientId = ClientMessage->ClientId; - MessageId = ClientMessage->MessageId; + ClientId = ((volatile PORT_MESSAGE*)ClientMessage)->ClientId; + MessageId = ((volatile PORT_MESSAGE*)ClientMessage)->MessageId; } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { - DPRINT1("Got exception!\n"); _SEH2_YIELD(return _SEH2_GetExceptionCode()); } _SEH2_END; } + else + { + ClientId = ClientMessage->ClientId; + MessageId = ClientMessage->MessageId; + } /* Reference the port handle */ Status = ObReferenceObjectByHandle(PortHandle, diff --git a/reactos/ntoskrnl/lpc/reply.c b/reactos/ntoskrnl/lpc/reply.c index 910d0cd9309..d9fa834be0d 100644 --- a/reactos/ntoskrnl/lpc/reply.c +++ b/reactos/ntoskrnl/lpc/reply.c @@ -192,7 +192,7 @@ NtReplyPort(IN HANDLE PortHandle, { NTSTATUS Status; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); - // PORT_MESSAGE CapturedReplyMessage; + PORT_MESSAGE CapturedReplyMessage; PLPCP_PORT_OBJECT Port; PLPCP_MESSAGE Message; PETHREAD Thread = PsGetCurrentThread(), WakeupThread; @@ -203,32 +203,35 @@ NtReplyPort(IN HANDLE PortHandle, PortHandle, ReplyMessage); - if (KeGetPreviousMode() == UserMode) + /* Check if the call comes from user mode */ + if (PreviousMode != KernelMode) { _SEH2_TRY { - ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); - /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE)); - ReplyMessage = &CapturedReplyMessage;*/ + ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG)); + CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage; } - _SEH2_EXCEPT(ExSystemExceptionFilter()) + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { - DPRINT1("SEH crash [1]\n"); _SEH2_YIELD(return _SEH2_GetExceptionCode()); } _SEH2_END; } + else + { + CapturedReplyMessage = *ReplyMessage; + } /* Validate its length */ - if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) > - (ULONG)ReplyMessage->u1.s1.TotalLength) + if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) > + (ULONG)CapturedReplyMessage.u1.s1.TotalLength) { /* Fail */ return STATUS_INVALID_PARAMETER; } /* Make sure it has a valid ID */ - if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER; + if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER; /* Get the Port object */ Status = ObReferenceObjectByHandle(PortHandle, @@ -240,9 +243,9 @@ NtReplyPort(IN HANDLE PortHandle, if (!NT_SUCCESS(Status)) return Status; /* Validate its length in respect to the port object */ - if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) || - ((ULONG)ReplyMessage->u1.s1.TotalLength <= - (ULONG)ReplyMessage->u1.s1.DataLength)) + if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) || + ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <= + (ULONG)CapturedReplyMessage.u1.s1.DataLength)) { /* Too large, fail */ ObDereferenceObject(Port); @@ -250,7 +253,7 @@ NtReplyPort(IN HANDLE PortHandle, } /* Get the ETHREAD corresponding to it */ - Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId, + Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId, NULL, &WakeupThread); if (!NT_SUCCESS(Status)) @@ -274,7 +277,7 @@ NtReplyPort(IN HANDLE PortHandle, KeAcquireGuardedMutex(&LpcpLock); /* Make sure this is the reply the thread is waiting for */ - if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) || + if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) || ((LpcpGetMessageFromThread(WakeupThread)) && (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)-> Request) != LPC_REQUEST))) @@ -290,7 +293,7 @@ NtReplyPort(IN HANDLE PortHandle, _SEH2_TRY { LpcpMoveMessage(&Message->Request, - ReplyMessage, + &CapturedReplyMessage, ReplyMessage + 1, LPC_REPLY, NULL); @@ -324,7 +327,7 @@ NtReplyPort(IN HANDLE PortHandle, /* Check if this is the message the thread had received */ if ((Thread->LpcReceivedMsgIdValid) && - (Thread->LpcReceivedMessageId == ReplyMessage->MessageId)) + (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId)) { /* Clear this data */ Thread->LpcReceivedMessageId = 0; @@ -333,9 +336,9 @@ NtReplyPort(IN HANDLE PortHandle, /* Free any data information */ LpcpFreeDataInfoMessage(Port, - ReplyMessage->MessageId, - ReplyMessage->CallbackId, - ReplyMessage->ClientId); + CapturedReplyMessage.MessageId, + CapturedReplyMessage.CallbackId, + CapturedReplyMessage.ClientId); /* Release the lock and release the LPC semaphore to wake up waiters */ KeReleaseGuardedMutex(&LpcpLock); @@ -362,7 +365,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, { NTSTATUS Status; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode; - // PORT_MESSAGE CapturedReplyMessage; + PORT_MESSAGE CapturedReplyMessage; LARGE_INTEGER CapturedTimeout; PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL; PLPCP_MESSAGE Message; @@ -378,30 +381,29 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, ReceiveMessage, PortContext); - if (KeGetPreviousMode() == UserMode) + /* Check if the call comes from user mode */ + if (PreviousMode != KernelMode) { _SEH2_TRY { + if (PortContext != NULL) + ProbeForWritePointer(PortContext); + if (ReplyMessage != NULL) { - ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); - /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE)); - ReplyMessage = &CapturedReplyMessage;*/ + ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG)); + CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage; } if (Timeout != NULL) { ProbeForReadLargeInteger(Timeout); - RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER)); + CapturedTimeout = *(volatile LARGE_INTEGER*)Timeout; Timeout = &CapturedTimeout; } - - if (PortContext != NULL) - ProbeForWritePointer(PortContext); } - _SEH2_EXCEPT(ExSystemExceptionFilter()) + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { - DPRINT1("SEH crash [1]\n"); _SEH2_YIELD(return _SEH2_GetExceptionCode()); } _SEH2_END; @@ -410,21 +412,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, { /* If this is a system thread, then let it page out its stack */ if (Thread->SystemThread) WaitMode = UserMode; + + if (ReplyMessage != NULL) + CapturedReplyMessage = *ReplyMessage; } /* Check if caller has a reply message */ if (ReplyMessage) { /* Validate its length */ - if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) > - (ULONG)ReplyMessage->u1.s1.TotalLength) + if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) > + (ULONG)CapturedReplyMessage.u1.s1.TotalLength) { /* Fail */ return STATUS_INVALID_PARAMETER; } /* Make sure it has a valid ID */ - if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER; + if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER; } /* Get the Port object */ @@ -440,9 +445,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, if (ReplyMessage) { /* Validate its length in respect to the port object */ - if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) || - ((ULONG)ReplyMessage->u1.s1.TotalLength <= - (ULONG)ReplyMessage->u1.s1.DataLength)) + if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) || + ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <= + (ULONG)CapturedReplyMessage.u1.s1.DataLength)) { /* Too large, fail */ ObDereferenceObject(Port); @@ -490,7 +495,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, if (ReplyMessage) { /* Get the ETHREAD corresponding to it */ - Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId, + Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId, NULL, &WakeupThread); if (!NT_SUCCESS(Status)) @@ -516,7 +521,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, KeAcquireGuardedMutex(&LpcpLock); /* Make sure this is the reply the thread is waiting for */ - if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) || + if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) || ((LpcpGetMessageFromThread(WakeupThread)) && (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->Request) != LPC_REQUEST))) @@ -530,11 +535,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, } /* Copy the message */ - LpcpMoveMessage(&Message->Request, - ReplyMessage, - ReplyMessage + 1, - LPC_REPLY, - NULL); + _SEH2_TRY + { + LpcpMoveMessage(&Message->Request, + &CapturedReplyMessage, + ReplyMessage + 1, + LPC_REPLY, + NULL); + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + /* Cleanup and return the exception code */ + LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); + ObDereferenceObject(WakeupThread); + ObDereferenceObject(Port); + _SEH2_YIELD(return _SEH2_GetExceptionCode()); + } + _SEH2_END; /* Reference the thread while we use it */ ObReferenceObject(WakeupThread); @@ -555,7 +573,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, /* Check if this is the message the thread had received */ if ((Thread->LpcReceivedMsgIdValid) && - (Thread->LpcReceivedMessageId == ReplyMessage->MessageId)) + (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId)) { /* Clear this data */ Thread->LpcReceivedMessageId = 0; @@ -564,9 +582,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, /* Free any data information */ LpcpFreeDataInfoMessage(Port, - ReplyMessage->MessageId, - ReplyMessage->CallbackId, - ReplyMessage->ClientId); + CapturedReplyMessage.MessageId, + CapturedReplyMessage.CallbackId, + CapturedReplyMessage.ClientId); /* Release the lock and release the LPC semaphore to wake up waiters */ KeReleaseGuardedMutex(&LpcpLock); @@ -688,9 +706,8 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, ASSERT(FALSE); } } - _SEH2_EXCEPT(ExSystemExceptionFilter()) + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { - DPRINT1("SEH crash [2]\n"); Status = _SEH2_GetExceptionCode(); } _SEH2_END; @@ -771,26 +788,24 @@ LpcpCopyRequestData( PAGED_CODE(); - /* Check the previous mode */ - PreviousMode = ExGetPreviousMode(); - if (PreviousMode == KernelMode) - { - CapturedMessage = *Message; - } - else + /* Check if the call comes from user mode */ + if (PreviousMode != KernelMode) { _SEH2_TRY { ProbeForRead(Message, sizeof(*Message), sizeof(PVOID)); - CapturedMessage = *Message; + CapturedMessage = *(volatile PORT_MESSAGE*)Message; } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { - DPRINT1("Got exception!\n"); _SEH2_YIELD(return _SEH2_GetExceptionCode()); } _SEH2_END; } + else + { + CapturedMessage = *Message; + } /* Make sure there is any data to copy */ if (CapturedMessage.u2.s2.DataInfoOffset == 0) diff --git a/reactos/ntoskrnl/lpc/send.c b/reactos/ntoskrnl/lpc/send.c index be5189d7692..74589b99a98 100644 --- a/reactos/ntoskrnl/lpc/send.c +++ b/reactos/ntoskrnl/lpc/send.c @@ -461,9 +461,8 @@ NtRequestPort(IN HANDLE PortHandle, _SEH2_TRY { /* Probe and capture the LpcRequest */ - ProbeForRead(LpcRequest, sizeof(PORT_MESSAGE), sizeof(ULONG)); - ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG)); - CapturedLpcRequest = *LpcRequest; + ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG)); + CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest; } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { @@ -523,7 +522,7 @@ NtRequestPort(IN HANDLE PortHandle, { /* Copy it */ LpcpMoveMessage(&Message->Request, - LpcRequest, + &CapturedLpcRequest, LpcRequest + 1, MessageType, &Thread->Cid); @@ -725,10 +724,9 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, { _SEH2_TRY { - /* Probe the full request message and copy the base structure */ + /* Probe and capture the LpcRequest */ ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG)); - ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG)); - CapturedLpcRequest = *LpcRequest; + CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest; /* Probe the reply message for write */ ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG)); -- 2.17.1