Use macros for LPC message limits in current LPC implementation.
[reactos.git] / reactos / ntoskrnl / lpc / connect.c
index ab8ac40..0a12bad 100644 (file)
@@ -1,4 +1,4 @@
-/* $Id: connect.c,v 1.12 2002/09/08 10:23:32 chorns Exp $
+/* $Id: connect.c,v 1.24 2004/02/01 18:19:28 ea Exp $
  * 
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
@@ -11,7 +11,8 @@
 
 /* INCLUDES *****************************************************************/
 
-#include <ddk/ntddk.h>
+#define NTOS_MODE_KERNEL
+#include <ntos.h>
 #include <internal/ob.h>
 #include <internal/port.h>
 #include <internal/dbg.h>
 
 /* FUNCTIONS *****************************************************************/
 
+/**********************************************************************
+ * NAME                                                        EXPORTED
+ *     EiConnectPort/12
+ *
+ * DESCRIPTION
+ *
+ * ARGUMENTS
+ *
+ * RETURN VALUE
+ */
 NTSTATUS STDCALL
 EiConnectPort(IN PEPORT* ConnectedPort,
              IN PEPORT NamedPort,
@@ -62,10 +73,14 @@ EiConnectPort(IN PEPORT* ConnectedPort,
   /*
    * Create a port to represent our side of the connection
    */
-  Status = ObCreateObject (NULL,
-                          PORT_ALL_ACCESS,
-                          NULL,
+  Status = ObCreateObject (KernelMode,
                           ExPortType,
+                          NULL,
+                          KernelMode,
+                          NULL,
+                          sizeof(EPORT),
+                          0,
+                          0,
                           (PVOID*)&OurPort);
   if (!NT_SUCCESS(Status))
     {
@@ -90,12 +105,12 @@ EiConnectPort(IN PEPORT* ConnectedPort,
    */
   RequestMessage->MessageHeader.DataSize = 
     sizeof(EPORT_CONNECT_REQUEST_MESSAGE) + RequestConnectDataLength -
-    sizeof(LPC_MESSAGE_HEADER);
+    sizeof(LPC_MESSAGE);
   RequestMessage->MessageHeader.MessageSize = 
     sizeof(EPORT_CONNECT_REQUEST_MESSAGE) + RequestConnectDataLength;
   DPRINT("RequestMessageSize %d\n",
         RequestMessage->MessageHeader.MessageSize);
-  RequestMessage->MessageHeader.SharedSectionSize = 0;
+  RequestMessage->MessageHeader.SectionSize = 0;
   RequestMessage->ConnectingProcess = PsGetCurrentProcess();
   ObReferenceObjectByPointer(RequestMessage->ConnectingProcess,
                             PROCESS_VM_OPERATION,
@@ -214,11 +229,11 @@ EiConnectPort(IN PEPORT* ConnectedPort,
 
 /**********************************************************************
  * NAME                                                        EXPORTED
- *     NtConnectPort@32
+ *     NtConnectPort/8
  *     
  * DESCRIPTION
  *     Connect to a named port and wait for the other side to 
- *     accept the connection.
+ *     accept or reject the connection request.
  *
  * ARGUMENTS
  *     ConnectedPort
@@ -232,6 +247,7 @@ EiConnectPort(IN PEPORT* ConnectedPort,
  * 
  * RETURN VALUE
  * 
+ * @unimplemented
  */
 NTSTATUS STDCALL
 NtConnectPort (PHANDLE                         UnsafeConnectedPortHandle,
@@ -260,7 +276,8 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
    */
   if (UnsafeWriteMap != NULL)
     {
-      Status = MmCopyFromCaller(&WriteMap, UnsafeWriteMap, 
+      Status = MmCopyFromCaller(&WriteMap,
+                               UnsafeWriteMap,
                                sizeof(LPC_SECTION_WRITE));
       if (!NT_SUCCESS(Status))
        {
@@ -308,7 +325,7 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
            }
          Status = MmCopyFromCaller(ConnectData,
                                    UnsafeConnectData,
-                                   ConnectDataLength); 
+                                   ConnectDataLength);
          if (!NT_SUCCESS(Status))
            {
              ExFreePool(ConnectData);
@@ -385,11 +402,13 @@ NtConnectPort (PHANDLE                            UnsafeConnectedPortHandle,
        {
          if (ExGetPreviousMode() != KernelMode)
            {
-             MmCopyToCaller(UnsafeConnectData, ConnectData,
+             MmCopyToCaller(UnsafeConnectData,
+                            ConnectData,
                             ConnectDataLength);
              ExFreePool(ConnectData);
            }
-         MmCopyToCaller(UnsafeConnectDataLength, &ConnectDataLength,
+         MmCopyToCaller(UnsafeConnectDataLength,
+                        &ConnectDataLength,
                         sizeof(ULONG));
        }
       return(Status);
@@ -415,7 +434,8 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
        {
          if (ExGetPreviousMode() != KernelMode)
            {
-             Status = MmCopyToCaller(UnsafeConnectData, ConnectData,
+             Status = MmCopyToCaller(UnsafeConnectData,
+                                     ConnectData,
                                      ConnectDataLength);
              ExFreePool(ConnectData);
              if (!NT_SUCCESS(Status))
@@ -423,7 +443,8 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
                  return(Status);
                }
            }
-         Status = MmCopyToCaller(UnsafeConnectDataLength, &ConnectDataLength,
+         Status = MmCopyToCaller(UnsafeConnectDataLength,
+                                 &ConnectDataLength,
                                  sizeof(ULONG));
          if (!NT_SUCCESS(Status))
            {
@@ -441,7 +462,8 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
     {
       return(Status);
     }
-  Status = MmCopyToCaller(UnsafeConnectedPortHandle, &ConnectedPortHandle,
+  Status = MmCopyToCaller(UnsafeConnectedPortHandle,
+                         &ConnectedPortHandle,
                          sizeof(HANDLE));
   if (!NT_SUCCESS(Status))
     {
@@ -449,7 +471,8 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
     }
   if (UnsafeWriteMap != NULL)
     {
-      Status = MmCopyToCaller(UnsafeWriteMap, &WriteMap, 
+      Status = MmCopyToCaller(UnsafeWriteMap,
+                             &WriteMap,
                              sizeof(LPC_SECTION_WRITE));
       if (!NT_SUCCESS(Status))
        {
@@ -458,7 +481,8 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
     }
   if (UnsafeReadMap != NULL)
     {
-      Status = MmCopyToCaller(UnsafeReadMap, &ReadMap,
+      Status = MmCopyToCaller(UnsafeReadMap,
+                             &ReadMap,
                              sizeof(LPC_SECTION_READ));
       if (!NT_SUCCESS(Status))
        {
@@ -467,9 +491,9 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
     }
   if (UnsafeMaximumMessageSize != NULL)
     {
-      Status = MmCopyToCaller(UnsafeMaximumMessageSize, 
+      Status = MmCopyToCaller(UnsafeMaximumMessageSize,
                              &MaximumMessageSize,
-                             sizeof(LPC_SECTION_WRITE));
+                             sizeof(ULONG));
       if (!NT_SUCCESS(Status))
        {
          return(Status);
@@ -479,7 +503,6 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
   /*
    * All done.
    */
-  ObDereferenceObject(ConnectedPort);
 
   return(STATUS_SUCCESS);
 }
@@ -487,7 +510,7 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
 
 /**********************************************************************
  * NAME                                                        EXPORTED
- *     NtAcceptConnectPort@24
+ *     NtAcceptConnectPort/6
  *
  * DESCRIPTION
  *
@@ -500,9 +523,8 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
  *     ReadMap
  *
  * RETURN VALUE
- *
  */
-EXPORTED NTSTATUS STDCALL
+/*EXPORTED*/ NTSTATUS STDCALL
 NtAcceptConnectPort (PHANDLE                   ServerPortHandle,
                     HANDLE                     NamedPortHandle,
                     PLPC_MESSAGE               LpcMessage,
@@ -517,9 +539,15 @@ NtAcceptConnectPort (PHANDLE                       ServerPortHandle,
   KIRQL                oldIrql;
   PEPORT_CONNECT_REQUEST_MESSAGE CRequest;
   PEPORT_CONNECT_REPLY_MESSAGE CReply;
+  ULONG Size;
+
+  Size = sizeof(EPORT_CONNECT_REPLY_MESSAGE);
+  if (LpcMessage)
+  {
+     Size += LpcMessage->DataSize;
+  }
 
-  CReply = ExAllocatePool(NonPagedPool, 
-                sizeof(EPORT_CONNECT_REPLY_MESSAGE) + LpcMessage->DataSize);
+  CReply = ExAllocatePool(NonPagedPool, Size);
   if (CReply == NULL)
     {
       return(STATUS_NO_MEMORY);
@@ -540,21 +568,41 @@ NtAcceptConnectPort (PHANDLE                      ServerPortHandle,
   /*
    * Create a port object for our side of the connection
    */
-  if (AcceptIt == 1)
+  if (AcceptIt)
     {
-      Status = ObCreateObject(ServerPortHandle,
-                             PORT_ALL_ACCESS,
-                             NULL,
+      Status = ObCreateObject(ExGetPreviousMode(),
                              ExPortType,
+                             NULL,
+                             ExGetPreviousMode(),
+                             NULL,
+                             sizeof(EPORT),
+                             0,
+                             0,
                              (PVOID*)&OurPort);
       if (!NT_SUCCESS(Status))
        {
+         ExFreePool(CReply);
+         ObDereferenceObject(NamedPort);
+         return(Status);
+       }
+
+      Status = ObInsertObject ((PVOID)OurPort,
+                              NULL,
+                              PORT_ALL_ACCESS,
+                              0,
+                              NULL,
+                              ServerPortHandle);
+      if (!NT_SUCCESS(Status))
+       {
+         ObDereferenceObject(OurPort);
+         ExFreePool(CReply);
          ObDereferenceObject(NamedPort);
          return(Status);
        }
+
       NiInitializePort(OurPort);
     }
-  
+
   /*
    * Dequeue the connection request
    */
@@ -568,23 +616,23 @@ NtAcceptConnectPort (PHANDLE                      ServerPortHandle,
    */
   if (LpcMessage != NULL)
     {
-      memcpy(&CReply->MessageHeader, LpcMessage, sizeof(LPC_MESSAGE_HEADER));
+      memcpy(&CReply->MessageHeader, LpcMessage, sizeof(LPC_MESSAGE));
       memcpy(&CReply->ConnectData, (PVOID)(LpcMessage + 1), 
             LpcMessage->DataSize);
       CReply->MessageHeader.MessageSize =
        sizeof(EPORT_CONNECT_REPLY_MESSAGE) + LpcMessage->DataSize;
       CReply->MessageHeader.DataSize = CReply->MessageHeader.MessageSize -
-       sizeof(LPC_MESSAGE_HEADER);
+       sizeof(LPC_MESSAGE);
       CReply->ConnectDataLength = LpcMessage->DataSize;
     }
   else
     {
       CReply->MessageHeader.MessageSize = sizeof(EPORT_CONNECT_REPLY_MESSAGE);
       CReply->MessageHeader.DataSize = sizeof(EPORT_CONNECT_REPLY_MESSAGE) -
-       sizeof(LPC_MESSAGE_HEADER);
+       sizeof(LPC_MESSAGE);
       CReply->ConnectDataLength = 0;
     }
-  if (AcceptIt != 1)
+  if (!AcceptIt)
     {  
       EiReplyOrRequestPort(ConnectionRequest->Sender,
                           &CReply->MessageHeader,
@@ -694,7 +742,7 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
     {
       CReply->ReceiveClientViewBase = WriteMap->TargetViewBase;
     }
-  CReply->MaximumMessageSize = 0x148;
+  CReply->MaximumMessageSize = PORT_MAX_MESSAGE_LENGTH;
 
 
   /*
@@ -707,6 +755,7 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
                       LPC_REPLY,
                       OurPort);
   ExFreePool(ConnectionRequest);
+  ExFreePool(CReply);
    
   ObDereferenceObject(OurPort);
   ObDereferenceObject(NamedPort);
@@ -716,7 +765,7 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
 
 /**********************************************************************
  * NAME                                                        EXPORTED
- *     NtSecureConnectPort@36
+ *     NtSecureConnectPort/9
  *     
  * DESCRIPTION
  *     Connect to a named port and wait for the other side to 
@@ -726,7 +775,7 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
  *
  * ARGUMENTS
  *     ConnectedPort
- *     PortName
+ *     PortName: fully qualified name in the Ob name space;
  *     Qos
  *     WriteMap
  *     ServerSid
@@ -736,7 +785,6 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
  *     UserConnectInfoLength
  * 
  * RETURN VALUE
- * 
  */
 NTSTATUS STDCALL
 NtSecureConnectPort (OUT    PHANDLE                            ConnectedPort,
@@ -749,7 +797,16 @@ NtSecureConnectPort (OUT    PHANDLE                                ConnectedPort,
                     IN OUT PVOID                               ConnectInfo             OPTIONAL,
                     IN OUT PULONG                              UserConnectInfoLength   OPTIONAL)
 {
-       return (STATUS_NOT_IMPLEMENTED);
+  /* TODO: implement a new object type: WaitablePort */
+  /* TODO: verify the process' SID that hosts the rendez-vous port equals ServerSid */
+  return NtConnectPort (ConnectedPort,
+                       PortName,
+                       Qos,
+                       WriteMap,
+                       ReadMap,
+                       MaxMessageSize,
+                       ConnectInfo,
+                       UserConnectInfoLength);
 }
 
 /* EOF */