38dd4d46e6ef8a8547fb05db6e186236a8d6236c
[reactos.git] / modules / rostests / apitests / ntdll / NtAcceptConnectPort.c
1 /*
2 * PROJECT: ReactOS API tests
3 * LICENSE: LGPLv2.1+ - See COPYING.LIB in the top level directory
4 * PURPOSE: Test for NtAcceptConnectPort
5 * PROGRAMMERS: Thomas Faber <thomas.faber@reactos.org>
6 */
7
8 #include <apitest.h>
9
10 #define WIN32_NO_STATUS
11 #include <ndk/lpcfuncs.h>
12 #include <ndk/obfuncs.h>
13 #include <ndk/rtlfuncs.h>
14 #include <process.h>
15
16 #define TEST_CONNECTION_INFO_SIGNATURE1 0xaabb0123
17 #define TEST_CONNECTION_INFO_SIGNATURE2 0xaabb0124
18 typedef struct _TEST_CONNECTION_INFO
19 {
20 ULONG Signature;
21 } TEST_CONNECTION_INFO, *PTEST_CONNECTION_INFO;
22
23 #define TEST_MESSAGE_MESSAGE 0x4455cdef
24 typedef struct _TEST_MESSAGE
25 {
26 PORT_MESSAGE Header;
27 ULONG Message;
28 } TEST_MESSAGE, *PTEST_MESSAGE;
29
30 static UNICODE_STRING PortName = RTL_CONSTANT_STRING(L"\\NtdllApitestNtAcceptConnectPortTestPort");
31 static UINT ServerThreadId;
32 static UINT ClientThreadId;
33 static UCHAR Context;
34
35 UINT
36 CALLBACK
37 ServerThread(
38 _Inout_ PVOID Parameter)
39 {
40 NTSTATUS Status;
41 TEST_MESSAGE Message;
42 HANDLE PortHandle;
43 HANDLE ServerPortHandle = Parameter;
44
45 /* Listen, but refuse the connection */
46 RtlZeroMemory(&Message, sizeof(Message));
47 Status = NtListenPort(ServerPortHandle,
48 &Message.Header);
49 ok_hex(Status, STATUS_SUCCESS);
50
51 ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message),
52 "TotalLength = %u, expected %lu\n",
53 Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message));
54 ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO),
55 "DataLength = %u\n", Message.Header.u1.s1.DataLength);
56 ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST,
57 "Type = %x\n", Message.Header.u2.s2.Type);
58 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
59 "UniqueProcess = %p, expected %lx\n",
60 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
61 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
62 "UniqueThread = %p, expected %x\n",
63 Message.Header.ClientId.UniqueThread, ClientThreadId);
64 ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE1, "Message = %lx\n", Message.Message);
65
66 PortHandle = (PVOID)(ULONG_PTR)0x55555555;
67 Status = NtAcceptConnectPort(&PortHandle,
68 &Context,
69 &Message.Header,
70 FALSE,
71 NULL,
72 NULL);
73 ok_hex(Status, STATUS_SUCCESS);
74 ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle);
75
76 /* Listen a second time, then accept */
77 RtlZeroMemory(&Message, sizeof(Message));
78 Status = NtListenPort(ServerPortHandle,
79 &Message.Header);
80 ok_hex(Status, STATUS_SUCCESS);
81
82 ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message),
83 "TotalLength = %u, expected %lu\n",
84 Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message));
85 ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO),
86 "DataLength = %u\n", Message.Header.u1.s1.DataLength);
87 ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST,
88 "Type = %x\n", Message.Header.u2.s2.Type);
89 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
90 "UniqueProcess = %p, expected %lx\n",
91 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
92 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
93 "UniqueThread = %p, expected %x\n",
94 Message.Header.ClientId.UniqueThread, ClientThreadId);
95 ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE2, "Message = %lx\n", Message.Message);
96
97 Status = NtAcceptConnectPort(&PortHandle,
98 &Context,
99 &Message.Header,
100 TRUE,
101 NULL,
102 NULL);
103 ok_hex(Status, STATUS_SUCCESS);
104
105 Status = NtCompleteConnectPort(PortHandle);
106 ok_hex(Status, STATUS_SUCCESS);
107
108 RtlZeroMemory(&Message, sizeof(Message));
109 Status = NtReplyWaitReceivePort(PortHandle,
110 NULL,
111 NULL,
112 &Message.Header);
113 ok_hex(Status, STATUS_SUCCESS);
114
115 ok(Message.Header.u1.s1.TotalLength == sizeof(Message),
116 "TotalLength = %u, expected %Iu\n",
117 Message.Header.u1.s1.TotalLength, sizeof(Message));
118 ok(Message.Header.u1.s1.DataLength == sizeof(Message.Message),
119 "DataLength = %u\n", Message.Header.u1.s1.DataLength);
120 ok(Message.Header.u2.s2.Type == LPC_DATAGRAM,
121 "Type = %x\n", Message.Header.u2.s2.Type);
122 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
123 "UniqueProcess = %p, expected %lx\n",
124 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
125 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
126 "UniqueThread = %p, expected %x\n",
127 Message.Header.ClientId.UniqueThread, ClientThreadId);
128 ok(Message.Message == TEST_MESSAGE_MESSAGE, "Message = %lx\n", Message.Message);
129
130 Status = NtClose(PortHandle);
131 ok_hex(Status, STATUS_SUCCESS);
132
133 return 0;
134 }
135
136 UINT
137 CALLBACK
138 ClientThread(
139 _Inout_ PVOID Parameter)
140 {
141 NTSTATUS Status;
142 HANDLE PortHandle;
143 TEST_CONNECTION_INFO ConnectInfo;
144 ULONG ConnectInfoLength;
145 SECURITY_QUALITY_OF_SERVICE SecurityQos;
146 TEST_MESSAGE Message;
147
148 SecurityQos.Length = sizeof(SecurityQos);
149 SecurityQos.ImpersonationLevel = SecurityIdentification;
150 SecurityQos.EffectiveOnly = TRUE;
151 SecurityQos.ContextTrackingMode = SECURITY_STATIC_TRACKING;
152
153 /* Attempt to connect -- will be rejected */
154 ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE1;
155 ConnectInfoLength = sizeof(ConnectInfo);
156 PortHandle = (PVOID)(ULONG_PTR)0x55555555;
157 Status = NtConnectPort(&PortHandle,
158 &PortName,
159 &SecurityQos,
160 NULL,
161 NULL,
162 NULL,
163 &ConnectInfo,
164 &ConnectInfoLength);
165 ok_hex(Status, STATUS_PORT_CONNECTION_REFUSED);
166 ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle);
167
168 /* Try again, this time it will be accepted */
169 ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE2;
170 ConnectInfoLength = sizeof(ConnectInfo);
171 Status = NtConnectPort(&PortHandle,
172 &PortName,
173 &SecurityQos,
174 NULL,
175 NULL,
176 NULL,
177 &ConnectInfo,
178 &ConnectInfoLength);
179 ok_hex(Status, STATUS_SUCCESS);
180 if (!NT_SUCCESS(Status))
181 {
182 skip("Failed to connect\n");
183 return 0;
184 }
185
186 RtlZeroMemory(&Message, sizeof(Message));
187 Message.Header.u1.s1.TotalLength = sizeof(Message);
188 Message.Header.u1.s1.DataLength = sizeof(Message.Message);
189 Message.Message = TEST_MESSAGE_MESSAGE;
190 Status = NtRequestPort(PortHandle,
191 &Message.Header);
192 ok_hex(Status, STATUS_SUCCESS);
193
194 Status = NtClose(PortHandle);
195 ok_hex(Status, STATUS_SUCCESS);
196
197 return 0;
198 }
199
200 START_TEST(NtAcceptConnectPort)
201 {
202 NTSTATUS Status;
203 OBJECT_ATTRIBUTES ObjectAttributes;
204 HANDLE PortHandle;
205 HANDLE ThreadHandles[2];
206
207 InitializeObjectAttributes(&ObjectAttributes,
208 &PortName,
209 OBJ_CASE_INSENSITIVE,
210 NULL,
211 NULL);
212 Status = NtCreatePort(&PortHandle,
213 &ObjectAttributes,
214 sizeof(TEST_CONNECTION_INFO),
215 sizeof(TEST_MESSAGE),
216 2 * sizeof(TEST_MESSAGE));
217 ok_hex(Status, STATUS_SUCCESS);
218 if (!NT_SUCCESS(Status))
219 {
220 skip("Failed to create port\n");
221 return;
222 }
223
224 ThreadHandles[0] = (HANDLE)_beginthreadex(NULL,
225 0,
226 ServerThread,
227 PortHandle,
228 0,
229 &ServerThreadId);
230 ok(ThreadHandles[0] != NULL, "_beginthreadex failed\n");
231
232 ThreadHandles[1] = (HANDLE)_beginthreadex(NULL,
233 0,
234 ClientThread,
235 PortHandle,
236 0,
237 &ClientThreadId);
238 ok(ThreadHandles[1] != NULL, "_beginthreadex failed\n");
239
240 Status = NtWaitForMultipleObjects(RTL_NUMBER_OF(ThreadHandles),
241 ThreadHandles,
242 WaitAll,
243 FALSE,
244 NULL);
245 ok_hex(Status, STATUS_SUCCESS);
246
247 Status = NtClose(ThreadHandles[0]);
248 ok_hex(Status, STATUS_SUCCESS);
249 Status = NtClose(ThreadHandles[1]);
250 ok_hex(Status, STATUS_SUCCESS);
251
252 Status = NtClose(PortHandle);
253 ok_hex(Status, STATUS_SUCCESS);
254 }