- ws2_32 Winetest: Define __ROS_LONG64__
[reactos.git] / rostests / winetests / rpcrt4 / server.c
index 7f25518..3b8847b 100644 (file)
@@ -36,6 +36,17 @@ static const char *progname;
 
 static HANDLE stop_event;
 
+static void (WINAPI *pNDRSContextMarshall2)(RPC_BINDING_HANDLE, NDR_SCONTEXT, void*, NDR_RUNDOWN, void*, ULONG);
+static NDR_SCONTEXT (WINAPI *pNDRSContextUnmarshall2)(RPC_BINDING_HANDLE, void*, ULONG, void*, ULONG);
+
+static void InitFunctionPointers(void)
+{
+    HMODULE hrpcrt4 = GetModuleHandleA("rpcrt4.dll");
+
+    pNDRSContextMarshall2 = (void *)GetProcAddress(hrpcrt4, "NDRSContextMarshall2");
+    pNDRSContextUnmarshall2 = (void *)GetProcAddress(hrpcrt4, "NDRSContextUnmarshall2");
+}
+
 void __RPC_FAR *__RPC_USER
 midl_user_allocate(size_t n)
 {
@@ -92,6 +103,12 @@ s_str_length(const char *s)
   return strlen(s);
 }
 
+int
+s_str_t_length(str_t s)
+{
+  return strlen(s);
+}
+
 int
 s_cstr_length(const char *s, int n)
 {
@@ -316,10 +333,22 @@ s_square_encu(encu_t *eu)
   }
 }
 
-int
-s_sum_parr(int *a[3])
+double
+s_square_unencu(int t, unencu_t *eu)
 {
-  return s_sum_pcarr(a, 3);
+  switch (t)
+  {
+  case ENCU_I: return eu->i * eu->i;
+  case ENCU_F: return eu->f * eu->f;
+  default:
+    return 0.0;
+  }
+}
+
+void
+s_check_se2(se_t *s)
+{
+  ok(s->f == E2, "check_se2\n");
 }
 
 int
@@ -331,6 +360,12 @@ s_sum_pcarr(int *a[], int n)
   return s;
 }
 
+int
+s_sum_parr(int *a[3])
+{
+  return s_sum_pcarr(a, 3);
+}
+
 int
 s_enum_ord(e_t e)
 {
@@ -497,6 +532,16 @@ s_sum_L1_norms(int n, vector_t *vs)
   return sum;
 }
 
+s123_t *
+s_get_s123(void)
+{
+  s123_t *s = MIDL_user_allocate(sizeof *s);
+  s->f1 = 1;
+  s->f2 = 2;
+  s->f3 = 3;
+  return s;
+}
+
 str_t
 s_get_filename(void)
 {
@@ -526,40 +571,40 @@ s_context_handle_test(void)
     binding = I_RpcGetCurrentCallHandle();
     ok(binding != NULL, "I_RpcGetCurrentCallHandle returned NULL\n");
 
-    h = NDRSContextUnmarshall2(binding, NULL, NDR_LOCAL_DATA_REPRESENTATION, NULL, 0);
+    h = pNDRSContextUnmarshall2(binding, NULL, NDR_LOCAL_DATA_REPRESENTATION, NULL, 0);
     ok(h != NULL, "NDRSContextUnmarshall2 returned NULL\n");
 
     /* marshal a context handle with NULL userContext */
     memset(buf, 0xcc, sizeof(buf));
-    NDRSContextMarshall2(binding, h, buf, NULL, NULL, 0);
+    pNDRSContextMarshall2(binding, h, buf, NULL, NULL, 0);
     ok(*(ULONG *)buf == 0, "attributes should have been set to 0 instead of 0x%x\n", *(ULONG *)buf);
     ok(UuidIsNil((UUID *)&buf[4], &status), "uuid should have been nil\n");
 
-    h = NDRSContextUnmarshall2(binding, NULL, NDR_LOCAL_DATA_REPRESENTATION, NULL, 0);
+    h = pNDRSContextUnmarshall2(binding, NULL, NDR_LOCAL_DATA_REPRESENTATION, NULL, 0);
     ok(h != NULL, "NDRSContextUnmarshall2 returned NULL\n");
 
     /* marshal a context handle with non-NULL userContext */
     memset(buf, 0xcc, sizeof(buf));
     h->userContext = (void *)0xdeadbeef;
-    NDRSContextMarshall2(binding, h, buf, NULL, NULL, 0);
+    pNDRSContextMarshall2(binding, h, buf, NULL, NULL, 0);
     ok(*(ULONG *)buf == 0, "attributes should have been set to 0 instead of 0x%x\n", *(ULONG *)buf);
     ok(!UuidIsNil((UUID *)&buf[4], &status), "uuid should not have been nil\n");
 
-    h = NDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, NULL, 0);
+    h = pNDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, NULL, 0);
     ok(h != NULL, "NDRSContextUnmarshall2 returned NULL\n");
     ok(h->userContext == (void *)0xdeadbeef, "userContext of interface didn't unmarshal properly: %p\n", h->userContext);
 
     /* marshal a context handle with an interface specified */
-    h = NDRSContextUnmarshall2(binding, NULL, NDR_LOCAL_DATA_REPRESENTATION, &server_if.InterfaceId, 0);
+    h = pNDRSContextUnmarshall2(binding, NULL, NDR_LOCAL_DATA_REPRESENTATION, &server_if.InterfaceId, 0);
     ok(h != NULL, "NDRSContextUnmarshall2 returned NULL\n");
 
     memset(buf, 0xcc, sizeof(buf));
     h->userContext = (void *)0xcafebabe;
-    NDRSContextMarshall2(binding, h, buf, NULL, &server_if.InterfaceId, 0);
+    pNDRSContextMarshall2(binding, h, buf, NULL, &server_if.InterfaceId, 0);
     ok(*(ULONG *)buf == 0, "attributes should have been set to 0 instead of 0x%x\n", *(ULONG *)buf);
     ok(!UuidIsNil((UUID *)&buf[4], &status), "uuid should not have been nil\n");
 
-    h = NDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, &server_if.InterfaceId, 0);
+    h = pNDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, &server_if.InterfaceId, 0);
     ok(h != NULL, "NDRSContextUnmarshall2 returned NULL\n");
     ok(h->userContext == (void *)0xcafebabe, "userContext of interface didn't unmarshal properly: %p\n", h->userContext);
 
@@ -569,7 +614,7 @@ s_context_handle_test(void)
     {
         RPC_SERVER_INTERFACE server_if_clone = server_if;
 
-        NDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, &server_if_clone.InterfaceId, 0);
+        pNDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, &server_if_clone.InterfaceId, 0);
     }
 
     /* test different interface data, but different pointer */
@@ -588,9 +633,9 @@ s_context_handle_test(void)
             0,
             0,
         };
-        NDRSContextMarshall2(binding, h, buf, NULL, &server_if.InterfaceId, 0);
+        pNDRSContextMarshall2(binding, h, buf, NULL, &server_if.InterfaceId, 0);
 
-        NDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, &server_if2.InterfaceId, 0);
+        pNDRSContextUnmarshall2(binding, buf, NDR_LOCAL_DATA_REPRESENTATION, &server_if2.InterfaceId, 0);
     }
 }
 
@@ -620,6 +665,24 @@ s_get_numbers(int length, int size, pints_t n[])
     }
 }
 
+void
+s_get_numbers_struct(numbers_struct_t **ns)
+{
+    int i;
+    *ns = midl_user_allocate(FIELD_OFFSET(numbers_struct_t, numbers[5]));
+    if (!*ns) return;
+    (*ns)->length = 5;
+    (*ns)->size = 5;
+    for (i = 0; i < (*ns)->length; i++)
+    {
+        (*ns)->numbers[i].pi = NULL;
+        (*ns)->numbers[i].ppi = NULL;
+        (*ns)->numbers[i].pppi = NULL;
+    }
+    (*ns)->numbers[0].pi = midl_user_allocate(sizeof(*(*ns)->numbers[i].pi));
+    *(*ns)->numbers[0].pi = 5;
+}
+
 void
 s_stop(void)
 {
@@ -634,25 +697,21 @@ make_cmdline(char buffer[MAX_PATH], const char *test)
   sprintf(buffer, "%s server %s", progname, test);
 }
 
-static int
+static void
 run_client(const char *test)
 {
   char cmdline[MAX_PATH];
   PROCESS_INFORMATION info;
   STARTUPINFOA startup;
-  DWORD exitcode;
 
   memset(&startup, 0, sizeof startup);
   startup.cb = sizeof startup;
 
   make_cmdline(cmdline, test);
   ok(CreateProcessA(NULL, cmdline, NULL, NULL, FALSE, 0L, NULL, NULL, &startup, &info), "CreateProcess\n");
-  ok(WaitForSingleObject(info.hProcess, 30000) == WAIT_OBJECT_0, "Child process termination\n");
-  ok(GetExitCodeProcess(info.hProcess, &exitcode), "GetExitCodeProcess\n");
+  winetest_wait_child_process( info.hProcess );
   ok(CloseHandle(info.hProcess), "CloseHandle\n");
   ok(CloseHandle(info.hThread), "CloseHandle\n");
-
-  return exitcode == 0;
 }
 
 static void
@@ -682,6 +741,7 @@ basic_tests(void)
   str_struct_t ss = {string};
   wstr_struct_t ws = {wstring};
   str_t str;
+  se_t se;
 
   ok(int_return() == INT_CODE, "RPC int_return\n");
 
@@ -697,6 +757,7 @@ basic_tests(void)
   ok(x == 25, "RPC square_ref\n");
 
   ok(str_length(string) == strlen(string), "RPC str_length\n");
+  ok(str_t_length(string) == strlen(string), "RPC str_length\n");
   ok(dot_self(&a) == 59, "RPC dot_self\n");
 
   ok(str_struct_len(&ss) == lstrlenA(string), "RPC str_struct_len\n");
@@ -750,6 +811,9 @@ basic_tests(void)
   ok(enum_ord(E3) == 3, "RPC enum_ord\n");
   ok(enum_ord(E4) == 4, "RPC enum_ord\n");
 
+  se.f = E2;
+  check_se2(&se);
+
   memset(&aligns, 0, sizeof(aligns));
   aligns.c = 3;
   aligns.i = 4;
@@ -792,6 +856,7 @@ union_tests(void)
 {
   encue_t eue;
   encu_t eu;
+  unencu_t uneu;
   sun_t su;
   int i;
 
@@ -820,6 +885,12 @@ union_tests(void)
   eu.tagged_union.f = 3.0;
   ok(square_encu(&eu) == 9.0, "RPC square_encu\n");
 
+  uneu.i = 4;
+  ok(square_unencu(ENCU_I, &uneu) == 16.0, "RPC square_unencu\n");
+
+  uneu.f = 5.0;
+  ok(square_unencu(ENCU_F, &uneu) == 25.0, "RPC square_unencu\n");
+
   eue.t = E1;
   eue.tagged_union.i1 = 8;
   ok(square_encue(&eue) == 64.0, "RPC square_encue\n");
@@ -919,27 +990,27 @@ us_t_UserFree(ULONG *flags, us_t *pus)
 ULONG __RPC_USER
 bstr_t_UserSize(ULONG *flags, ULONG start, bstr_t *b)
 {
-  return start + FIELD_OFFSET(wire_bstr_t, data[(*b)[-1]]);
+  return start + FIELD_OFFSET(user_bstr_t, data[(*b)[-1]]);
 }
 
 unsigned char * __RPC_USER
 bstr_t_UserMarshal(ULONG *flags, unsigned char *buffer, bstr_t *b)
 {
-  wire_bstr_t *wb = (wire_bstr_t *) buffer;
+  wire_bstr_t wb = (wire_bstr_t) buffer;
   wb->n = (*b)[-1];
   memcpy(&wb->data, *b, wb->n * sizeof wb->data[0]);
-  return buffer + FIELD_OFFSET(wire_bstr_t, data[wb->n]);
+  return buffer + FIELD_OFFSET(user_bstr_t, data[wb->n]);
 }
 
 unsigned char * __RPC_USER
 bstr_t_UserUnmarshal(ULONG *flags, unsigned char *buffer, bstr_t *b)
 {
-  wire_bstr_t *wb = (wire_bstr_t *) buffer;
+  wire_bstr_t wb = (wire_bstr_t) buffer;
   short *data = HeapAlloc(GetProcessHeap(), 0, (wb->n + 1) * sizeof *data);
   data[0] = wb->n;
   memcpy(&data[1], wb->data, wb->n * sizeof data[1]);
   *b = &data[1];
-  return buffer + FIELD_OFFSET(wire_bstr_t, data[wb->n]);
+  return buffer + FIELD_OFFSET(user_bstr_t, data[wb->n]);
 }
 
 void __RPC_USER
@@ -963,6 +1034,7 @@ pointer_tests(void)
   name_t name;
   void *buffer;
   int *pa2;
+  s123_t *s123;
 
   ok(test_list_length(list) == 3, "RPC test_list_length\n");
   ok(square_puint(p1) == 121, "RPC square_puint\n");
@@ -1016,6 +1088,10 @@ pointer_tests(void)
 
   pa2 = a;
   ok(sum_pcarr2(4, &pa2) == 10, "RPC sum_pcarr2\n");
+
+  s123 = get_s123();
+  ok(s123->f1 == 1 && s123->f2 == 2 && s123->f3 == 3, "RPC get_s123\n");
+  MIDL_user_free(s123);
 }
 
 static int
@@ -1057,6 +1133,7 @@ array_tests(void)
   doub_carr_t *dc;
   int *pi;
   pints_t api[5];
+  numbers_struct_t *ns;
 
   ok(cstr_length(str1, sizeof str1) == strlen(str1), "RPC cstr_length\n");
 
@@ -1138,12 +1215,22 @@ array_tests(void)
   api[0].pi = pi;
   get_5numbers(1, api);
   ok(api[0].pi == pi, "RPC varying array [out] pointer changed from %p to %p\n", pi, api[0].pi);
-  ok(*api[0].pi == 0, "pi unmarshalled incorrectly %d\n", *pi);
+  ok(*api[0].pi == 0, "pi unmarshalled incorrectly %d\n", *api[0].pi);
 
   api[0].pi = pi;
   get_numbers(1, 1, api);
   ok(api[0].pi == pi, "RPC conformant varying array [out] pointer changed from %p to %p\n", pi, api[0].pi);
-  ok(*api[0].pi == 0, "pi unmarshalled incorrectly %d\n", *pi);
+  ok(*api[0].pi == 0, "pi unmarshalled incorrectly %d\n", *api[0].pi);
+
+  ns = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, FIELD_OFFSET(numbers_struct_t, numbers[5]));
+  ns->length = 5;
+  ns->size = 5;
+  ns->numbers[0].pi = pi;
+  get_numbers_struct(&ns);
+  ok(ns->numbers[0].pi == pi, "RPC conformant varying struct embedded pointer changed from %p to %p\n", pi, ns->numbers[0].pi);
+  ok(*ns->numbers[0].pi == 5, "pi unmarshalled incorrectly %d\n", *ns->numbers[0].pi);
+
+  HeapFree(GetProcessHeap(), 0, ns);
   HeapFree(GetProcessHeap(), 0, pi);
 }
 
@@ -1200,22 +1287,59 @@ server(void)
   static unsigned char port[] = PORT;
   static unsigned char np[] = "ncacn_np";
   static unsigned char pipe[] = PIPE;
+  RPC_STATUS status, iptcp_status, np_status;
+  RPC_STATUS (RPC_ENTRY *pRpcServerRegisterIfEx)(RPC_IF_HANDLE,UUID*,
+    RPC_MGR_EPV*, unsigned int,unsigned int,RPC_IF_CALLBACK_FN*);
+  DWORD ret;
+
+  iptcp_status = RpcServerUseProtseqEp(iptcp, 20, port, NULL);
+  ok(iptcp_status == RPC_S_OK, "RpcServerUseProtseqEp(ncacn_ip_tcp) failed with status %ld\n", iptcp_status);
+  np_status = RpcServerUseProtseqEp(np, 0, pipe, NULL);
+  if (np_status == RPC_S_PROTSEQ_NOT_SUPPORTED)
+    skip("Protocol sequence ncacn_np is not supported\n");
+  else
+    ok(np_status == RPC_S_OK, "RpcServerUseProtseqEp(ncacn_np) failed with status %ld\n", np_status);
 
-  ok(RPC_S_OK == RpcServerUseProtseqEp(iptcp, 20, port, NULL), "RpcServerUseProtseqEp\n");
-  //ok(RPC_S_OK == RpcServerRegisterIf(s_IServer_v0_0_s_ifspec, NULL, NULL), "RpcServerRegisterIf\n");
-  ok(RPC_S_OK == RpcServerListen(1, 20, TRUE), "RpcServerListen\n");
-
+  pRpcServerRegisterIfEx = (void *)GetProcAddress(GetModuleHandle("rpcrt4.dll"), "RpcServerRegisterIfEx");
+  if (pRpcServerRegisterIfEx)
+  {
+    trace("Using RpcServerRegisterIfEx\n");
+    status = pRpcServerRegisterIfEx(IServer_v0_0_s_ifspec, NULL, NULL,
+                                    RPC_IF_ALLOW_CALLBACKS_WITH_NO_AUTH,
+                                    RPC_C_LISTEN_MAX_CALLS_DEFAULT, NULL);
+  }
+  else
+    status = RpcServerRegisterIf(IServer_v0_0_s_ifspec, NULL, NULL);
+  ok(status == RPC_S_OK, "RpcServerRegisterIf failed with status %ld\n", status);
+  status = RpcServerListen(1, 20, TRUE);
+  ok(status == RPC_S_OK, "RpcServerListen failed with status %ld\n", status);
   stop_event = CreateEvent(NULL, FALSE, FALSE, NULL);
-  ok(stop_event != NULL, "CreateEvent failed\n");
+  ok(stop_event != NULL, "CreateEvent failed with error %d\n", GetLastError());
 
-  ok(run_client("tcp_basic"), "tcp_basic client test failed\n");
+  if (iptcp_status == RPC_S_OK)
+    run_client("tcp_basic");
+  else
+    skip("tcp_basic tests skipped due to earlier failure\n");
 
-  ok(RPC_S_OK == RpcServerUseProtseqEp(np, 0, pipe, NULL), "RpcServerUseProtseqEp\n");
-  ok(run_client("np_basic"), "np_basic client test failed\n");
+  if (np_status == RPC_S_OK)
+    run_client("np_basic");
+  else
+  {
+    skip("np_basic tests skipped due to earlier failure\n");
+    /* np client is what signals stop_event, so bail out if we didn't run do it */
+    return;
+  }
 
-  ok(WAIT_OBJECT_0 == WaitForSingleObject(stop_event, 60000), "WaitForSingleObject\n");
-  todo_wine {
-    ok(RPC_S_OK == RpcMgmtWaitServerListen(), "RpcMgmtWaitServerListening\n");
+  ret = WaitForSingleObject(stop_event, 1000);
+  ok(WAIT_OBJECT_0 == ret, "WaitForSingleObject\n");
+  /* if the stop event didn't fire then RpcMgmtWaitServerListen will wait
+   * forever, so don't bother calling it in this case */
+  if (ret == WAIT_OBJECT_0)
+  {
+    status = RpcMgmtWaitServerListen();
+    todo_wine {
+      ok(status == RPC_S_OK, "RpcMgmtWaitServerListening failed with status %ld\n", status);
+    }
   }
 }
 
@@ -1224,6 +1348,8 @@ START_TEST(server)
   int argc;
   char **argv;
 
+  InitFunctionPointers();
+
   argc = winetest_get_mainargs(&argv);
   progname = argv[0];