- Update to r53061
[reactos.git] / base / system / services / database.c
index 0d7ec77..1c5dd29 100644 (file)
@@ -25,6 +25,7 @@
 
 /* GLOBALS *******************************************************************/
 
+LIST_ENTRY ImageListHead;
 LIST_ENTRY ServiceListHead;
 
 static RTL_RESOURCE DatabaseLock;
@@ -34,6 +35,232 @@ static CRITICAL_SECTION ControlServiceCriticalSection;
 
 /* FUNCTIONS *****************************************************************/
 
+static DWORD
+ScmCreateNewControlPipe(PSERVICE_IMAGE pServiceImage)
+{
+    WCHAR szControlPipeName[MAX_PATH + 1];
+    HKEY hServiceCurrentKey = INVALID_HANDLE_VALUE;
+    DWORD ServiceCurrent = 0;
+    DWORD KeyDisposition;
+    DWORD dwKeySize;
+    DWORD dwError;
+
+    /* Get the service number */
+    /* TODO: Create registry entry with correct write access */
+    dwError = RegCreateKeyExW(HKEY_LOCAL_MACHINE,
+                              L"SYSTEM\\CurrentControlSet\\Control\\ServiceCurrent", 0, NULL,
+                              REG_OPTION_VOLATILE,
+                              KEY_WRITE | KEY_READ,
+                              NULL,
+                              &hServiceCurrentKey,
+                              &KeyDisposition);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("RegCreateKeyEx() failed with error %lu\n", dwError);
+        return dwError;
+    }
+
+    if (KeyDisposition == REG_OPENED_EXISTING_KEY)
+    {
+        dwKeySize = sizeof(DWORD);
+        dwError = RegQueryValueExW(hServiceCurrentKey,
+                                   L"", 0, NULL, (BYTE*)&ServiceCurrent, &dwKeySize);
+
+        if (dwError != ERROR_SUCCESS)
+        {
+            RegCloseKey(hServiceCurrentKey);
+            DPRINT1("RegQueryValueEx() failed with error %lu\n", dwError);
+            return dwError;
+        }
+
+        ServiceCurrent++;
+    }
+
+    dwError = RegSetValueExW(hServiceCurrentKey, L"", 0, REG_DWORD, (BYTE*)&ServiceCurrent, sizeof(ServiceCurrent));
+
+    RegCloseKey(hServiceCurrentKey);
+
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("RegSetValueExW() failed (Error %lu)\n", dwError);
+        return dwError;
+    }
+
+    /* Create '\\.\pipe\net\NtControlPipeXXX' instance */
+    swprintf(szControlPipeName, L"\\\\.\\pipe\\net\\NtControlPipe%u", ServiceCurrent);
+
+    DPRINT("PipeName: %S\n", szControlPipeName);
+
+    pServiceImage->hControlPipe = CreateNamedPipeW(szControlPipeName,
+                                                   PIPE_ACCESS_DUPLEX,
+                                                   PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_WAIT,
+                                                   100,
+                                                   8000,
+                                                   4,
+                                                   30000,
+                                                   NULL);
+    DPRINT("CreateNamedPipeW(%S) done\n", szControlPipeName);
+    if (pServiceImage->hControlPipe == INVALID_HANDLE_VALUE)
+    {
+        DPRINT1("Failed to create control pipe!\n");
+        return GetLastError();
+    }
+
+    return ERROR_SUCCESS;
+}
+
+
+static PSERVICE_IMAGE
+ScmGetServiceImageByImagePath(LPWSTR lpImagePath)
+{
+    PLIST_ENTRY ImageEntry;
+    PSERVICE_IMAGE CurrentImage;
+
+    DPRINT("ScmGetServiceImageByImagePath(%S) called\n", lpImagePath);
+
+    ImageEntry = ImageListHead.Flink;
+    while (ImageEntry != &ImageListHead)
+    {
+        CurrentImage = CONTAINING_RECORD(ImageEntry,
+                                         SERVICE_IMAGE,
+                                         ImageListEntry);
+        if (_wcsicmp(CurrentImage->szImagePath, lpImagePath) == 0)
+        {
+            DPRINT("Found image: '%S'\n", CurrentImage->szImagePath);
+            return CurrentImage;
+        }
+
+        ImageEntry = ImageEntry->Flink;
+    }
+
+    DPRINT1("Couldn't find a matching image\n");
+
+    return NULL;
+
+}
+
+
+static DWORD
+ScmCreateOrReferenceServiceImage(PSERVICE pService)
+{
+    RTL_QUERY_REGISTRY_TABLE QueryTable[2];
+    UNICODE_STRING ImagePath;
+    PSERVICE_IMAGE pServiceImage = NULL;
+    NTSTATUS Status;
+    DWORD dwError = ERROR_SUCCESS;
+
+    DPRINT("ScmCreateOrReferenceServiceImage(%p)\n", pService);
+
+    RtlInitUnicodeString(&ImagePath, NULL);
+
+    /* Get service data */
+    RtlZeroMemory(&QueryTable,
+                  sizeof(QueryTable));
+
+    QueryTable[0].Name = L"ImagePath";
+    QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT | RTL_QUERY_REGISTRY_REQUIRED;
+    QueryTable[0].EntryContext = &ImagePath;
+
+    Status = RtlQueryRegistryValues(RTL_REGISTRY_SERVICES,
+                                    pService->lpServiceName,
+                                    QueryTable,
+                                    NULL,
+                                    NULL);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
+        return RtlNtStatusToDosError(Status);
+    }
+
+    DPRINT("ImagePath: '%wZ'\n", &ImagePath);
+
+    pServiceImage = ScmGetServiceImageByImagePath(ImagePath.Buffer);
+    if (pServiceImage == NULL)
+    {
+        /* Create a new service image */
+        pServiceImage = (PSERVICE_IMAGE)HeapAlloc(GetProcessHeap(),
+                                                  HEAP_ZERO_MEMORY,
+                                                  sizeof(SERVICE_IMAGE) + ((wcslen(ImagePath.Buffer) + 1) * sizeof(WCHAR)));
+        if (pServiceImage == NULL)
+        {
+            dwError = ERROR_NOT_ENOUGH_MEMORY;
+            goto done;
+        }
+
+        pServiceImage->dwImageRunCount = 1;
+        pServiceImage->hControlPipe = INVALID_HANDLE_VALUE;
+        pServiceImage->hProcess = INVALID_HANDLE_VALUE;
+
+        /* Set the image path */
+        wcscpy(pServiceImage->szImagePath,
+               ImagePath.Buffer);
+
+        RtlFreeUnicodeString(&ImagePath);
+
+        /* Create the control pipe */
+        dwError = ScmCreateNewControlPipe(pServiceImage);
+        if (dwError != ERROR_SUCCESS)
+        {
+            HeapFree(GetProcessHeap(), 0, pServiceImage);
+            goto done;
+        }
+
+        /* FIXME: Add more initialization code here */
+
+
+        /* Append service record */
+        InsertTailList(&ImageListHead,
+                       &pServiceImage->ImageListEntry);
+    }
+    else
+    {
+        /* Increment the run counter */
+        pServiceImage->dwImageRunCount++;
+    }
+
+    DPRINT("pServiceImage->dwImageRunCount: %lu\n", pServiceImage->dwImageRunCount);
+
+    /* Link the service image to the service */
+    pService->lpImage = pServiceImage;
+
+done:;
+    RtlFreeUnicodeString(&ImagePath);
+
+    DPRINT("ScmCreateOrReferenceServiceImage() done (Error: %lu)\n", dwError);
+
+    return dwError;
+}
+
+
+static VOID
+ScmDereferenceServiceImage(PSERVICE_IMAGE pServiceImage)
+{
+    DPRINT1("ScmDereferenceServiceImage() called\n");
+
+    pServiceImage->dwImageRunCount--;
+
+    if (pServiceImage->dwImageRunCount == 0)
+    {
+        DPRINT1("dwImageRunCount == 0\n");
+
+        /* FIXME: Terminate the process */
+
+        /* Remove the service image from the list */
+        RemoveEntryList(&pServiceImage->ImageListEntry);
+
+        /* Close the control pipe */
+        if (pServiceImage->hControlPipe != INVALID_HANDLE_VALUE)
+            CloseHandle(pServiceImage->hControlPipe);
+
+        /* Close the process handle */
+        if (pServiceImage->hProcess != INVALID_HANDLE_VALUE)
+            CloseHandle(pServiceImage->hProcess);
+
+        /* Release the service image */
+        HeapFree(GetProcessHeap(), 0, pServiceImage);
+    }
+}
+
 
 PSERVICE
 ScmGetServiceEntryByName(LPCWSTR lpServiceName)
@@ -131,7 +358,7 @@ ScmCreateNewServiceRecord(LPCWSTR lpServiceName,
     DPRINT("Service: '%S'\n", lpServiceName);
 
     /* Allocate service entry */
-    lpService = (SERVICE*) HeapAlloc(GetProcessHeap(),
+    lpService = (SERVICE*)HeapAlloc(GetProcessHeap(),
                           HEAP_ZERO_MEMORY,
                           sizeof(SERVICE) + ((wcslen(lpServiceName) + 1) * sizeof(WCHAR)));
     if (lpService == NULL)
@@ -173,9 +400,9 @@ ScmDeleteServiceRecord(PSERVICE lpService)
         lpService->lpDisplayName != lpService->lpServiceName)
         HeapFree(GetProcessHeap(), 0, lpService->lpDisplayName);
 
-    /* Decrement the image reference counter */
+    /* Dereference the service image */
     if (lpService->lpImage)
-        lpService->lpImage->dwServiceRefCount--;
+        ScmDereferenceServiceImage(lpService->lpImage);
 
     /* Decrement the group reference counter */
     if (lpService->lpGroup)
@@ -183,9 +410,6 @@ ScmDeleteServiceRecord(PSERVICE lpService)
 
     /* FIXME: SecurityDescriptor */
 
-    /* Close the control pipe */
-    if (lpService->ControlPipeHandle != INVALID_HANDLE_VALUE)
-        CloseHandle(lpService->ControlPipeHandle);
 
     /* Remove the Service from the List */
     RemoveEntryList(&lpService->ServiceListEntry);
@@ -330,6 +554,12 @@ done:;
     if (lpDisplayName != NULL)
         HeapFree(GetProcessHeap(), 0, lpDisplayName);
 
+    if (lpService != NULL)
+    {
+        if (lpService->lpImage != NULL)
+            ScmDereferenceServiceImage(lpService->lpImage);
+    }
+
     return dwError;
 }
 
@@ -479,6 +709,7 @@ ScmCreateServiceDatabase(VOID)
         return dwError;
 
     /* Initialize basic variables */
+    InitializeListHead(&ImageListHead);
     InitializeListHead(&ServiceListHead);
 
     /* Initialize the database lock */
@@ -711,14 +942,14 @@ ScmControlService(PSERVICE Service,
     wcscpy(&ControlPacket->szArguments[0], Service->lpServiceName);
 
     /* Send the control packet */
-    WriteFile(Service->ControlPipeHandle,
+    WriteFile(Service->lpImage->hControlPipe,
               ControlPacket,
               sizeof(SCM_CONTROL_PACKET) + (TotalLength * sizeof(WCHAR)),
               &dwWriteCount,
               NULL);
 
     /* Read the reply */
-    ReadFile(Service->ControlPipeHandle,
+    ReadFile(Service->lpImage->hControlPipe,
              &ReplyPacket,
              sizeof(SCM_REPLY_PACKET),
              &dwReadCount,
@@ -734,6 +965,12 @@ ScmControlService(PSERVICE Service,
         dwError = ReplyPacket.dwError;
     }
 
+    if (dwError == ERROR_SUCCESS &&
+        dwControl == SERVICE_CONTROL_STOP)
+    {
+        ScmDereferenceServiceImage(Service->lpImage);
+    }
+
     LeaveCriticalSection(&ControlServiceCriticalSection);
 
     DPRINT("ScmControlService() done\n");
@@ -804,14 +1041,14 @@ ScmSendStartCommand(PSERVICE Service,
     *Ptr = 0;
 
     /* Send the start command */
-    WriteFile(Service->ControlPipeHandle,
+    WriteFile(Service->lpImage->hControlPipe,
               ControlPacket,
               sizeof(SCM_CONTROL_PACKET) + (TotalLength - 1) * sizeof(WCHAR),
               &dwWriteCount,
               NULL);
 
     /* Read the reply */
-    ReadFile(Service->ControlPipeHandle,
+    ReadFile(Service->lpImage->hControlPipe,
              &ReplyPacket,
              sizeof(SCM_REPLY_PACKET),
              &dwReadCount,
@@ -838,106 +1075,19 @@ ScmStartUserModeService(PSERVICE Service,
                         DWORD argc,
                         LPWSTR *argv)
 {
-    RTL_QUERY_REGISTRY_TABLE QueryTable[3];
     PROCESS_INFORMATION ProcessInformation;
     STARTUPINFOW StartupInfo;
-    UNICODE_STRING ImagePath;
-    ULONG Type;
-    DWORD ServiceCurrent = 0;
     BOOL Result;
-    NTSTATUS Status;
     DWORD dwError = ERROR_SUCCESS;
-    WCHAR NtControlPipeName[MAX_PATH + 1];
-    HKEY hServiceCurrentKey = INVALID_HANDLE_VALUE;
-    DWORD KeyDisposition;
     DWORD dwProcessId;
 
-    RtlInitUnicodeString(&ImagePath, NULL);
-
-    /* Get service data */
-    RtlZeroMemory(&QueryTable,
-                  sizeof(QueryTable));
-
-    QueryTable[0].Name = L"Type";
-    QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT | RTL_QUERY_REGISTRY_REQUIRED;
-    QueryTable[0].EntryContext = &Type;
-
-    QueryTable[1].Name = L"ImagePath";
-    QueryTable[1].Flags = RTL_QUERY_REGISTRY_DIRECT | RTL_QUERY_REGISTRY_REQUIRED;
-    QueryTable[1].EntryContext = &ImagePath;
+    DPRINT("ScmStartUserModeService(%p)\n", Service);
 
-    Status = RtlQueryRegistryValues(RTL_REGISTRY_SERVICES,
-                                    Service->lpServiceName,
-                                    QueryTable,
-                                    NULL,
-                                    NULL);
-    if (!NT_SUCCESS(Status))
-    {
-        DPRINT1("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
-        return RtlNtStatusToDosError(Status);
-    }
-    DPRINT("ImagePath: '%S'\n", ImagePath.Buffer);
-    DPRINT("Type: %lx\n", Type);
-
-    /* Get the service number */
-    /* TODO: Create registry entry with correct write access */
-    Status = RegCreateKeyExW(HKEY_LOCAL_MACHINE,
-                            L"SYSTEM\\CurrentControlSet\\Control\\ServiceCurrent", 0, NULL,
-                            REG_OPTION_VOLATILE,
-                            KEY_WRITE | KEY_READ,
-                            NULL,
-                            &hServiceCurrentKey,
-                            &KeyDisposition);
-
-    if (ERROR_SUCCESS != Status)
-    {
-        DPRINT1("RegCreateKeyEx() failed with status %u\n", Status);
-        return Status;
-    }
-
-    if (REG_OPENED_EXISTING_KEY == KeyDisposition)
-    {
-        DWORD KeySize = sizeof(ServiceCurrent);
-        Status = RegQueryValueExW(hServiceCurrentKey, L"", 0, NULL, (BYTE*)&ServiceCurrent, &KeySize);
-
-        if (ERROR_SUCCESS != Status)
-        {
-            RegCloseKey(hServiceCurrentKey);
-            DPRINT1("RegQueryValueEx() failed with status %u\n", Status);
-            return Status;
-        }
-
-        ServiceCurrent++;
-    }
-
-    Status = RegSetValueExW(hServiceCurrentKey, L"", 0, REG_DWORD, (BYTE*)&ServiceCurrent, sizeof(ServiceCurrent));
-
-    RegCloseKey(hServiceCurrentKey);
-
-    if (ERROR_SUCCESS != Status)
+    /* If the image is already running ... */
+    if (Service->lpImage->dwImageRunCount > 1)
     {
-        DPRINT1("RegSetValueExW() failed (Status %lx)\n", Status);
-        return Status;
-    }
-
-    /* Create '\\.\pipe\net\NtControlPipeXXX' instance */
-    swprintf(NtControlPipeName, L"\\\\.\\pipe\\net\\NtControlPipe%u", ServiceCurrent);
-
-    DPRINT("Service: %p  ImagePath: %wZ  PipeName: %S\n", Service, &ImagePath, NtControlPipeName);
-
-    Service->ControlPipeHandle = CreateNamedPipeW(NtControlPipeName,
-                                                  PIPE_ACCESS_DUPLEX,
-                                                  PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_WAIT,
-                                                  100,
-                                                  8000,
-                                                  4,
-                                                  30000,
-                                                  NULL);
-    DPRINT("CreateNamedPipeW(%S) done\n", NtControlPipeName);
-    if (Service->ControlPipeHandle == INVALID_HANDLE_VALUE)
-    {
-        DPRINT1("Failed to create control pipe!\n");
-        return GetLastError();
+        /* ... just send a start command */
+        return ScmSendStartCommand(Service, argc, argv);
     }
 
     StartupInfo.cb = sizeof(StartupInfo);
@@ -949,7 +1099,7 @@ ScmStartUserModeService(PSERVICE Service,
     StartupInfo.lpReserved2 = 0;
 
     Result = CreateProcessW(NULL,
-                            ImagePath.Buffer,
+                            Service->lpImage->szImagePath,
                             NULL,
                             NULL,
                             FALSE,
@@ -958,15 +1108,9 @@ ScmStartUserModeService(PSERVICE Service,
                             NULL,
                             &StartupInfo,
                             &ProcessInformation);
-    RtlFreeUnicodeString(&ImagePath);
-
     if (!Result)
     {
         dwError = GetLastError();
-        /* Close control pipe */
-        CloseHandle(Service->ControlPipeHandle);
-        Service->ControlPipeHandle = INVALID_HANDLE_VALUE;
-
         DPRINT1("Starting '%S' failed!\n", Service->lpServiceName);
         return dwError;
     }
@@ -978,15 +1122,15 @@ ScmStartUserModeService(PSERVICE Service,
            ProcessInformation.dwThreadId,
            ProcessInformation.hThread);
 
-    /* Get process and thread ids */
-    Service->ProcessId = ProcessInformation.dwProcessId;
-    Service->ThreadId = ProcessInformation.dwThreadId;
+    /* Get process handle and id */
+    Service->lpImage->dwProcessId = ProcessInformation.dwProcessId;
+    Service->lpImage->hProcess = ProcessInformation.hProcess;
 
     /* Resume Thread */
     ResumeThread(ProcessInformation.hThread);
 
     /* Connect control pipe */
-    if (ConnectNamedPipe(Service->ControlPipeHandle, NULL) ?
+    if (ConnectNamedPipe(Service->lpImage->hControlPipe, NULL) ?
         TRUE : (dwError = GetLastError()) == ERROR_PIPE_CONNECTED)
     {
         DWORD dwRead = 0;
@@ -994,7 +1138,7 @@ ScmStartUserModeService(PSERVICE Service,
         DPRINT("Control pipe connected!\n");
 
         /* Read SERVICE_STATUS_HANDLE from pipe */
-        if (!ReadFile(Service->ControlPipeHandle,
+        if (!ReadFile(Service->lpImage->hControlPipe,
                       (LPVOID)&dwProcessId,
                       sizeof(DWORD),
                       &dwRead,
@@ -1015,17 +1159,10 @@ ScmStartUserModeService(PSERVICE Service,
     else
     {
         DPRINT1("Connecting control pipe failed! (Error %lu)\n", dwError);
-
-        /* Close control pipe */
-        CloseHandle(Service->ControlPipeHandle);
-        Service->ControlPipeHandle = INVALID_HANDLE_VALUE;
-        Service->ProcessId = 0;
-        Service->ThreadId = 0;
     }
 
-    /* Close process and thread handle */
+    /* Close thread handle */
     CloseHandle(ProcessInformation.hThread);
-    CloseHandle(ProcessInformation.hProcess);
 
     return dwError;
 }
@@ -1051,7 +1188,6 @@ ScmStartService(PSERVICE Service, DWORD argc, LPWSTR *argv)
         return ERROR_SERVICE_ALREADY_RUNNING;
     }
 
-    Service->ControlPipeHandle = INVALID_HANDLE_VALUE;
     DPRINT("Service->Type: %lu\n", Service->Status.dwServiceType);
 
     if (Service->Status.dwServiceType & SERVICE_DRIVER)
@@ -1067,14 +1203,23 @@ ScmStartService(PSERVICE Service, DWORD argc, LPWSTR *argv)
     else
     {
         /* Start user-mode service */
-        dwError = ScmStartUserModeService(Service, argc, argv);
+        dwError = ScmCreateOrReferenceServiceImage(Service);
         if (dwError == ERROR_SUCCESS)
         {
+            dwError = ScmStartUserModeService(Service, argc, argv);
+            if (dwError == ERROR_SUCCESS)
+            {
 #ifdef USE_SERVICE_START_PENDING
-            Service->Status.dwCurrentState = SERVICE_START_PENDING;
+                Service->Status.dwCurrentState = SERVICE_START_PENDING;
 #else
-            Service->Status.dwCurrentState = SERVICE_RUNNING;
+                Service->Status.dwCurrentState = SERVICE_RUNNING;
 #endif
+            }
+            else
+            {
+                ScmDereferenceServiceImage(Service->lpImage);
+                Service->lpImage = NULL;
+            }
         }
     }