359f06382fb7da4ee1346e321e9b2c02f0e38bd3
[reactos.git] / modules / rostests / apitests / ntdll / NtAcceptConnectPort.c
1 /*
2 * PROJECT: ReactOS API tests
3 * LICENSE: LGPLv2.1+ - See COPYING.LIB in the top level directory
4 * PURPOSE: Test for NtAcceptConnectPort
5 * PROGRAMMERS: Thomas Faber <thomas.faber@reactos.org>
6 */
7
8 #include "precomp.h"
9
10 #include <process.h>
11
12 #define TEST_CONNECTION_INFO_SIGNATURE1 0xaabb0123
13 #define TEST_CONNECTION_INFO_SIGNATURE2 0xaabb0124
14 typedef struct _TEST_CONNECTION_INFO
15 {
16 ULONG Signature;
17 } TEST_CONNECTION_INFO, *PTEST_CONNECTION_INFO;
18
19 #define TEST_MESSAGE_MESSAGE 0x4455cdef
20 typedef struct _TEST_MESSAGE
21 {
22 PORT_MESSAGE Header;
23 ULONG Message;
24 } TEST_MESSAGE, *PTEST_MESSAGE;
25
26 static UNICODE_STRING PortName = RTL_CONSTANT_STRING(L"\\NtdllApitestNtAcceptConnectPortTestPort");
27 static UINT ServerThreadId;
28 static UINT ClientThreadId;
29 static UCHAR Context;
30
31 UINT
32 CALLBACK
33 ServerThread(
34 _Inout_ PVOID Parameter)
35 {
36 NTSTATUS Status;
37 TEST_MESSAGE Message;
38 HANDLE PortHandle;
39 HANDLE ServerPortHandle = Parameter;
40
41 /* Listen, but refuse the connection */
42 RtlZeroMemory(&Message, sizeof(Message));
43 Status = NtListenPort(ServerPortHandle,
44 &Message.Header);
45 ok_hex(Status, STATUS_SUCCESS);
46
47 ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message),
48 "TotalLength = %u, expected %lu\n",
49 Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message));
50 ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO),
51 "DataLength = %u\n", Message.Header.u1.s1.DataLength);
52 ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST,
53 "Type = %x\n", Message.Header.u2.s2.Type);
54 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
55 "UniqueProcess = %p, expected %lx\n",
56 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
57 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
58 "UniqueThread = %p, expected %x\n",
59 Message.Header.ClientId.UniqueThread, ClientThreadId);
60 ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE1, "Message = %lx\n", Message.Message);
61
62 PortHandle = (PVOID)(ULONG_PTR)0x55555555;
63 Status = NtAcceptConnectPort(&PortHandle,
64 &Context,
65 &Message.Header,
66 FALSE,
67 NULL,
68 NULL);
69 ok_hex(Status, STATUS_SUCCESS);
70 ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle);
71
72 /* Listen a second time, then accept */
73 RtlZeroMemory(&Message, sizeof(Message));
74 Status = NtListenPort(ServerPortHandle,
75 &Message.Header);
76 ok_hex(Status, STATUS_SUCCESS);
77
78 ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message),
79 "TotalLength = %u, expected %lu\n",
80 Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message));
81 ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO),
82 "DataLength = %u\n", Message.Header.u1.s1.DataLength);
83 ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST,
84 "Type = %x\n", Message.Header.u2.s2.Type);
85 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
86 "UniqueProcess = %p, expected %lx\n",
87 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
88 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
89 "UniqueThread = %p, expected %x\n",
90 Message.Header.ClientId.UniqueThread, ClientThreadId);
91 ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE2, "Message = %lx\n", Message.Message);
92
93 Status = NtAcceptConnectPort(&PortHandle,
94 &Context,
95 &Message.Header,
96 TRUE,
97 NULL,
98 NULL);
99 ok_hex(Status, STATUS_SUCCESS);
100
101 Status = NtCompleteConnectPort(PortHandle);
102 ok_hex(Status, STATUS_SUCCESS);
103
104 RtlZeroMemory(&Message, sizeof(Message));
105 Status = NtReplyWaitReceivePort(PortHandle,
106 NULL,
107 NULL,
108 &Message.Header);
109 ok_hex(Status, STATUS_SUCCESS);
110
111 ok(Message.Header.u1.s1.TotalLength == sizeof(Message),
112 "TotalLength = %u, expected %Iu\n",
113 Message.Header.u1.s1.TotalLength, sizeof(Message));
114 ok(Message.Header.u1.s1.DataLength == sizeof(Message.Message),
115 "DataLength = %u\n", Message.Header.u1.s1.DataLength);
116 ok(Message.Header.u2.s2.Type == LPC_DATAGRAM,
117 "Type = %x\n", Message.Header.u2.s2.Type);
118 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
119 "UniqueProcess = %p, expected %lx\n",
120 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
121 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
122 "UniqueThread = %p, expected %x\n",
123 Message.Header.ClientId.UniqueThread, ClientThreadId);
124 ok(Message.Message == TEST_MESSAGE_MESSAGE, "Message = %lx\n", Message.Message);
125
126 Status = NtClose(PortHandle);
127 ok_hex(Status, STATUS_SUCCESS);
128
129 return 0;
130 }
131
132 UINT
133 CALLBACK
134 ClientThread(
135 _Inout_ PVOID Parameter)
136 {
137 NTSTATUS Status;
138 HANDLE PortHandle;
139 TEST_CONNECTION_INFO ConnectInfo;
140 ULONG ConnectInfoLength;
141 SECURITY_QUALITY_OF_SERVICE SecurityQos;
142 TEST_MESSAGE Message;
143
144 SecurityQos.Length = sizeof(SecurityQos);
145 SecurityQos.ImpersonationLevel = SecurityIdentification;
146 SecurityQos.EffectiveOnly = TRUE;
147 SecurityQos.ContextTrackingMode = SECURITY_STATIC_TRACKING;
148
149 /* Attempt to connect -- will be rejected */
150 ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE1;
151 ConnectInfoLength = sizeof(ConnectInfo);
152 PortHandle = (PVOID)(ULONG_PTR)0x55555555;
153 Status = NtConnectPort(&PortHandle,
154 &PortName,
155 &SecurityQos,
156 NULL,
157 NULL,
158 NULL,
159 &ConnectInfo,
160 &ConnectInfoLength);
161 ok_hex(Status, STATUS_PORT_CONNECTION_REFUSED);
162 ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle);
163
164 /* Try again, this time it will be accepted */
165 ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE2;
166 ConnectInfoLength = sizeof(ConnectInfo);
167 Status = NtConnectPort(&PortHandle,
168 &PortName,
169 &SecurityQos,
170 NULL,
171 NULL,
172 NULL,
173 &ConnectInfo,
174 &ConnectInfoLength);
175 ok_hex(Status, STATUS_SUCCESS);
176 if (!NT_SUCCESS(Status))
177 {
178 skip("Failed to connect\n");
179 return 0;
180 }
181
182 RtlZeroMemory(&Message, sizeof(Message));
183 Message.Header.u1.s1.TotalLength = sizeof(Message);
184 Message.Header.u1.s1.DataLength = sizeof(Message.Message);
185 Message.Message = TEST_MESSAGE_MESSAGE;
186 Status = NtRequestPort(PortHandle,
187 &Message.Header);
188 ok_hex(Status, STATUS_SUCCESS);
189
190 Status = NtClose(PortHandle);
191 ok_hex(Status, STATUS_SUCCESS);
192
193 return 0;
194 }
195
196 START_TEST(NtAcceptConnectPort)
197 {
198 NTSTATUS Status;
199 OBJECT_ATTRIBUTES ObjectAttributes;
200 HANDLE PortHandle;
201 HANDLE ThreadHandles[2];
202
203 InitializeObjectAttributes(&ObjectAttributes,
204 &PortName,
205 OBJ_CASE_INSENSITIVE,
206 NULL,
207 NULL);
208 Status = NtCreatePort(&PortHandle,
209 &ObjectAttributes,
210 sizeof(TEST_CONNECTION_INFO),
211 sizeof(TEST_MESSAGE),
212 2 * sizeof(TEST_MESSAGE));
213 ok_hex(Status, STATUS_SUCCESS);
214 if (!NT_SUCCESS(Status))
215 {
216 skip("Failed to create port\n");
217 return;
218 }
219
220 ThreadHandles[0] = (HANDLE)_beginthreadex(NULL,
221 0,
222 ServerThread,
223 PortHandle,
224 0,
225 &ServerThreadId);
226 ok(ThreadHandles[0] != NULL, "_beginthreadex failed\n");
227
228 ThreadHandles[1] = (HANDLE)_beginthreadex(NULL,
229 0,
230 ClientThread,
231 PortHandle,
232 0,
233 &ClientThreadId);
234 ok(ThreadHandles[1] != NULL, "_beginthreadex failed\n");
235
236 Status = NtWaitForMultipleObjects(RTL_NUMBER_OF(ThreadHandles),
237 ThreadHandles,
238 WaitAll,
239 FALSE,
240 NULL);
241 ok_hex(Status, STATUS_SUCCESS);
242
243 Status = NtClose(ThreadHandles[0]);
244 ok_hex(Status, STATUS_SUCCESS);
245 Status = NtClose(ThreadHandles[1]);
246 ok_hex(Status, STATUS_SUCCESS);
247
248 Status = NtClose(PortHandle);
249 ok_hex(Status, STATUS_SUCCESS);
250 }