Cleanup isn't necessary after calling the driver in NtQueryDirectoryFile.
[reactos.git] / reactos / ntoskrnl / lpc / send.c
index f70fc96..8725e56 100644 (file)
@@ -1,22 +1,16 @@
-/* $Id: send.c,v 1.15 2004/05/31 11:47:05 gvg Exp $
- * 
+/* $Id$
+ *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  * FILE:            ntoskrnl/lpc/send.c
  * PURPOSE:         Communication mechanism
- * PROGRAMMER:      David Welch (welch@cwcom.net)
- * UPDATE HISTORY:
- *                  Created 22/05/98
+ *
+ * PROGRAMMERS:     David Welch (welch@cwcom.net)
  */
 
 /* INCLUDES *****************************************************************/
 
-#include <ddk/ntddk.h>
-#include <internal/ob.h>
-#include <internal/port.h>
-#include <internal/dbg.h>
-#include <internal/safe.h>
-#include <internal/ps.h>
+#include <ntoskrnl.h>
 
 #define NDEBUG
 #include <internal/debug.h>
  *
  * REVISIONS
  */
-NTSTATUS STDCALL 
+NTSTATUS STDCALL
 LpcSendTerminationPort (IN PEPORT Port,
-                       IN TIME CreationTime)
+                       IN LARGE_INTEGER CreateTime)
 {
   NTSTATUS Status;
-  LPC_TERMINATION_MESSAGE Msg;
-  
+  CLIENT_DIED_MSG Msg;
+
 #ifdef __USE_NT_LPC__
-  Msg.Header.MessageType = LPC_NEW_MESSAGE;
+  Msg.h.u2.s2.Type = LPC_CLIENT_DIED;
 #endif
-  Msg.CreationTime = CreationTime;
-  Status = LpcRequestPort (Port, &Msg.Header);
+  Msg.h.u1.s1.TotalLength = sizeof(Msg);
+  Msg.h.u1.s1.DataLength = sizeof(Msg) - sizeof(PORT_MESSAGE);
+  Msg.CreateTime = CreateTime;
+  Status = LpcRequestPort (Port, &Msg.h);
   return(Status);
 }
 
@@ -62,17 +58,17 @@ LpcSendTerminationPort (IN PEPORT Port,
  *
  * REVISIONS
  */
-NTSTATUS STDCALL 
+NTSTATUS STDCALL
 LpcSendDebugMessagePort (IN PEPORT Port,
-                        IN PLPC_DBG_MESSAGE Message,
-                        OUT PLPC_DBG_MESSAGE Reply)
+                        IN PDBGKM_MSG Message,
+                        OUT PDBGKM_MSG Reply)
 {
    NTSTATUS Status;
    KIRQL oldIrql;
    PQUEUEDMESSAGE ReplyMessage;
-   
-   Status = EiReplyOrRequestPort(Port, 
-                                &Message->Header, 
+
+   Status = EiReplyOrRequestPort(Port,
+                                &Message->h,
                                 LPC_REQUEST,
                                 Port);
    if (!NT_SUCCESS(Status))
@@ -90,14 +86,14 @@ LpcSendDebugMessagePort (IN PEPORT Port,
                         UserMode,
                         FALSE,
                         NULL);
-   
+
    /*
     * Dequeue the reply
     */
    KeAcquireSpinLock(&Port->Lock, &oldIrql);
    ReplyMessage = EiDequeueMessagePort(Port);
    KeReleaseSpinLock(&Port->Lock, oldIrql);
-   memcpy(Reply, &ReplyMessage->Message, ReplyMessage->Message.MessageSize);
+   memcpy(Reply, &ReplyMessage->Message, ReplyMessage->Message.u1.s1.TotalLength);
    ExFreePool(ReplyMessage);
 
    return(STATUS_SUCCESS);
@@ -107,7 +103,7 @@ LpcSendDebugMessagePort (IN PEPORT Port,
 /**********************************************************************
  * NAME
  *     LpcRequestPort/2
- *     
+ *
  * DESCRIPTION
  *
  * ARGUMENTS
@@ -117,30 +113,30 @@ LpcSendDebugMessagePort (IN PEPORT Port,
  * REVISIONS
  *     2002-03-01 EA
  *     I investigated this function a bit more in depth.
- *     It looks like the legal values for the MessageType field in the 
+ *     It looks like the legal values for the MessageType field in the
  *     message to send are in the range LPC_NEW_MESSAGE .. LPC_CLIENT_DIED,
  *     but LPC_DATAGRAM is explicitly forbidden.
  *
  * @implemented
  */
 NTSTATUS STDCALL LpcRequestPort (IN    PEPORT          Port,
-                                IN     PLPC_MESSAGE    LpcMessage)
+                                IN     PPORT_MESSAGE   LpcMessage)
 {
    NTSTATUS Status;
-   
+
    DPRINT("LpcRequestPort(PortHandle %08x, LpcMessage %08x)\n", Port, LpcMessage);
 
 #ifdef __USE_NT_LPC__
    /* Check the message's type */
-   if (LPC_NEW_MESSAGE == LpcMessage->MessageType)
+   if (LPC_NEW_MESSAGE == LpcMessage->u2.s2.Type)
    {
-      LpcMessage->MessageType = LPC_DATAGRAM;
+      LpcMessage->u2.s2.Type = LPC_DATAGRAM;
    }
-   else if (LPC_DATAGRAM == LpcMessage->MessageType)
+   else if (LPC_DATAGRAM == LpcMessage->u2.s2.Type)
    {
       return STATUS_INVALID_PARAMETER;
    }
-   else if (LpcMessage->MessageType > LPC_CLIENT_DIED)
+   else if (LpcMessage->u2.s2.Type > LPC_CLIENT_DIED)
    {
       return STATUS_INVALID_PARAMETER;
    }
@@ -151,8 +147,8 @@ NTSTATUS STDCALL LpcRequestPort (IN PEPORT          Port,
    }
 #endif
 
-   Status = EiReplyOrRequestPort(Port, 
-                                LpcMessage, 
+   Status = EiReplyOrRequestPort(Port,
+                                LpcMessage,
                                 LPC_DATAGRAM,
                                 Port);
    KeReleaseSemaphore( &Port->Semaphore, IO_NO_INCREMENT, 1, FALSE );
@@ -176,17 +172,17 @@ NTSTATUS STDCALL LpcRequestPort (IN       PEPORT          Port,
  * @implemented
  */
 NTSTATUS STDCALL NtRequestPort (IN     HANDLE          PortHandle,
-                               IN      PLPC_MESSAGE    LpcMessage)
+                               IN      PPORT_MESSAGE   LpcMessage)
 {
    NTSTATUS Status;
    PEPORT Port;
-   
-   DPRINT("NtRequestPort(PortHandle %x LpcMessage %x)\n", PortHandle, 
+
+   DPRINT("NtRequestPort(PortHandle %x LpcMessage %x)\n", PortHandle,
          LpcMessage);
-   
+
    Status = ObReferenceObjectByHandle(PortHandle,
                                      PORT_ALL_ACCESS,
-                                     ExPortType,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
@@ -196,9 +192,9 @@ NTSTATUS STDCALL NtRequestPort (IN  HANDLE          PortHandle,
        return(Status);
      }
 
-   Status = LpcRequestPort(Port->OtherPort, 
+   Status = LpcRequestPort(Port->OtherPort,
                           LpcMessage);
-   
+
    ObDereferenceObject(Port);
    return(Status);
 }
@@ -218,26 +214,57 @@ NTSTATUS STDCALL NtRequestPort (IN        HANDLE          PortHandle,
  *
  * @implemented
  */
-NTSTATUS STDCALL 
+NTSTATUS STDCALL
 NtRequestWaitReplyPort (IN HANDLE PortHandle,
-                       PLPC_MESSAGE UnsafeLpcRequest,    
-                       PLPC_MESSAGE UnsafeLpcReply)
+                       PPORT_MESSAGE UnsafeLpcRequest,
+                       PPORT_MESSAGE UnsafeLpcReply)
 {
    PETHREAD CurrentThread;
-   struct _EPROCESS *AttachedProcess;
-   NTSTATUS Status;
+   struct _KPROCESS *AttachedProcess;
    PEPORT Port;
    PQUEUEDMESSAGE Message;
    KIRQL oldIrql;
-   PLPC_MESSAGE LpcRequest;
-   USHORT LpcRequestMessageSize;
+   PPORT_MESSAGE LpcRequest;
+   USHORT LpcRequestMessageSize = 0, LpcRequestDataSize = 0;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS Status = STATUS_SUCCESS;
+   
+   PreviousMode = ExGetPreviousMode();
+   
+   if (PreviousMode != KernelMode)
+     {
+       _SEH_TRY
+         {
+           ProbeForRead(UnsafeLpcRequest,
+                        sizeof(PORT_MESSAGE),
+                        1);
+           ProbeForWrite(UnsafeLpcReply,
+                         sizeof(PORT_MESSAGE),
+                         1);
+           LpcRequestMessageSize = UnsafeLpcRequest->u1.s1.TotalLength;
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           return Status;
+         }
+     }
+   else
+     {
+       LpcRequestMessageSize = UnsafeLpcRequest->u1.s1.TotalLength;
+     }
 
    DPRINT("NtRequestWaitReplyPort(PortHandle %x, LpcRequest %x, "
          "LpcReply %x)\n", PortHandle, UnsafeLpcRequest, UnsafeLpcReply);
 
    Status = ObReferenceObjectByHandle(PortHandle,
-                                     PORT_ALL_ACCESS, 
-                                     ExPortType,
+                                     PORT_ALL_ACCESS,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
@@ -246,34 +273,28 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        return(Status);
      }
 
+   if (EPORT_DISCONNECTED == Port->State)
+     {
+       ObDereferenceObject(Port);
+       return STATUS_PORT_DISCONNECTED;
+     }
+
    /* win32k sometimes needs to KeAttach() the CSRSS process in order to make
       the PortHandle valid. Now that we've got the EPORT structure from the
       handle we can undo this, so everything is normal again. Need to
       re-KeAttach() before returning though */
    CurrentThread = PsGetCurrentThread();
-   if (NULL == CurrentThread->OldProcess)
+   if (&CurrentThread->ThreadsProcess->Pcb == CurrentThread->Tcb.ApcState.Process)
      {
        AttachedProcess = NULL;
      }
    else
      {
-       AttachedProcess = CurrentThread->ThreadsProcess;
+       AttachedProcess = CurrentThread->Tcb.ApcState.Process;
        KeDetachProcess();
      }
 
-   Status = MmCopyFromCaller(&LpcRequestMessageSize,
-                            &UnsafeLpcRequest->MessageSize,
-                            sizeof(USHORT));
-   if (!NT_SUCCESS(Status))
-     {
-       if (NULL != AttachedProcess)
-         {
-           KeAttachProcess(AttachedProcess);
-         }
-       ObDereferenceObject(Port);
-       return(Status);
-     }
-   if (LpcRequestMessageSize > (sizeof(LPC_MESSAGE) + MAX_MESSAGE_DATA))
+   if (LpcRequestMessageSize > LPC_MAX_MESSAGE_LENGTH)
      {
        if (NULL != AttachedProcess)
          {
@@ -292,20 +313,43 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        ObDereferenceObject(Port);
        return(STATUS_NO_MEMORY);
      }
-   Status = MmCopyFromCaller(LpcRequest, UnsafeLpcRequest,
-                            LpcRequestMessageSize);
-   if (!NT_SUCCESS(Status))
+   if (PreviousMode != KernelMode)
      {
-       ExFreePool(LpcRequest);
-       if (NULL != AttachedProcess)
+       _SEH_TRY
          {
-           KeAttachProcess(AttachedProcess);
+           RtlCopyMemory(LpcRequest,
+                         UnsafeLpcRequest,
+                         LpcRequestMessageSize);
+           LpcRequestMessageSize = LpcRequest->u1.s1.TotalLength;
+           LpcRequestDataSize = LpcRequest->u1.s1.DataLength;
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           ExFreePool(LpcRequest);
+           if (NULL != AttachedProcess)
+             {
+               KeAttachProcess(AttachedProcess);
+             }
+           ObDereferenceObject(Port);
+           return(Status);
          }
-       ObDereferenceObject(Port);
-       return(Status);
      }
-   LpcRequestMessageSize = LpcRequest->MessageSize;
-   if (LpcRequestMessageSize > (sizeof(LPC_MESSAGE) + MAX_MESSAGE_DATA))
+   else
+     {
+       RtlCopyMemory(LpcRequest,
+                     UnsafeLpcRequest,
+                     LpcRequestMessageSize);
+       LpcRequestMessageSize = LpcRequest->u1.s1.TotalLength;
+       LpcRequestDataSize = LpcRequest->u1.s1.DataLength;
+     }
+
+   if (LpcRequestMessageSize > LPC_MAX_MESSAGE_LENGTH)
      {
        ExFreePool(LpcRequest);
        if (NULL != AttachedProcess)
@@ -315,7 +359,7 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        ObDereferenceObject(Port);
        return(STATUS_PORT_MESSAGE_TOO_LONG);
      }
-   if (LpcRequest->DataSize != (LpcRequest->MessageSize - sizeof(LPC_MESSAGE)))
+   if (LpcRequestDataSize > LPC_MAX_DATA_LENGTH)
      {
        ExFreePool(LpcRequest);
        if (NULL != AttachedProcess)
@@ -326,8 +370,8 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        return(STATUS_PORT_MESSAGE_TOO_LONG);
      }
 
-   Status = EiReplyOrRequestPort(Port->OtherPort, 
-                                LpcRequest, 
+   Status = EiReplyOrRequestPort(Port->OtherPort,
+                                LpcRequest,
                                 LPC_REQUEST,
                                 Port);
    if (!NT_SUCCESS(Status))
@@ -342,9 +386,9 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        return(Status);
      }
    ExFreePool(LpcRequest);
-   KeReleaseSemaphore (&Port->OtherPort->Semaphore, IO_NO_INCREMENT, 
-                      1, FALSE);   
-   
+   KeReleaseSemaphore (&Port->OtherPort->Semaphore, IO_NO_INCREMENT,
+                      1, FALSE);
+
    /*
     * Wait for a reply
     */
@@ -355,7 +399,7 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
                                  NULL);
    if (Status == STATUS_SUCCESS)
      {
-   
+
        /*
         * Dequeue the reply
         */
@@ -364,10 +408,28 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        KeReleaseSpinLock(&Port->Lock, oldIrql);
        if (Message)
          {
-           DPRINT("Message->Message.MessageSize %d\n",
-                 Message->Message.MessageSize);
-           Status = MmCopyToCaller(UnsafeLpcReply, &Message->Message, 
-                                  Message->Message.MessageSize);
+           DPRINT("Message->Message.u1.s1.TotalLength %d\n",
+                 Message->Message.u1.s1.TotalLength);
+           if (PreviousMode != KernelMode)
+             {
+               _SEH_TRY
+                 {
+                   RtlCopyMemory(UnsafeLpcReply,
+                                 &Message->Message,
+                                 Message->Message.u1.s1.TotalLength);
+                 }
+               _SEH_HANDLE
+                 {
+                   Status = _SEH_GetExceptionCode();
+                 }
+               _SEH_END;
+             }
+           else
+             {
+               RtlCopyMemory(UnsafeLpcReply,
+                             &Message->Message,
+                             Message->Message.u1.s1.TotalLength);
+             }
            ExFreePool(Message);
          }
        else
@@ -385,7 +447,7 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        KeAttachProcess(AttachedProcess);
      }
    ObDereferenceObject(Port);
-   
+
    return(Status);
 }
 
@@ -393,7 +455,7 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
 /**********************************************************************
  * NAME
  *     NtWriteRequestData/6
- *     
+ *
  * DESCRIPTION
  *
  * ARGUMENTS
@@ -403,7 +465,7 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
  * REVISIONS
  */
 NTSTATUS STDCALL NtWriteRequestData (HANDLE            PortHandle,
-                                    PLPC_MESSAGE       Message,
+                                    PPORT_MESSAGE      Message,
                                     ULONG              Index,
                                     PVOID              Buffer,
                                     ULONG              BufferLength,