Fix invalid Object Attributes
[reactos.git] / reactos / subsys / smss / smapi.c
index 78cac41..6d2941a 100644 (file)
@@ -5,9 +5,7 @@
  * Reactos Session Manager
  *
  */
-
 #include "smss.h"
-#include <rosrtl/string.h>
 
 #define NDEBUG
 #include <debug.h>
@@ -21,7 +19,7 @@ static HANDLE SmApiPort = INVALID_HANDLE_VALUE;
 SMAPI(SmInvalid)
 {
        DPRINT("SM: %s called\n",__FUNCTION__);
-       Request->Status = STATUS_NOT_IMPLEMENTED;
+       Request->SmHeader.Status = STATUS_NOT_IMPLEMENTED;
        return STATUS_SUCCESS;
 }
 
@@ -34,14 +32,69 @@ SM_PORT_API SmApi [] =
        SmCompSes,      /* smapicomp.c */
        SmInvalid,      /* obsolete */
        SmInvalid,      /* unknown */
-       SmExecPgm       /* smapiexec.c */
+       SmExecPgm,      /* smapiexec.c */
+       SmQryInfo       /* smapyqry.c */
 };
 
+/* TODO: optimize this address computation (it should be done
+ * with a macro) */
+PSM_CONNECT_DATA FASTCALL SmpGetConnectData (PSM_PORT_MESSAGE Request)
+{
+       PLPC_MAX_MESSAGE LpcMaxMessage = (PLPC_MAX_MESSAGE) Request;
+       return (PSM_CONNECT_DATA) & LpcMaxMessage->Data[0];
+}
+
 #if !defined(__USE_NT_LPC__)
 NTSTATUS STDCALL
 SmpHandleConnectionRequest (PSM_PORT_MESSAGE Request);
 #endif
 
+/**********************************************************************
+ * SmpCallbackServer/2
+ *
+ * DESCRIPTION
+ *     The SM calls back a previously connected subsystem process to
+ *     authorize it to bootstrap (initialize). The SM connects to a
+ *     named LPC port which name was sent in the connection data by
+ *     the candidate subsystem server process.
+ */
+static NTSTATUS
+SmpCallbackServer (PSM_PORT_MESSAGE Request,
+                  PSM_CLIENT_DATA ClientData)
+{
+       NTSTATUS          Status = STATUS_SUCCESS;
+       PSM_CONNECT_DATA  ConnectData = SmpGetConnectData (Request);
+       UNICODE_STRING    CallbackPortName;
+       ULONG             CallbackPortNameLength = SM_SB_NAME_MAX_LENGTH; /* TODO: compute length */
+       SB_CONNECT_DATA   SbConnectData;
+       ULONG             SbConnectDataLength = sizeof SbConnectData;
+
+       DPRINT("SM: %s called\n", __FUNCTION__);
+
+       if (    (IMAGE_SUBSYSTEM_UNKNOWN == ConnectData->SubSystemId) ||
+               (IMAGE_SUBSYSTEM_NATIVE  == ConnectData->SubSystemId))
+       {
+               DPRINT("SM: %s: we do not need calling back SM!\n",
+                               __FUNCTION__);
+               return STATUS_SUCCESS;
+       }
+       RtlCopyMemory (ClientData->SbApiPortName,
+                      ConnectData->SbName,
+                      CallbackPortNameLength);
+       RtlInitUnicodeString (& CallbackPortName,
+                             ClientData->SbApiPortName);
+
+       SbConnectData.SmApiMax = (sizeof SmApi / sizeof SmApi[0]);
+       Status = NtConnectPort (& ClientData->SbApiPort,
+                               & CallbackPortName,
+                               NULL,
+                               NULL,
+                               NULL,
+                               NULL,
+                               & SbConnectData,
+                               & SbConnectDataLength);
+       return Status;
+}
 
 /**********************************************************************
  * NAME
@@ -51,20 +104,21 @@ SmpHandleConnectionRequest (PSM_PORT_MESSAGE Request);
  *     Entry point for the listener thread of LPC port "\SmApiPort".
  */
 VOID STDCALL
-SmpApiConnectedThread(PVOID dummy)
+SmpApiConnectedThread(PVOID pConnectedPort)
 {
        NTSTATUS        Status = STATUS_SUCCESS;
        PVOID           Unknown = NULL;
        PLPC_MESSAGE    Reply = NULL;
        SM_PORT_MESSAGE Request = {{0}};
+       HANDLE          ConnectedPort = * (PHANDLE) pConnectedPort;
 
-       DPRINT("SM: %s running\n",__FUNCTION__);
+       DPRINT("SM: %s called\n", __FUNCTION__);
 
        while (TRUE)
        {
                DPRINT("SM: %s: waiting for message\n",__FUNCTION__);
 
-               Status = NtReplyWaitReceivePort(SmApiPort,
+               Status = NtReplyWaitReceivePort(ConnectedPort,
                                                (PULONG) & Unknown,
                                                Reply,
                                                (PLPC_MESSAGE) & Request);
@@ -88,18 +142,23 @@ SmpApiConnectedThread(PVOID dummy)
                              Reply = NULL;
                              break;
                        default:
-                               if ((Request.ApiIndex) &&
-                                       (Request.ApiIndex < (sizeof SmApi / sizeof SmApi[0])))
+                               if ((Request.SmHeader.ApiIndex) &&
+                                       (Request.SmHeader.ApiIndex < (sizeof SmApi / sizeof SmApi[0])))
                                {
-                                       Status = SmApi[Request.ApiIndex](&Request);
+                                       Status = SmApi[Request.SmHeader.ApiIndex](&Request);
                                        Reply = (PLPC_MESSAGE) & Request;
                                } else {
-                                       Request.Status = STATUS_NOT_IMPLEMENTED;
+                                       Request.SmHeader.Status = STATUS_NOT_IMPLEMENTED;
                                        Reply = (PLPC_MESSAGE) & Request;
                                }
                        }
+               } else {
+                       /* LPC failed */
+                       break;
                }
        }
+       NtClose (ConnectedPort);
+       NtTerminateThread (NtCurrentThread(), Status);
 }
 
 /**********************************************************************
@@ -115,84 +174,119 @@ SmpApiConnectedThread(PVOID dummy)
 NTSTATUS STDCALL
 SmpHandleConnectionRequest (PSM_PORT_MESSAGE Request)
 {
-#if defined(__USE_NT_LPC__)
+       PSM_CONNECT_DATA ConnectData = SmpGetConnectData (Request);
        NTSTATUS         Status = STATUS_SUCCESS;
+       BOOL             Accept = FALSE;
        PSM_CLIENT_DATA  ClientData = NULL;
+       HANDLE           hClientDataApiPort = (HANDLE) 0;
+       PHANDLE          ClientDataApiPort = & hClientDataApiPort;
+       HANDLE           hClientDataApiPortThread = (HANDLE) 0;
+       PHANDLE          ClientDataApiPortThread = & hClientDataApiPortThread;
        PVOID            Context = NULL;
-       
-       DPRINT("SM: %s called\n",__FUNCTION__);
 
-       /*
-        * SmCreateClient/2 is called here explicitly to *fail*.
-        * If it succeeds, there is something wrong in the
-        * connection request. An environment subsystem *never*
-        * registers twice. (Security issue: maybe we will
-        * write this event into the security log).
-        */
-       Status = SmCreateClient (Request, & ClientData);
-       if(STATUS_SUCCESS == Status)
+       DPRINT("SM: %s called:\n  SubSystemID=%d\n  SbName=\"%S\"\n",
+                       __FUNCTION__, ConnectData->SubSystemId, ConnectData->SbName);
+
+       if(sizeof (SM_CONNECT_DATA) == Request->Header.DataSize)
        {
-               /* OK: the client is an environment subsystem
-                * willing to manage a free image type.
-                * Accept it.
-                */
-               Status = NtAcceptConnectPort (& ClientData->ApiPort,
-                                             Context,
-                                             (PLPC_MESSAGE) Request,
-                                             TRUE, //accept
-                                             NULL,
-                                             NULL);
-               if(NT_SUCCESS(Status))
+               if(IMAGE_SUBSYSTEM_UNKNOWN == ConnectData->SubSystemId)
                {
-                       Status = NtCompleteConnectPort(ClientData->ApiPort);
+                       /*
+                        * This is not a call to register an image set,
+                        * but a simple connection request from a process
+                        * that will use the SM API.
+                        */
+                       DPRINT("SM: %s: simple request\n", __FUNCTION__);
+                       ClientDataApiPort = & hClientDataApiPort;
+                       ClientDataApiPortThread = & hClientDataApiPortThread;
+                       Accept = TRUE;
+               } else {
+                       DPRINT("SM: %s: request to register an image set\n", __FUNCTION__);
+                       /*
+                        *  Reject GUIs classes: only odd subsystem IDs are
+                        *  allowed to register here (tty mode images).
+                        */
+                       if(1 == (ConnectData->SubSystemId % 2))
+                       {
+                               DPRINT("SM: %s: id = %d\n", __FUNCTION__, ConnectData->SubSystemId);
+                               /*
+                                * SmBeginClientInitialization/2 will succeed only if there
+                                * is a candidate client ready.
+                                */
+                               Status = SmBeginClientInitialization (Request, & ClientData);
+                               if(STATUS_SUCCESS == Status)
+                               {
+                                       DPRINT("SM: %s: ClientData = 0x%08lx\n",
+                                               __FUNCTION__, ClientData);
+                                       /*
+                                        * OK: the client is an environment subsystem
+                                        * willing to manage a free image type.
+                                        */
+                                       ClientDataApiPort = & ClientData->ApiPort;
+                                       ClientDataApiPortThread = & ClientData->ApiPortThread;
+                                       /*
+                                        * Call back the candidate environment subsystem
+                                        * server (use the port name sent in in the
+                                        * connection request message).
+                                        */
+                                       Status = SmpCallbackServer (Request, ClientData);
+                                       if(NT_SUCCESS(Status))
+                                       {
+                                               DPRINT("SM: %s: SmpCallbackServer OK\n",
+                                                       __FUNCTION__);
+                                               Accept = TRUE;
+                                       } else {
+                                               DPRINT("SM: %s: SmpCallbackServer failed (Status=%08lx)\n",
+                                                       __FUNCTION__, Status);
+                                               Status = SmDestroyClient (ConnectData->SubSystemId);
+                                       }
+                               }
+                       }
                }
-               return STATUS_SUCCESS;
-       } else {
-               /* Reject the subsystem */
-               Status = NtAcceptConnectPort (& ClientData->ApiPort,
-                                             Context,
-                                             (PLPC_MESSAGE) Request,
-                                             FALSE, //reject
-                                             NULL,
-                                             NULL);
        }
+       DPRINT("SM: %s: before NtAcceptConnectPort\n", __FUNCTION__);
+#if defined(__USE_NT_LPC__)
+       Status = NtAcceptConnectPort (ClientDataApiPort,
+                                     Context,
+                                     (PLPC_MESSAGE) Request,
+                                     Accept,
+                                     NULL,
+                                     NULL);
 #else /* ReactOS LPC */
-       NTSTATUS         Status = STATUS_SUCCESS;
-       PSM_CLIENT_DATA  ClientData = NULL;
-       
-       DPRINT("SM: %s called\n",__FUNCTION__);
-
-       Status = SmCreateClient (Request, & ClientData);
-       if(STATUS_SUCCESS == Status)
+       Status = NtAcceptConnectPort (ClientDataApiPort,
+                                     SmApiPort, // ROS LPC requires the listen port here
+                                     Context,
+                                     Accept,
+                                     NULL,
+                                     NULL);
+#endif
+       if(Accept)
        {
-               Status = NtAcceptConnectPort (& ClientData->ApiPort,
-                                             SmApiPort,
-                                             NULL,
-                                             TRUE, //accept
-                                             NULL,
-                                             NULL);
-               if (!NT_SUCCESS(Status))
+               if(!NT_SUCCESS(Status))
                {
                        DPRINT1("SM: %s: NtAcceptConnectPort() failed (Status=0x%08lx)\n",
-                                       __FUNCTION__, Status);
+                               __FUNCTION__, Status);
                        return Status;
                } else {
-                       Status = NtCompleteConnectPort(ClientData->ApiPort);
+                       DPRINT("SM: %s: completing conn req\n", __FUNCTION__);
+                       Status = NtCompleteConnectPort (*ClientDataApiPort);
                        if (!NT_SUCCESS(Status))
                        {
                                DPRINT1("SM: %s: NtCompleteConnectPort() failed (Status=0x%08lx)\n",
                                        __FUNCTION__, Status);
                                return Status;
                        }
+#if !defined(__USE_NT_LPC__) /* ReactOS LPC */
+                       DPRINT("SM: %s: server side comm port thread (ROS LPC)\n", __FUNCTION__);
                        Status = RtlCreateUserThread(NtCurrentProcess(),
                                             NULL,
                                             FALSE,
                                             0,
-                                            NULL,
-                                            NULL,
+                                            0,
+                                            0,
                                             (PTHREAD_START_ROUTINE) SmpApiConnectedThread,
-                                            ClientData->ApiPort,
-                                            & ClientData->ApiPortThread,
+                                            ClientDataApiPort,
+                                            ClientDataApiPortThread,
                                             NULL);
                        if (!NT_SUCCESS(Status))
                        {
@@ -200,18 +294,11 @@ SmpHandleConnectionRequest (PSM_PORT_MESSAGE Request)
                                        __FUNCTION__, Status);
                                return Status;
                        }
+#endif
                }
-               return STATUS_SUCCESS;
-       } else {
-               /* Reject the subsystem */
-               Status = NtAcceptConnectPort (& ClientData->ApiPort,
-                                             SmApiPort,
-                                             NULL,
-                                             FALSE, //reject
-                                             NULL,
-                                             NULL);
+               Status = STATUS_SUCCESS;
        }
-#endif /* defined __USE_NT_LPC__ */
+       DPRINT("SM: %s done\n", __FUNCTION__);
        return Status;
 }
 
@@ -221,17 +308,19 @@ SmpHandleConnectionRequest (PSM_PORT_MESSAGE Request)
  *
  * DECRIPTION
  *     Due to differences in LPC implementation between NT and ROS,
- *     we need a thread to listen for connection request that
+ *     we need a thread to listen to for connection request that
  *     creates a new thread for each connected port. This is not
  *     necessary in NT LPC, because server side connected ports are
- *     never used to receive requests. 
+ *     never used to receive requests.
  */
 VOID STDCALL
 SmpApiThread (HANDLE ListeningPort)
 {
        NTSTATUS        Status = STATUS_SUCCESS;
        LPC_MAX_MESSAGE Request = {{0}};
-    
+
+       DPRINT("SM: %s called\n", __FUNCTION__);
+
        while (TRUE)
        {
                Status = NtListenPort (ListeningPort, & Request.Header);
@@ -268,14 +357,12 @@ NTSTATUS
 SmCreateApiPort(VOID)
 {
   OBJECT_ATTRIBUTES  ObjectAttributes = {0};
-  UNICODE_STRING     UnicodeString = {0};
+  UNICODE_STRING     UnicodeString = RTL_CONSTANT_STRING(L"\\SmApiPort");
   NTSTATUS           Status = STATUS_SUCCESS;
 
-  RtlRosInitUnicodeStringFromLiteral(&UnicodeString,
-                      L"\\SmApiPort");
   InitializeObjectAttributes(&ObjectAttributes,
                             &UnicodeString,
-                            PORT_ALL_ACCESS,
+                            0,
                             NULL,
                             NULL);
 
@@ -296,8 +383,8 @@ SmCreateApiPort(VOID)
                      NULL,
                      FALSE,
                      0,
-                     NULL,
-                     NULL,
+                     0,
+                     0,
                      (PTHREAD_START_ROUTINE)SmpApiThread,
                      (PVOID)SmApiPort,
                      NULL,