[RPCRT4]
[reactos.git] / reactos / dll / win32 / rpcrt4 / rpc_assoc.c
1 /*
2 * Associations
3 *
4 * Copyright 2007 Robert Shearman (for CodeWeavers)
5 *
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19 *
20 */
21
22 #include "precomp.h"
23
24 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
25
26 static CRITICAL_SECTION assoc_list_cs;
27 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
28 {
29 0, 0, &assoc_list_cs,
30 { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
31 0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
32 };
33 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
34
35 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
36 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
37
38 static LONG last_assoc_group_id;
39
40 typedef struct _RpcContextHandle
41 {
42 struct list entry;
43 void *user_context;
44 NDR_RUNDOWN rundown_routine;
45 void *ctx_guard;
46 UUID uuid;
47 RTL_RWLOCK rw_lock;
48 unsigned int refs;
49 } RpcContextHandle;
50
51 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
52
53 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
54 LPCSTR Endpoint, LPCWSTR NetworkOptions,
55 RpcAssoc **assoc_out)
56 {
57 RpcAssoc *assoc;
58 assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
59 if (!assoc)
60 return RPC_S_OUT_OF_RESOURCES;
61 assoc->refs = 1;
62 list_init(&assoc->free_connection_pool);
63 list_init(&assoc->context_handle_list);
64 InitializeCriticalSection(&assoc->cs);
65 assoc->cs.DebugInfo->Spare[0] = (DWORD_PTR)(__FILE__ ": RpcAssoc.cs");
66 assoc->Protseq = RPCRT4_strdupA(Protseq);
67 assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
68 assoc->Endpoint = RPCRT4_strdupA(Endpoint);
69 assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
70 assoc->assoc_group_id = 0;
71 list_init(&assoc->entry);
72 *assoc_out = assoc;
73 return RPC_S_OK;
74 }
75
76 static BOOL compare_networkoptions(LPCWSTR opts1, LPCWSTR opts2)
77 {
78 if ((opts1 == NULL) && (opts2 == NULL))
79 return TRUE;
80 if ((opts1 == NULL) || (opts2 == NULL))
81 return FALSE;
82 return !strcmpW(opts1, opts2);
83 }
84
85 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
86 LPCSTR Endpoint, LPCWSTR NetworkOptions,
87 RpcAssoc **assoc_out)
88 {
89 RpcAssoc *assoc;
90 RPC_STATUS status;
91
92 EnterCriticalSection(&assoc_list_cs);
93 LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
94 {
95 if (!strcmp(Protseq, assoc->Protseq) &&
96 !strcmp(NetworkAddr, assoc->NetworkAddr) &&
97 !strcmp(Endpoint, assoc->Endpoint) &&
98 compare_networkoptions(NetworkOptions, assoc->NetworkOptions))
99 {
100 assoc->refs++;
101 *assoc_out = assoc;
102 LeaveCriticalSection(&assoc_list_cs);
103 TRACE("using existing assoc %p\n", assoc);
104 return RPC_S_OK;
105 }
106 }
107
108 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
109 if (status != RPC_S_OK)
110 {
111 LeaveCriticalSection(&assoc_list_cs);
112 return status;
113 }
114 list_add_head(&client_assoc_list, &assoc->entry);
115 *assoc_out = assoc;
116
117 LeaveCriticalSection(&assoc_list_cs);
118
119 TRACE("new assoc %p\n", assoc);
120
121 return RPC_S_OK;
122 }
123
124 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
125 LPCSTR Endpoint, LPCWSTR NetworkOptions,
126 ULONG assoc_gid,
127 RpcAssoc **assoc_out)
128 {
129 RpcAssoc *assoc;
130 RPC_STATUS status;
131
132 EnterCriticalSection(&assoc_list_cs);
133 if (assoc_gid)
134 {
135 LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
136 {
137 /* FIXME: NetworkAddr shouldn't be NULL */
138 if (assoc->assoc_group_id == assoc_gid &&
139 !strcmp(Protseq, assoc->Protseq) &&
140 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
141 !strcmp(Endpoint, assoc->Endpoint) &&
142 ((!assoc->NetworkOptions == !NetworkOptions) &&
143 (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
144 {
145 assoc->refs++;
146 *assoc_out = assoc;
147 LeaveCriticalSection(&assoc_list_cs);
148 TRACE("using existing assoc %p\n", assoc);
149 return RPC_S_OK;
150 }
151 }
152 *assoc_out = NULL;
153 LeaveCriticalSection(&assoc_list_cs);
154 return RPC_S_NO_CONTEXT_AVAILABLE;
155 }
156
157 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
158 if (status != RPC_S_OK)
159 {
160 LeaveCriticalSection(&assoc_list_cs);
161 return status;
162 }
163 assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
164 list_add_head(&server_assoc_list, &assoc->entry);
165 *assoc_out = assoc;
166
167 LeaveCriticalSection(&assoc_list_cs);
168
169 TRACE("new assoc %p\n", assoc);
170
171 return RPC_S_OK;
172 }
173
174 ULONG RpcAssoc_Release(RpcAssoc *assoc)
175 {
176 ULONG refs;
177
178 EnterCriticalSection(&assoc_list_cs);
179 refs = --assoc->refs;
180 if (!refs)
181 list_remove(&assoc->entry);
182 LeaveCriticalSection(&assoc_list_cs);
183
184 if (!refs)
185 {
186 RpcConnection *Connection, *cursor2;
187 RpcContextHandle *context_handle, *context_handle_cursor;
188
189 TRACE("destroying assoc %p\n", assoc);
190
191 LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
192 {
193 list_remove(&Connection->conn_pool_entry);
194 RPCRT4_ReleaseConnection(Connection);
195 }
196
197 LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
198 RpcContextHandle_Destroy(context_handle);
199
200 HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
201 HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
202 HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
203 HeapFree(GetProcessHeap(), 0, assoc->Protseq);
204
205 assoc->cs.DebugInfo->Spare[0] = 0;
206 DeleteCriticalSection(&assoc->cs);
207
208 HeapFree(GetProcessHeap(), 0, assoc);
209 }
210
211 return refs;
212 }
213
214 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
215
216 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
217 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
218 const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
219 {
220 RpcPktHdr *hdr;
221 RpcPktHdr *response_hdr;
222 RPC_MESSAGE msg;
223 RPC_STATUS status;
224 unsigned char *auth_data = NULL;
225 ULONG auth_length;
226
227 TRACE("sending bind request to server\n");
228
229 hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
230 RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
231 assoc->assoc_group_id,
232 InterfaceId, TransferSyntax);
233
234 status = RPCRT4_Send(conn, hdr, NULL, 0);
235 RPCRT4_FreeHeader(hdr);
236 if (status != RPC_S_OK)
237 return status;
238
239 status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
240 if (status != RPC_S_OK)
241 {
242 ERR("receive failed with error %d\n", status);
243 return status;
244 }
245
246 switch (response_hdr->common.ptype)
247 {
248 case PKT_BIND_ACK:
249 {
250 RpcAddressString *server_address = msg.Buffer;
251 if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
252 (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
253 {
254 unsigned short remaining = msg.BufferLength -
255 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
256 RpcResultList *results = (RpcResultList*)((ULONG_PTR)server_address +
257 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
258 if ((results->num_results == 1) &&
259 (remaining >= FIELD_OFFSET(RpcResultList, results[results->num_results])))
260 {
261 switch (results->results[0].result)
262 {
263 case RESULT_ACCEPT:
264 /* respond to authorization request */
265 if (auth_length > sizeof(RpcAuthVerifier))
266 status = RPCRT4_ClientConnectionAuth(conn,
267 auth_data + sizeof(RpcAuthVerifier),
268 auth_length);
269 if (status == RPC_S_OK)
270 {
271 conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
272 conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
273 conn->ActiveInterface = *InterfaceId;
274 }
275 break;
276 case RESULT_PROVIDER_REJECTION:
277 switch (results->results[0].reason)
278 {
279 case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
280 ERR("syntax %s, %d.%d not supported\n",
281 debugstr_guid(&InterfaceId->SyntaxGUID),
282 InterfaceId->SyntaxVersion.MajorVersion,
283 InterfaceId->SyntaxVersion.MinorVersion);
284 status = RPC_S_UNKNOWN_IF;
285 break;
286 case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
287 ERR("transfer syntax not supported\n");
288 status = RPC_S_SERVER_UNAVAILABLE;
289 break;
290 case REASON_NONE:
291 default:
292 status = RPC_S_CALL_FAILED_DNE;
293 }
294 break;
295 case RESULT_USER_REJECTION:
296 default:
297 ERR("rejection result %d\n", results->results[0].result);
298 status = RPC_S_CALL_FAILED_DNE;
299 }
300 }
301 else
302 {
303 ERR("incorrect results size\n");
304 status = RPC_S_CALL_FAILED_DNE;
305 }
306 }
307 else
308 {
309 ERR("bind ack packet too small (%d)\n", msg.BufferLength);
310 status = RPC_S_PROTOCOL_ERROR;
311 }
312 break;
313 }
314 case PKT_BIND_NACK:
315 switch (response_hdr->bind_nack.reject_reason)
316 {
317 case REJECT_LOCAL_LIMIT_EXCEEDED:
318 case REJECT_TEMPORARY_CONGESTION:
319 ERR("server too busy\n");
320 status = RPC_S_SERVER_TOO_BUSY;
321 break;
322 case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
323 ERR("protocol version not supported\n");
324 status = RPC_S_PROTOCOL_ERROR;
325 break;
326 case REJECT_UNKNOWN_AUTHN_SERVICE:
327 ERR("unknown authentication service\n");
328 status = RPC_S_UNKNOWN_AUTHN_SERVICE;
329 break;
330 case REJECT_INVALID_CHECKSUM:
331 ERR("invalid checksum\n");
332 status = ERROR_ACCESS_DENIED;
333 break;
334 default:
335 ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
336 status = RPC_S_CALL_FAILED_DNE;
337 }
338 break;
339 default:
340 ERR("wrong packet type received %d\n", response_hdr->common.ptype);
341 status = RPC_S_PROTOCOL_ERROR;
342 break;
343 }
344
345 I_RpcFree(msg.Buffer);
346 RPCRT4_FreeHeader(response_hdr);
347 HeapFree(GetProcessHeap(), 0, auth_data);
348 return status;
349 }
350
351 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
352 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
353 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
354 const RpcQualityOfService *QOS)
355 {
356 RpcConnection *Connection;
357 EnterCriticalSection(&assoc->cs);
358 /* try to find a compatible connection from the connection pool */
359 LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
360 {
361 if (!memcmp(&Connection->ActiveInterface, InterfaceId,
362 sizeof(RPC_SYNTAX_IDENTIFIER)) &&
363 RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
364 RpcQualityOfService_IsEqual(Connection->QOS, QOS))
365 {
366 list_remove(&Connection->conn_pool_entry);
367 LeaveCriticalSection(&assoc->cs);
368 TRACE("got connection from pool %p\n", Connection);
369 return Connection;
370 }
371 }
372
373 LeaveCriticalSection(&assoc->cs);
374 return NULL;
375 }
376
377 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
378 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
379 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
380 RpcQualityOfService *QOS, LPCWSTR CookieAuth, RpcConnection **Connection)
381 {
382 RpcConnection *NewConnection;
383 RPC_STATUS status;
384
385 *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
386 if (*Connection)
387 return RPC_S_OK;
388
389 /* create a new connection */
390 status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
391 assoc->Protseq, assoc->NetworkAddr,
392 assoc->Endpoint, assoc->NetworkOptions,
393 AuthInfo, QOS, CookieAuth);
394 if (status != RPC_S_OK)
395 return status;
396
397 NewConnection->assoc = assoc;
398 status = RPCRT4_OpenClientConnection(NewConnection);
399 if (status != RPC_S_OK)
400 {
401 RPCRT4_ReleaseConnection(NewConnection);
402 return status;
403 }
404
405 status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
406 if (status != RPC_S_OK)
407 {
408 RPCRT4_ReleaseConnection(NewConnection);
409 return status;
410 }
411
412 *Connection = NewConnection;
413
414 return RPC_S_OK;
415 }
416
417 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
418 {
419 assert(!Connection->server);
420 Connection->async_state = NULL;
421 EnterCriticalSection(&assoc->cs);
422 if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
423 list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
424 LeaveCriticalSection(&assoc->cs);
425 }
426
427 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
428 NDR_SCONTEXT *SContext)
429 {
430 RpcContextHandle *context_handle;
431
432 context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
433 if (!context_handle)
434 return ERROR_OUTOFMEMORY;
435
436 context_handle->ctx_guard = CtxGuard;
437 RtlInitializeResource(&context_handle->rw_lock);
438 context_handle->refs = 1;
439
440 /* lock here to mirror unmarshall, so we don't need to special-case the
441 * freeing of a non-marshalled context handle */
442 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
443
444 EnterCriticalSection(&assoc->cs);
445 list_add_tail(&assoc->context_handle_list, &context_handle->entry);
446 LeaveCriticalSection(&assoc->cs);
447
448 *SContext = (NDR_SCONTEXT)context_handle;
449 return RPC_S_OK;
450 }
451
452 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
453 {
454 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
455 return context_handle->ctx_guard == CtxGuard;
456 }
457
458 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
459 void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
460 {
461 RpcContextHandle *context_handle;
462
463 EnterCriticalSection(&assoc->cs);
464 LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
465 {
466 if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
467 !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
468 {
469 *SContext = (NDR_SCONTEXT)context_handle;
470 if (context_handle->refs++)
471 {
472 LeaveCriticalSection(&assoc->cs);
473 TRACE("found %p\n", context_handle);
474 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
475 return RPC_S_OK;
476 }
477 }
478 }
479 LeaveCriticalSection(&assoc->cs);
480
481 ERR("no context handle found for uuid %s, guard %p\n",
482 debugstr_guid(uuid), CtxGuard);
483 return ERROR_INVALID_HANDLE;
484 }
485
486 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
487 NDR_SCONTEXT SContext,
488 void *CtxGuard,
489 NDR_RUNDOWN rundown_routine)
490 {
491 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
492 RPC_STATUS status;
493
494 if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
495 return ERROR_INVALID_HANDLE;
496
497 EnterCriticalSection(&assoc->cs);
498 if (UuidIsNil(&context_handle->uuid, &status))
499 {
500 /* add a ref for the data being valid */
501 context_handle->refs++;
502 UuidCreate(&context_handle->uuid);
503 context_handle->rundown_routine = rundown_routine;
504 TRACE("allocated uuid %s for context handle %p\n",
505 debugstr_guid(&context_handle->uuid), context_handle);
506 }
507 LeaveCriticalSection(&assoc->cs);
508
509 return RPC_S_OK;
510 }
511
512 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
513 {
514 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
515 *uuid = context_handle->uuid;
516 }
517
518 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
519 {
520 TRACE("freeing %p\n", context_handle);
521
522 if (context_handle->user_context && context_handle->rundown_routine)
523 {
524 TRACE("calling rundown routine %p with user context %p\n",
525 context_handle->rundown_routine, context_handle->user_context);
526 context_handle->rundown_routine(context_handle->user_context);
527 }
528
529 RtlDeleteResource(&context_handle->rw_lock);
530
531 HeapFree(GetProcessHeap(), 0, context_handle);
532 }
533
534 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
535 {
536 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
537 unsigned int refs;
538
539 if (release_lock)
540 RtlReleaseResource(&context_handle->rw_lock);
541
542 EnterCriticalSection(&assoc->cs);
543 refs = --context_handle->refs;
544 if (!refs)
545 list_remove(&context_handle->entry);
546 LeaveCriticalSection(&assoc->cs);
547
548 if (!refs)
549 RpcContextHandle_Destroy(context_handle);
550
551 return refs;
552 }