Synchronize with trunk r58606.
[reactos.git] / dll / win32 / rpcrt4 / rpc_assoc.c
1 /*
2 * Associations
3 *
4 * Copyright 2007 Robert Shearman (for CodeWeavers)
5 *
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19 *
20 */
21
22 #define WIN32_NO_STATUS
23 #define _INC_WINDOWS
24
25 #include <stdarg.h>
26 #include <assert.h>
27
28 #include <windef.h>
29 #include <winbase.h>
30 #include <rpc.h>
31 //#include "rpcndr.h"
32 #include <winternl.h>
33
34 #include <wine/unicode.h>
35 #include <wine/debug.h>
36
37 //#include "rpc_binding.h"
38 #include "rpc_assoc.h"
39 #include "rpc_message.h"
40
41 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
42
43 static CRITICAL_SECTION assoc_list_cs;
44 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
45 {
46 0, 0, &assoc_list_cs,
47 { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
48 0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
49 };
50 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
51
52 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
53 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
54
55 static LONG last_assoc_group_id;
56
57 typedef struct _RpcContextHandle
58 {
59 struct list entry;
60 void *user_context;
61 NDR_RUNDOWN rundown_routine;
62 void *ctx_guard;
63 UUID uuid;
64 RTL_RWLOCK rw_lock;
65 unsigned int refs;
66 } RpcContextHandle;
67
68 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
69
70 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
71 LPCSTR Endpoint, LPCWSTR NetworkOptions,
72 RpcAssoc **assoc_out)
73 {
74 RpcAssoc *assoc;
75 assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
76 if (!assoc)
77 return RPC_S_OUT_OF_RESOURCES;
78 assoc->refs = 1;
79 list_init(&assoc->free_connection_pool);
80 list_init(&assoc->context_handle_list);
81 InitializeCriticalSection(&assoc->cs);
82 assoc->Protseq = RPCRT4_strdupA(Protseq);
83 assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
84 assoc->Endpoint = RPCRT4_strdupA(Endpoint);
85 assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
86 assoc->assoc_group_id = 0;
87 list_init(&assoc->entry);
88 *assoc_out = assoc;
89 return RPC_S_OK;
90 }
91
92 static BOOL compare_networkoptions(LPCWSTR opts1, LPCWSTR opts2)
93 {
94 if ((opts1 == NULL) && (opts2 == NULL))
95 return TRUE;
96 if ((opts1 == NULL) || (opts2 == NULL))
97 return FALSE;
98 return !strcmpW(opts1, opts2);
99 }
100
101 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
102 LPCSTR Endpoint, LPCWSTR NetworkOptions,
103 RpcAssoc **assoc_out)
104 {
105 RpcAssoc *assoc;
106 RPC_STATUS status;
107
108 EnterCriticalSection(&assoc_list_cs);
109 LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
110 {
111 if (!strcmp(Protseq, assoc->Protseq) &&
112 !strcmp(NetworkAddr, assoc->NetworkAddr) &&
113 !strcmp(Endpoint, assoc->Endpoint) &&
114 compare_networkoptions(NetworkOptions, assoc->NetworkOptions))
115 {
116 assoc->refs++;
117 *assoc_out = assoc;
118 LeaveCriticalSection(&assoc_list_cs);
119 TRACE("using existing assoc %p\n", assoc);
120 return RPC_S_OK;
121 }
122 }
123
124 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
125 if (status != RPC_S_OK)
126 {
127 LeaveCriticalSection(&assoc_list_cs);
128 return status;
129 }
130 list_add_head(&client_assoc_list, &assoc->entry);
131 *assoc_out = assoc;
132
133 LeaveCriticalSection(&assoc_list_cs);
134
135 TRACE("new assoc %p\n", assoc);
136
137 return RPC_S_OK;
138 }
139
140 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
141 LPCSTR Endpoint, LPCWSTR NetworkOptions,
142 ULONG assoc_gid,
143 RpcAssoc **assoc_out)
144 {
145 RpcAssoc *assoc;
146 RPC_STATUS status;
147
148 EnterCriticalSection(&assoc_list_cs);
149 if (assoc_gid)
150 {
151 LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
152 {
153 /* FIXME: NetworkAddr shouldn't be NULL */
154 if (assoc->assoc_group_id == assoc_gid &&
155 !strcmp(Protseq, assoc->Protseq) &&
156 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
157 !strcmp(Endpoint, assoc->Endpoint) &&
158 ((!assoc->NetworkOptions == !NetworkOptions) &&
159 (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
160 {
161 assoc->refs++;
162 *assoc_out = assoc;
163 LeaveCriticalSection(&assoc_list_cs);
164 TRACE("using existing assoc %p\n", assoc);
165 return RPC_S_OK;
166 }
167 }
168 *assoc_out = NULL;
169 LeaveCriticalSection(&assoc_list_cs);
170 return RPC_S_NO_CONTEXT_AVAILABLE;
171 }
172
173 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
174 if (status != RPC_S_OK)
175 {
176 LeaveCriticalSection(&assoc_list_cs);
177 return status;
178 }
179 assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
180 list_add_head(&server_assoc_list, &assoc->entry);
181 *assoc_out = assoc;
182
183 LeaveCriticalSection(&assoc_list_cs);
184
185 TRACE("new assoc %p\n", assoc);
186
187 return RPC_S_OK;
188 }
189
190 ULONG RpcAssoc_Release(RpcAssoc *assoc)
191 {
192 ULONG refs;
193
194 EnterCriticalSection(&assoc_list_cs);
195 refs = --assoc->refs;
196 if (!refs)
197 list_remove(&assoc->entry);
198 LeaveCriticalSection(&assoc_list_cs);
199
200 if (!refs)
201 {
202 RpcConnection *Connection, *cursor2;
203 RpcContextHandle *context_handle, *context_handle_cursor;
204
205 TRACE("destroying assoc %p\n", assoc);
206
207 LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
208 {
209 list_remove(&Connection->conn_pool_entry);
210 RPCRT4_DestroyConnection(Connection);
211 }
212
213 LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
214 RpcContextHandle_Destroy(context_handle);
215
216 HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
217 HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
218 HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
219 HeapFree(GetProcessHeap(), 0, assoc->Protseq);
220
221 DeleteCriticalSection(&assoc->cs);
222
223 HeapFree(GetProcessHeap(), 0, assoc);
224 }
225
226 return refs;
227 }
228
229 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
230
231 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
232 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
233 const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
234 {
235 RpcPktHdr *hdr;
236 RpcPktHdr *response_hdr;
237 RPC_MESSAGE msg;
238 RPC_STATUS status;
239 unsigned char *auth_data = NULL;
240 ULONG auth_length;
241
242 TRACE("sending bind request to server\n");
243
244 hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
245 RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
246 assoc->assoc_group_id,
247 InterfaceId, TransferSyntax);
248
249 status = RPCRT4_Send(conn, hdr, NULL, 0);
250 RPCRT4_FreeHeader(hdr);
251 if (status != RPC_S_OK)
252 return status;
253
254 status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
255 if (status != RPC_S_OK)
256 {
257 ERR("receive failed with error %d\n", status);
258 return status;
259 }
260
261 switch (response_hdr->common.ptype)
262 {
263 case PKT_BIND_ACK:
264 {
265 RpcAddressString *server_address = msg.Buffer;
266 if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
267 (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
268 {
269 unsigned short remaining = msg.BufferLength -
270 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
271 RpcResultList *results = (RpcResultList*)((ULONG_PTR)server_address +
272 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
273 if ((results->num_results == 1) &&
274 (remaining >= FIELD_OFFSET(RpcResultList, results[results->num_results])))
275 {
276 switch (results->results[0].result)
277 {
278 case RESULT_ACCEPT:
279 /* respond to authorization request */
280 if (auth_length > sizeof(RpcAuthVerifier))
281 status = RPCRT4_ClientConnectionAuth(conn,
282 auth_data + sizeof(RpcAuthVerifier),
283 auth_length);
284 if (status == RPC_S_OK)
285 {
286 conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
287 conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
288 conn->ActiveInterface = *InterfaceId;
289 }
290 break;
291 case RESULT_PROVIDER_REJECTION:
292 switch (results->results[0].reason)
293 {
294 case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
295 ERR("syntax %s, %d.%d not supported\n",
296 debugstr_guid(&InterfaceId->SyntaxGUID),
297 InterfaceId->SyntaxVersion.MajorVersion,
298 InterfaceId->SyntaxVersion.MinorVersion);
299 status = RPC_S_UNKNOWN_IF;
300 break;
301 case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
302 ERR("transfer syntax not supported\n");
303 status = RPC_S_SERVER_UNAVAILABLE;
304 break;
305 case REASON_NONE:
306 default:
307 status = RPC_S_CALL_FAILED_DNE;
308 }
309 break;
310 case RESULT_USER_REJECTION:
311 default:
312 ERR("rejection result %d\n", results->results[0].result);
313 status = RPC_S_CALL_FAILED_DNE;
314 }
315 }
316 else
317 {
318 ERR("incorrect results size\n");
319 status = RPC_S_CALL_FAILED_DNE;
320 }
321 }
322 else
323 {
324 ERR("bind ack packet too small (%d)\n", msg.BufferLength);
325 status = RPC_S_PROTOCOL_ERROR;
326 }
327 break;
328 }
329 case PKT_BIND_NACK:
330 switch (response_hdr->bind_nack.reject_reason)
331 {
332 case REJECT_LOCAL_LIMIT_EXCEEDED:
333 case REJECT_TEMPORARY_CONGESTION:
334 ERR("server too busy\n");
335 status = RPC_S_SERVER_TOO_BUSY;
336 break;
337 case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
338 ERR("protocol version not supported\n");
339 status = RPC_S_PROTOCOL_ERROR;
340 break;
341 case REJECT_UNKNOWN_AUTHN_SERVICE:
342 ERR("unknown authentication service\n");
343 status = RPC_S_UNKNOWN_AUTHN_SERVICE;
344 break;
345 case REJECT_INVALID_CHECKSUM:
346 ERR("invalid checksum\n");
347 status = ERROR_ACCESS_DENIED;
348 break;
349 default:
350 ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
351 status = RPC_S_CALL_FAILED_DNE;
352 }
353 break;
354 default:
355 ERR("wrong packet type received %d\n", response_hdr->common.ptype);
356 status = RPC_S_PROTOCOL_ERROR;
357 break;
358 }
359
360 I_RpcFree(msg.Buffer);
361 RPCRT4_FreeHeader(response_hdr);
362 HeapFree(GetProcessHeap(), 0, auth_data);
363 return status;
364 }
365
366 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
367 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
368 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
369 const RpcQualityOfService *QOS)
370 {
371 RpcConnection *Connection;
372 EnterCriticalSection(&assoc->cs);
373 /* try to find a compatible connection from the connection pool */
374 LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
375 {
376 if (!memcmp(&Connection->ActiveInterface, InterfaceId,
377 sizeof(RPC_SYNTAX_IDENTIFIER)) &&
378 RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
379 RpcQualityOfService_IsEqual(Connection->QOS, QOS))
380 {
381 list_remove(&Connection->conn_pool_entry);
382 LeaveCriticalSection(&assoc->cs);
383 TRACE("got connection from pool %p\n", Connection);
384 return Connection;
385 }
386 }
387
388 LeaveCriticalSection(&assoc->cs);
389 return NULL;
390 }
391
392 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
393 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
394 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
395 RpcQualityOfService *QOS, RpcConnection **Connection)
396 {
397 RpcConnection *NewConnection;
398 RPC_STATUS status;
399
400 *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
401 if (*Connection)
402 return RPC_S_OK;
403
404 /* create a new connection */
405 status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
406 assoc->Protseq, assoc->NetworkAddr,
407 assoc->Endpoint, assoc->NetworkOptions,
408 AuthInfo, QOS);
409 if (status != RPC_S_OK)
410 return status;
411
412 NewConnection->assoc = assoc;
413 status = RPCRT4_OpenClientConnection(NewConnection);
414 if (status != RPC_S_OK)
415 {
416 RPCRT4_DestroyConnection(NewConnection);
417 return status;
418 }
419
420 status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
421 if (status != RPC_S_OK)
422 {
423 RPCRT4_DestroyConnection(NewConnection);
424 return status;
425 }
426
427 *Connection = NewConnection;
428
429 return RPC_S_OK;
430 }
431
432 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
433 {
434 assert(!Connection->server);
435 Connection->async_state = NULL;
436 EnterCriticalSection(&assoc->cs);
437 if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
438 list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
439 LeaveCriticalSection(&assoc->cs);
440 }
441
442 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
443 NDR_SCONTEXT *SContext)
444 {
445 RpcContextHandle *context_handle;
446
447 context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
448 if (!context_handle)
449 return ERROR_OUTOFMEMORY;
450
451 context_handle->ctx_guard = CtxGuard;
452 RtlInitializeResource(&context_handle->rw_lock);
453 context_handle->refs = 1;
454
455 /* lock here to mirror unmarshall, so we don't need to special-case the
456 * freeing of a non-marshalled context handle */
457 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
458
459 EnterCriticalSection(&assoc->cs);
460 list_add_tail(&assoc->context_handle_list, &context_handle->entry);
461 LeaveCriticalSection(&assoc->cs);
462
463 *SContext = (NDR_SCONTEXT)context_handle;
464 return RPC_S_OK;
465 }
466
467 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
468 {
469 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
470 return context_handle->ctx_guard == CtxGuard;
471 }
472
473 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
474 void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
475 {
476 RpcContextHandle *context_handle;
477
478 EnterCriticalSection(&assoc->cs);
479 LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
480 {
481 if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
482 !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
483 {
484 *SContext = (NDR_SCONTEXT)context_handle;
485 if (context_handle->refs++)
486 {
487 LeaveCriticalSection(&assoc->cs);
488 TRACE("found %p\n", context_handle);
489 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
490 return RPC_S_OK;
491 }
492 }
493 }
494 LeaveCriticalSection(&assoc->cs);
495
496 ERR("no context handle found for uuid %s, guard %p\n",
497 debugstr_guid(uuid), CtxGuard);
498 return ERROR_INVALID_HANDLE;
499 }
500
501 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
502 NDR_SCONTEXT SContext,
503 void *CtxGuard,
504 NDR_RUNDOWN rundown_routine)
505 {
506 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
507 RPC_STATUS status;
508
509 if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
510 return ERROR_INVALID_HANDLE;
511
512 EnterCriticalSection(&assoc->cs);
513 if (UuidIsNil(&context_handle->uuid, &status))
514 {
515 /* add a ref for the data being valid */
516 context_handle->refs++;
517 UuidCreate(&context_handle->uuid);
518 context_handle->rundown_routine = rundown_routine;
519 TRACE("allocated uuid %s for context handle %p\n",
520 debugstr_guid(&context_handle->uuid), context_handle);
521 }
522 LeaveCriticalSection(&assoc->cs);
523
524 return RPC_S_OK;
525 }
526
527 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
528 {
529 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
530 *uuid = context_handle->uuid;
531 }
532
533 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
534 {
535 TRACE("freeing %p\n", context_handle);
536
537 if (context_handle->user_context && context_handle->rundown_routine)
538 {
539 TRACE("calling rundown routine %p with user context %p\n",
540 context_handle->rundown_routine, context_handle->user_context);
541 context_handle->rundown_routine(context_handle->user_context);
542 }
543
544 RtlDeleteResource(&context_handle->rw_lock);
545
546 HeapFree(GetProcessHeap(), 0, context_handle);
547 }
548
549 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
550 {
551 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
552 unsigned int refs;
553
554 if (release_lock)
555 RtlReleaseResource(&context_handle->rw_lock);
556
557 EnterCriticalSection(&assoc->cs);
558 refs = --context_handle->refs;
559 if (!refs)
560 list_remove(&context_handle->entry);
561 LeaveCriticalSection(&assoc->cs);
562
563 if (!refs)
564 RpcContextHandle_Destroy(context_handle);
565
566 return refs;
567 }