implemented sweeping of handle tables
[reactos.git] / reactos / ntoskrnl / lpc / reply.c
index 8d002a7..e67db08 100644 (file)
@@ -1,25 +1,22 @@
-/* $Id: reply.c,v 1.2 2000/10/22 16:36:51 ekohl Exp $
- * 
+/* $Id$
+ *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  * FILE:            ntoskrnl/lpc/reply.c
  * PURPOSE:         Communication mechanism
- * PROGRAMMER:      David Welch (welch@cwcom.net)
- * UPDATE HISTORY:
- *                  Created 22/05/98
+ *
+ * PROGRAMMERS:     David Welch (welch@cwcom.net)
  */
 
-/* INCLUDES *****************************************************************/
-
-#include <ddk/ntddk.h>
-#include <internal/ob.h>
-#include <internal/port.h>
-#include <internal/dbg.h>
+/* INCLUDES ******************************************************************/
 
+#include <ntoskrnl.h>
 #define NDEBUG
 #include <internal/debug.h>
 
+/* GLOBALS *******************************************************************/
 
+/* FUNCTIONS *****************************************************************/
 
 /**********************************************************************
  * NAME
  * RETURN VALUE
  *
  * REVISIONS
- *
  */
-NTSTATUS
-STDCALL
-EiReplyOrRequestPort (
-       IN      PEPORT          Port, 
-       IN      PLPC_MESSAGE    LpcReply, 
-       IN      ULONG           MessageType,
-       IN      PEPORT          Sender
-       )
+NTSTATUS STDCALL
+EiReplyOrRequestPort (IN       PEPORT          Port,
+                     IN        PPORT_MESSAGE   LpcReply,
+                     IN        ULONG           MessageType,
+                     IN        PEPORT          Sender)
 {
    KIRQL oldIrql;
    PQUEUEDMESSAGE MessageReply;
-   
-   MessageReply = ExAllocatePool(NonPagedPool, sizeof(QUEUEDMESSAGE));
+   ULONG Size;
+
+   if (Port == NULL)
+     {
+       KEBUGCHECK(0);
+     }
+
+   Size = sizeof(QUEUEDMESSAGE);
+   if (LpcReply && LpcReply->u1.s1.TotalLength > (CSHORT)sizeof(PORT_MESSAGE))
+     {
+       Size += LpcReply->u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+     }
+   MessageReply = ExAllocatePoolWithTag(NonPagedPool, Size, 
+                                       TAG_LPC_MESSAGE);
    MessageReply->Sender = Sender;
-   
+
    if (LpcReply != NULL)
      {
-       memcpy(&MessageReply->Message, LpcReply, LpcReply->MessageSize);
+       memcpy(&MessageReply->Message, LpcReply, LpcReply->u1.s1.TotalLength);
      }
-   
-   MessageReply->Message.Cid.UniqueProcess = PsGetCurrentProcessId();
-   MessageReply->Message.Cid.UniqueThread = PsGetCurrentThreadId();
-   MessageReply->Message.MessageType = MessageType;
-   MessageReply->Message.MessageId = InterlockedIncrement(&EiNextLpcMessageId);
-   
+   else
+     {
+       MessageReply->Message.u1.s1.TotalLength = sizeof(PORT_MESSAGE);
+       MessageReply->Message.u1.s1.DataLength = 0;
+     }
+
+   MessageReply->Message.ClientId.UniqueProcess = PsGetCurrentProcessId();
+   MessageReply->Message.ClientId.UniqueThread = PsGetCurrentThreadId();
+   MessageReply->Message.u2.s2.Type = MessageType;
+   MessageReply->Message.MessageId = InterlockedIncrementUL(&LpcpNextMessageId);
+
    KeAcquireSpinLock(&Port->Lock, &oldIrql);
    EiEnqueueMessagePort(Port, MessageReply);
    KeReleaseSpinLock(&Port->Lock, oldIrql);
-   
+
    return(STATUS_SUCCESS);
 }
 
@@ -76,23 +86,19 @@ EiReplyOrRequestPort (
  * RETURN VALUE
  *
  * REVISIONS
- *
  */
-NTSTATUS
-STDCALL
-NtReplyPort (
-       IN      HANDLE          PortHandle,
-       IN      PLPC_MESSAGE    LpcReply
-       )
+NTSTATUS STDCALL
+NtReplyPort (IN        HANDLE          PortHandle,
+            IN PPORT_MESSAGE   LpcReply)
 {
    NTSTATUS Status;
    PEPORT Port;
-   
+
    DPRINT("NtReplyPort(PortHandle %x, LpcReply %x)\n", PortHandle, LpcReply);
-   
+
    Status = ObReferenceObjectByHandle(PortHandle,
                                      PORT_ALL_ACCESS,   /* AccessRequired */
-                                     ExPortType,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
@@ -101,108 +107,262 @@ NtReplyPort (
        DPRINT("NtReplyPort() = %x\n", Status);
        return(Status);
      }
-   
-   Status = EiReplyOrRequestPort(Port->OtherPort, 
-                                LpcReply, 
+
+   if (EPORT_DISCONNECTED == Port->State)
+     {
+       ObDereferenceObject(Port);
+       return STATUS_PORT_DISCONNECTED;
+     }
+
+   Status = EiReplyOrRequestPort(Port->OtherPort,
+                                LpcReply,
                                 LPC_REPLY,
                                 Port);
-   KeSetEvent(&Port->OtherPort->Event, IO_NO_INCREMENT, FALSE);
-   
+   KeReleaseSemaphore(&Port->OtherPort->Semaphore, IO_NO_INCREMENT, 1, FALSE);
+
    ObDereferenceObject(Port);
-   
+
    return(Status);
 }
 
 
 /**********************************************************************
  * NAME                                                        EXPORTED
+ *     NtReplyWaitReceivePortEx
  *
  * DESCRIPTION
+ *     Can be used with waitable ports.
+ *     Present only in w2k+.
  *
  * ARGUMENTS
+ *     PortHandle
+ *     PortId
+ *     LpcReply
+ *     LpcMessage
+ *     Timeout
  *
  * RETURN VALUE
  *
  * REVISIONS
- *
  */
-NTSTATUS
-STDCALL
-NtReplyWaitReceivePort (
-       HANDLE          PortHandle,
-       PULONG          PortId,
-       PLPC_MESSAGE    LpcReply,     
-       PLPC_MESSAGE    LpcMessage
-       )
+NTSTATUS STDCALL
+NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
+                         OUT PVOID *PortContext OPTIONAL,
+                         IN PPORT_MESSAGE ReplyMessage OPTIONAL,
+                         OUT PPORT_MESSAGE ReceiveMessage,
+                                    IN PLARGE_INTEGER Timeout OPTIONAL)
 {
-   NTSTATUS Status;
    PEPORT Port;
    KIRQL oldIrql;
    PQUEUEDMESSAGE Request;
+   BOOLEAN Disconnected;
+   LARGE_INTEGER to;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS Status = STATUS_SUCCESS;
    
-   DPRINT("NtReplyWaitReceivePort(PortHandle %x, LpcReply %x, "
-         "LpcMessage %x)\n", PortHandle, LpcReply, LpcMessage);
-   
+   PreviousMode = ExGetPreviousMode();
+
+   DPRINT("NtReplyWaitReceivePortEx(PortHandle %x, LpcReply %x, "
+         "LpcMessage %x)\n", PortHandle, ReplyMessage, ReceiveMessage);
+
+   if (PreviousMode != KernelMode)
+     {
+       _SEH_TRY
+         {
+           ProbeForWrite(ReceiveMessage,
+                         sizeof(PORT_MESSAGE),
+                         1);
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           return Status;
+         }
+     }
+
    Status = ObReferenceObjectByHandle(PortHandle,
                                      PORT_ALL_ACCESS,
-                                     ExPortType,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
    if (!NT_SUCCESS(Status))
      {
-       DPRINT("NtReplyWaitReceivePort() = %x\n", Status);
+       DPRINT1("NtReplyWaitReceivePortEx() = %x\n", Status);
        return(Status);
      }
-   
+   if( Port->State == EPORT_DISCONNECTED )
+     {
+       /* If the port is disconnected, force the timeout to be 0
+        * so we don't wait for new messages, because there won't be
+        * any, only try to remove any existing messages
+       */
+       Disconnected = TRUE;
+       to.QuadPart = 0;
+       Timeout = &to;
+     }
+   else Disconnected = FALSE;
+
    /*
-    * Send the reply
+    * Send the reply, only if port is connected
     */
-   if (LpcReply != NULL)
+   if (ReplyMessage != NULL && !Disconnected)
      {
-       Status = EiReplyOrRequestPort(Port->OtherPort, 
-                                     LpcReply,
+       Status = EiReplyOrRequestPort(Port->OtherPort,
+                                     ReplyMessage,
                                      LPC_REPLY,
                                      Port);
-       KeSetEvent(&Port->OtherPort->Event, IO_NO_INCREMENT, FALSE);
-       
+       KeReleaseSemaphore(&Port->OtherPort->Semaphore, IO_NO_INCREMENT, 1,
+                          FALSE);
+
        if (!NT_SUCCESS(Status))
          {
             ObDereferenceObject(Port);
+            DPRINT1("NtReplyWaitReceivePortEx() = %x\n", Status);
             return(Status);
          }
      }
-   
+
    /*
     * Want for a message to be received
     */
-   DPRINT("Entering wait for message\n");
-   KeWaitForSingleObject(&Port->Event,
-                        UserRequest,
-                        UserMode,
-                        FALSE,
-                        NULL);
-   DPRINT("Woke from wait for message\n");
-   
+   Status = KeWaitForSingleObject(&Port->Semaphore,
+                                 UserRequest,
+                                 UserMode,
+                                 FALSE,
+                                 Timeout);
+   if( Status == STATUS_TIMEOUT )
+     {
+       /*
+       * if the port is disconnected, and there are no remaining messages,
+        * return STATUS_PORT_DISCONNECTED
+       */
+       ObDereferenceObject(Port);
+       return(Disconnected ? STATUS_PORT_DISCONNECTED : STATUS_TIMEOUT);
+     }
+
+   if (!NT_SUCCESS(Status))
+     {
+       if (STATUS_THREAD_IS_TERMINATING != Status)
+        {
+          DPRINT1("NtReplyWaitReceivePortEx() = %x\n", Status);
+        }
+       ObDereferenceObject(Port);
+       return(Status);
+     }
+
    /*
     * Dequeue the message
     */
    KeAcquireSpinLock(&Port->Lock, &oldIrql);
    Request = EiDequeueMessagePort(Port);
-   memcpy(LpcMessage, &Request->Message, Request->Message.MessageSize);
-   if (Request->Message.MessageType == LPC_CONNECTION_REQUEST)
+   KeReleaseSpinLock(&Port->Lock, oldIrql);
+
+   if (Request == NULL)
      {
-       EiEnqueueConnectMessagePort(Port, Request);
-       KeReleaseSpinLock(&Port->Lock, oldIrql);
+       ObDereferenceObject(Port);
+       return STATUS_UNSUCCESSFUL;
+     }
+
+   if (Request->Message.u2.s2.Type == LPC_CONNECTION_REQUEST)
+     {
+       PORT_MESSAGE Header;
+       PEPORT_CONNECT_REQUEST_MESSAGE CRequest;
+
+       CRequest = (PEPORT_CONNECT_REQUEST_MESSAGE)&Request->Message;
+       memcpy(&Header, &Request->Message, sizeof(PORT_MESSAGE));
+       Header.u1.s1.DataLength = CRequest->ConnectDataLength;
+       Header.u1.s1.TotalLength = Header.u1.s1.DataLength + sizeof(PORT_MESSAGE);
+       
+       if (PreviousMode != KernelMode)
+         {
+           _SEH_TRY
+             {
+               ProbeForWrite((PVOID)(ReceiveMessage + 1),
+                             CRequest->ConnectDataLength,
+                             1);
+
+               RtlCopyMemory(ReceiveMessage,
+                             &Header,
+                             sizeof(PORT_MESSAGE));
+               RtlCopyMemory((PVOID)(ReceiveMessage + 1),
+                             CRequest->ConnectData,
+                             CRequest->ConnectDataLength);
+             }
+           _SEH_HANDLE
+             {
+               Status = _SEH_GetExceptionCode();
+             }
+           _SEH_END;
+         }
+       else
+         {
+           RtlCopyMemory(ReceiveMessage,
+                         &Header,
+                         sizeof(PORT_MESSAGE));
+           RtlCopyMemory((PVOID)(ReceiveMessage + 1),
+                         CRequest->ConnectData,
+                         CRequest->ConnectDataLength);
+         }
      }
    else
      {
-       KeReleaseSpinLock(&Port->Lock, oldIrql);
-       ExFreePool(Request);
+       if (PreviousMode != KernelMode)
+         {
+           _SEH_TRY
+             {
+               ProbeForWrite(ReceiveMessage,
+                             Request->Message.u1.s1.TotalLength,
+                             1);
+
+               RtlCopyMemory(ReceiveMessage,
+                             &Request->Message,
+                             Request->Message.u1.s1.TotalLength);
+             }
+           _SEH_HANDLE
+             {
+               Status = _SEH_GetExceptionCode();
+             }
+           _SEH_END;
+         }
+       else
+         {
+           RtlCopyMemory(ReceiveMessage,
+                         &Request->Message,
+                         Request->Message.u1.s1.TotalLength);
+         }
      }
-   
+   if (!NT_SUCCESS(Status))
+     {
+       /*
+       * Copying the message to the caller's buffer failed so
+       * undo what we did and return.
+       * FIXME: Also increment semaphore.
+       */
+       KeAcquireSpinLock(&Port->Lock, &oldIrql);
+       EiEnqueueMessageAtHeadPort(Port, Request);
+       KeReleaseSpinLock(&Port->Lock, oldIrql);
+       ObDereferenceObject(Port);
+       return(Status);
+     }
+   if (Request->Message.u2.s2.Type == LPC_CONNECTION_REQUEST)
+     {
+       KeAcquireSpinLock(&Port->Lock, &oldIrql);
+       EiEnqueueConnectMessagePort(Port, Request);
+       KeReleaseSpinLock(&Port->Lock, oldIrql);
+     }
+   else
+     {
+       ExFreePool(Request);
+     }
+
    /*
-    * 
+    * Dereference the port
     */
    ObDereferenceObject(Port);
    return(STATUS_SUCCESS);
@@ -210,26 +370,67 @@ NtReplyWaitReceivePort (
 
 
 /**********************************************************************
- * NAME
+ * NAME                                                EXPORTED
+ *     NtReplyWaitReceivePort
  *
  * DESCRIPTION
+ *     Can be used with waitable ports.
  *
  * ARGUMENTS
+ *     PortHandle
+ *     PortId
+ *     LpcReply
+ *     LpcMessage
  *
  * RETURN VALUE
  *
  * REVISIONS
+ */
+NTSTATUS STDCALL
+NtReplyWaitReceivePort(IN HANDLE PortHandle,
+                       OUT PVOID *PortContext OPTIONAL,
+                       IN PPORT_MESSAGE ReplyMessage OPTIONAL,
+                       OUT PPORT_MESSAGE ReceiveMessage)
+{
+    return NtReplyWaitReceivePortEx(PortHandle,
+                                                   PortContext,
+                                                   ReplyMessage,
+                                                   ReceiveMessage,
+                                                   NULL);
+}
+
+/**********************************************************************
+ * NAME
+ *
+ * DESCRIPTION
+ *
+ * ARGUMENTS
+ *
+ * RETURN VALUE
  *
+ * REVISIONS
+ */
+NTSTATUS STDCALL
+NtReplyWaitReplyPort (HANDLE           PortHandle,
+                     PPORT_MESSAGE     ReplyMessage)
+{
+   UNIMPLEMENTED;
+   return(STATUS_NOT_IMPLEMENTED);
+}
+
+/*
+ * @unimplemented
  */
 NTSTATUS
 STDCALL
-NtReplyWaitReplyPort (
-       HANDLE          PortHandle,
-       PLPC_MESSAGE    ReplyMessage
+LpcRequestWaitReplyPort (
+       IN PEPORT               Port,
+       IN PPORT_MESSAGE        LpcMessageRequest,
+       OUT PPORT_MESSAGE       LpcMessageReply
        )
 {
        UNIMPLEMENTED;
+       return STATUS_NOT_IMPLEMENTED;
 }
 
-
 /* EOF */