implemented sweeping of handle tables
[reactos.git] / reactos / ntoskrnl / lpc / reply.c
index db71510..e67db08 100644 (file)
@@ -1,12 +1,11 @@
-/* $Id: reply.c,v 1.21 2004/08/15 16:39:06 chorns Exp $
- * 
+/* $Id$
+ *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  * FILE:            ntoskrnl/lpc/reply.c
  * PURPOSE:         Communication mechanism
- * PROGRAMMER:      David Welch (welch@cwcom.net)
- * UPDATE HISTORY:
- *                  Created 22/05/98
+ *
+ * PROGRAMMERS:     David Welch (welch@cwcom.net)
  */
 
 /* INCLUDES ******************************************************************/
@@ -17,8 +16,6 @@
 
 /* GLOBALS *******************************************************************/
 
-#define TAG_LPC_MESSAGE   TAG('L', 'P', 'C', 'M')
-
 /* FUNCTIONS *****************************************************************/
 
 /**********************************************************************
  * REVISIONS
  */
 NTSTATUS STDCALL
-EiReplyOrRequestPort (IN       PEPORT          Port, 
-                     IN        PLPC_MESSAGE    LpcReply, 
+EiReplyOrRequestPort (IN       PEPORT          Port,
+                     IN        PPORT_MESSAGE   LpcReply,
                      IN        ULONG           MessageType,
                      IN        PEPORT          Sender)
 {
    KIRQL oldIrql;
    PQUEUEDMESSAGE MessageReply;
-   
+   ULONG Size;
+
    if (Port == NULL)
      {
        KEBUGCHECK(0);
      }
 
-   MessageReply = ExAllocatePoolWithTag(NonPagedPool, sizeof(QUEUEDMESSAGE),
+   Size = sizeof(QUEUEDMESSAGE);
+   if (LpcReply && LpcReply->u1.s1.TotalLength > (CSHORT)sizeof(PORT_MESSAGE))
+     {
+       Size += LpcReply->u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+     }
+   MessageReply = ExAllocatePoolWithTag(NonPagedPool, Size, 
                                        TAG_LPC_MESSAGE);
    MessageReply->Sender = Sender;
-   
+
    if (LpcReply != NULL)
      {
-       memcpy(&MessageReply->Message, LpcReply, LpcReply->MessageSize);
+       memcpy(&MessageReply->Message, LpcReply, LpcReply->u1.s1.TotalLength);
      }
-   
+   else
+     {
+       MessageReply->Message.u1.s1.TotalLength = sizeof(PORT_MESSAGE);
+       MessageReply->Message.u1.s1.DataLength = 0;
+     }
+
    MessageReply->Message.ClientId.UniqueProcess = PsGetCurrentProcessId();
    MessageReply->Message.ClientId.UniqueThread = PsGetCurrentThreadId();
-   MessageReply->Message.MessageType = MessageType;
-   MessageReply->Message.MessageId = InterlockedIncrement((LONG *)&EiNextLpcMessageId);
-   
+   MessageReply->Message.u2.s2.Type = MessageType;
+   MessageReply->Message.MessageId = InterlockedIncrementUL(&LpcpNextMessageId);
+
    KeAcquireSpinLock(&Port->Lock, &oldIrql);
    EiEnqueueMessagePort(Port, MessageReply);
    KeReleaseSpinLock(&Port->Lock, oldIrql);
-   
+
    return(STATUS_SUCCESS);
 }
 
@@ -81,16 +89,16 @@ EiReplyOrRequestPort (IN    PEPORT          Port,
  */
 NTSTATUS STDCALL
 NtReplyPort (IN        HANDLE          PortHandle,
-            IN PLPC_MESSAGE    LpcReply)
+            IN PPORT_MESSAGE   LpcReply)
 {
    NTSTATUS Status;
    PEPORT Port;
-   
+
    DPRINT("NtReplyPort(PortHandle %x, LpcReply %x)\n", PortHandle, LpcReply);
-   
+
    Status = ObReferenceObjectByHandle(PortHandle,
                                      PORT_ALL_ACCESS,   /* AccessRequired */
-                                     ExPortType,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
@@ -99,15 +107,21 @@ NtReplyPort (IN    HANDLE          PortHandle,
        DPRINT("NtReplyPort() = %x\n", Status);
        return(Status);
      }
-   
-   Status = EiReplyOrRequestPort(Port->OtherPort, 
-                                LpcReply, 
+
+   if (EPORT_DISCONNECTED == Port->State)
+     {
+       ObDereferenceObject(Port);
+       return STATUS_PORT_DISCONNECTED;
+     }
+
+   Status = EiReplyOrRequestPort(Port->OtherPort,
+                                LpcReply,
                                 LPC_REPLY,
                                 Port);
    KeReleaseSemaphore(&Port->OtherPort->Semaphore, IO_NO_INCREMENT, 1, FALSE);
-   
+
    ObDereferenceObject(Port);
-   
+
    return(Status);
 }
 
@@ -115,11 +129,11 @@ NtReplyPort (IN   HANDLE          PortHandle,
 /**********************************************************************
  * NAME                                                        EXPORTED
  *     NtReplyWaitReceivePortEx
- *     
+ *
  * DESCRIPTION
  *     Can be used with waitable ports.
  *     Present only in w2k+.
- *     
+ *
  * ARGUMENTS
  *     PortHandle
  *     PortId
@@ -132,25 +146,48 @@ NtReplyPort (IN   HANDLE          PortHandle,
  * REVISIONS
  */
 NTSTATUS STDCALL
-NtReplyWaitReceivePortEx(IN  HANDLE            PortHandle,
-                        OUT PULONG             PortId,
-                        IN  PLPC_MESSAGE       LpcReply,     
-                        OUT PLPC_MESSAGE       LpcMessage,
-                        IN  PLARGE_INTEGER     Timeout)
+NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
+                         OUT PVOID *PortContext OPTIONAL,
+                         IN PPORT_MESSAGE ReplyMessage OPTIONAL,
+                         OUT PPORT_MESSAGE ReceiveMessage,
+                                    IN PLARGE_INTEGER Timeout OPTIONAL)
 {
-   NTSTATUS Status;
    PEPORT Port;
    KIRQL oldIrql;
    PQUEUEDMESSAGE Request;
    BOOLEAN Disconnected;
    LARGE_INTEGER to;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS Status = STATUS_SUCCESS;
    
+   PreviousMode = ExGetPreviousMode();
+
    DPRINT("NtReplyWaitReceivePortEx(PortHandle %x, LpcReply %x, "
-         "LpcMessage %x)\n", PortHandle, LpcReply, LpcMessage);
-   
+         "LpcMessage %x)\n", PortHandle, ReplyMessage, ReceiveMessage);
+
+   if (PreviousMode != KernelMode)
+     {
+       _SEH_TRY
+         {
+           ProbeForWrite(ReceiveMessage,
+                         sizeof(PORT_MESSAGE),
+                         1);
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           return Status;
+         }
+     }
+
    Status = ObReferenceObjectByHandle(PortHandle,
                                      PORT_ALL_ACCESS,
-                                     ExPortType,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
@@ -170,19 +207,19 @@ NtReplyWaitReceivePortEx(IN  HANDLE               PortHandle,
        Timeout = &to;
      }
    else Disconnected = FALSE;
-   
+
    /*
     * Send the reply, only if port is connected
     */
-   if (LpcReply != NULL && !Disconnected)
+   if (ReplyMessage != NULL && !Disconnected)
      {
-       Status = EiReplyOrRequestPort(Port->OtherPort, 
-                                     LpcReply,
+       Status = EiReplyOrRequestPort(Port->OtherPort,
+                                     ReplyMessage,
                                      LPC_REPLY,
                                      Port);
-       KeReleaseSemaphore(&Port->OtherPort->Semaphore, IO_NO_INCREMENT, 1, 
+       KeReleaseSemaphore(&Port->OtherPort->Semaphore, IO_NO_INCREMENT, 1,
                           FALSE);
-       
+
        if (!NT_SUCCESS(Status))
          {
             ObDereferenceObject(Port);
@@ -190,7 +227,7 @@ NtReplyWaitReceivePortEx(IN  HANDLE         PortHandle,
             return(Status);
          }
      }
-   
+
    /*
     * Want for a message to be received
     */
@@ -208,7 +245,7 @@ NtReplyWaitReceivePortEx(IN  HANDLE         PortHandle,
        ObDereferenceObject(Port);
        return(Disconnected ? STATUS_PORT_DISCONNECTED : STATUS_TIMEOUT);
      }
-   
+
    if (!NT_SUCCESS(Status))
      {
        if (STATUS_THREAD_IS_TERMINATING != Status)
@@ -226,32 +263,84 @@ NtReplyWaitReceivePortEx(IN  HANDLE               PortHandle,
    Request = EiDequeueMessagePort(Port);
    KeReleaseSpinLock(&Port->Lock, oldIrql);
 
-   if (Request->Message.MessageType == LPC_CONNECTION_REQUEST)
+   if (Request == NULL)
      {
-       LPC_MESSAGE Header;
+       ObDereferenceObject(Port);
+       return STATUS_UNSUCCESSFUL;
+     }
+
+   if (Request->Message.u2.s2.Type == LPC_CONNECTION_REQUEST)
+     {
+       PORT_MESSAGE Header;
        PEPORT_CONNECT_REQUEST_MESSAGE CRequest;
 
        CRequest = (PEPORT_CONNECT_REQUEST_MESSAGE)&Request->Message;
-       memcpy(&Header, &Request->Message, sizeof(LPC_MESSAGE));
-       Header.DataSize = CRequest->ConnectDataLength;
-       Header.MessageSize = Header.DataSize + sizeof(LPC_MESSAGE);
-       Status = MmCopyToCaller(LpcMessage, &Header, sizeof(LPC_MESSAGE));
-       if (NT_SUCCESS(Status))
-        {
-          Status = MmCopyToCaller((PVOID)(LpcMessage + 1),
-                                  CRequest->ConnectData,
-                                  CRequest->ConnectDataLength);
-        }
+       memcpy(&Header, &Request->Message, sizeof(PORT_MESSAGE));
+       Header.u1.s1.DataLength = CRequest->ConnectDataLength;
+       Header.u1.s1.TotalLength = Header.u1.s1.DataLength + sizeof(PORT_MESSAGE);
+       
+       if (PreviousMode != KernelMode)
+         {
+           _SEH_TRY
+             {
+               ProbeForWrite((PVOID)(ReceiveMessage + 1),
+                             CRequest->ConnectDataLength,
+                             1);
+
+               RtlCopyMemory(ReceiveMessage,
+                             &Header,
+                             sizeof(PORT_MESSAGE));
+               RtlCopyMemory((PVOID)(ReceiveMessage + 1),
+                             CRequest->ConnectData,
+                             CRequest->ConnectDataLength);
+             }
+           _SEH_HANDLE
+             {
+               Status = _SEH_GetExceptionCode();
+             }
+           _SEH_END;
+         }
+       else
+         {
+           RtlCopyMemory(ReceiveMessage,
+                         &Header,
+                         sizeof(PORT_MESSAGE));
+           RtlCopyMemory((PVOID)(ReceiveMessage + 1),
+                         CRequest->ConnectData,
+                         CRequest->ConnectDataLength);
+         }
      }
    else
      {
-       Status = MmCopyToCaller(LpcMessage, &Request->Message,
-                              Request->Message.MessageSize);
+       if (PreviousMode != KernelMode)
+         {
+           _SEH_TRY
+             {
+               ProbeForWrite(ReceiveMessage,
+                             Request->Message.u1.s1.TotalLength,
+                             1);
+
+               RtlCopyMemory(ReceiveMessage,
+                             &Request->Message,
+                             Request->Message.u1.s1.TotalLength);
+             }
+           _SEH_HANDLE
+             {
+               Status = _SEH_GetExceptionCode();
+             }
+           _SEH_END;
+         }
+       else
+         {
+           RtlCopyMemory(ReceiveMessage,
+                         &Request->Message,
+                         Request->Message.u1.s1.TotalLength);
+         }
      }
    if (!NT_SUCCESS(Status))
      {
        /*
-       * Copying the message to the caller's buffer failed so 
+       * Copying the message to the caller's buffer failed so
        * undo what we did and return.
        * FIXME: Also increment semaphore.
        */
@@ -261,7 +350,7 @@ NtReplyWaitReceivePortEx(IN  HANDLE         PortHandle,
        ObDereferenceObject(Port);
        return(Status);
      }
-   if (Request->Message.MessageType == LPC_CONNECTION_REQUEST)
+   if (Request->Message.u2.s2.Type == LPC_CONNECTION_REQUEST)
      {
        KeAcquireSpinLock(&Port->Lock, &oldIrql);
        EiEnqueueConnectMessagePort(Port, Request);
@@ -271,7 +360,7 @@ NtReplyWaitReceivePortEx(IN  HANDLE         PortHandle,
      {
        ExFreePool(Request);
      }
-   
+
    /*
     * Dereference the port
     */
@@ -283,10 +372,10 @@ NtReplyWaitReceivePortEx(IN  HANDLE               PortHandle,
 /**********************************************************************
  * NAME                                                EXPORTED
  *     NtReplyWaitReceivePort
- *     
+ *
  * DESCRIPTION
  *     Can be used with waitable ports.
- *     
+ *
  * ARGUMENTS
  *     PortHandle
  *     PortId
@@ -298,16 +387,16 @@ NtReplyWaitReceivePortEx(IN  HANDLE               PortHandle,
  * REVISIONS
  */
 NTSTATUS STDCALL
-NtReplyWaitReceivePort (IN  HANDLE             PortHandle,
-                       OUT PULONG              PortId,
-                       IN  PLPC_MESSAGE        LpcReply,     
-                       OUT PLPC_MESSAGE        LpcMessage)
+NtReplyWaitReceivePort(IN HANDLE PortHandle,
+                       OUT PVOID *PortContext OPTIONAL,
+                       IN PPORT_MESSAGE ReplyMessage OPTIONAL,
+                       OUT PPORT_MESSAGE ReceiveMessage)
 {
-  return(NtReplyWaitReceivePortEx (PortHandle,
-                                  PortId,
-                                  LpcReply,
-                                  LpcMessage,
-                                  NULL));
+    return NtReplyWaitReceivePortEx(PortHandle,
+                                                   PortContext,
+                                                   ReplyMessage,
+                                                   ReceiveMessage,
+                                                   NULL);
 }
 
 /**********************************************************************
@@ -323,7 +412,7 @@ NtReplyWaitReceivePort (IN  HANDLE          PortHandle,
  */
 NTSTATUS STDCALL
 NtReplyWaitReplyPort (HANDLE           PortHandle,
-                     PLPC_MESSAGE      ReplyMessage)
+                     PPORT_MESSAGE     ReplyMessage)
 {
    UNIMPLEMENTED;
    return(STATUS_NOT_IMPLEMENTED);
@@ -336,8 +425,8 @@ NTSTATUS
 STDCALL
 LpcRequestWaitReplyPort (
        IN PEPORT               Port,
-       IN PLPC_MESSAGE LpcMessageRequest,
-       OUT PLPC_MESSAGE        LpcMessageReply
+       IN PPORT_MESSAGE        LpcMessageRequest,
+       OUT PPORT_MESSAGE       LpcMessageReply
        )
 {
        UNIMPLEMENTED;