- Implement ProtocolResetComplete
[reactos.git] / dll / win32 / rpcrt4 / rpc_assoc.c
1 /*
2 * Associations
3 *
4 * Copyright 2007 Robert Shearman (for CodeWeavers)
5 *
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19 *
20 */
21
22 #include <stdarg.h>
23 #include <assert.h>
24
25 #include "rpc.h"
26 #include "rpcndr.h"
27 #include "winternl.h"
28
29 #include "wine/unicode.h"
30 #include "wine/debug.h"
31
32 #include "rpc_binding.h"
33 #include "rpc_assoc.h"
34 #include "rpc_message.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
37
38 static CRITICAL_SECTION assoc_list_cs;
39 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
40 {
41 0, 0, &assoc_list_cs,
42 { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
43 0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
44 };
45 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
46
47 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
48 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
49
50 static LONG last_assoc_group_id;
51
52 typedef struct _RpcContextHandle
53 {
54 struct list entry;
55 void *user_context;
56 NDR_RUNDOWN rundown_routine;
57 void *ctx_guard;
58 UUID uuid;
59 RTL_RWLOCK rw_lock;
60 unsigned int refs;
61 } RpcContextHandle;
62
63 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
64
65 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
66 LPCSTR Endpoint, LPCWSTR NetworkOptions,
67 RpcAssoc **assoc_out)
68 {
69 RpcAssoc *assoc;
70 assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
71 if (!assoc)
72 return RPC_S_OUT_OF_RESOURCES;
73 assoc->refs = 1;
74 list_init(&assoc->free_connection_pool);
75 list_init(&assoc->context_handle_list);
76 InitializeCriticalSection(&assoc->cs);
77 assoc->Protseq = RPCRT4_strdupA(Protseq);
78 assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
79 assoc->Endpoint = RPCRT4_strdupA(Endpoint);
80 assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
81 assoc->assoc_group_id = 0;
82 list_init(&assoc->entry);
83 *assoc_out = assoc;
84 return RPC_S_OK;
85 }
86
87 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
88 LPCSTR Endpoint, LPCWSTR NetworkOptions,
89 RpcAssoc **assoc_out)
90 {
91 RpcAssoc *assoc;
92 RPC_STATUS status;
93
94 EnterCriticalSection(&assoc_list_cs);
95 LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
96 {
97 if (!strcmp(Protseq, assoc->Protseq) &&
98 !strcmp(NetworkAddr, assoc->NetworkAddr) &&
99 !strcmp(Endpoint, assoc->Endpoint) &&
100 ((!assoc->NetworkOptions && !NetworkOptions) || !strcmpW(NetworkOptions, assoc->NetworkOptions)))
101 {
102 assoc->refs++;
103 *assoc_out = assoc;
104 LeaveCriticalSection(&assoc_list_cs);
105 TRACE("using existing assoc %p\n", assoc);
106 return RPC_S_OK;
107 }
108 }
109
110 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
111 if (status != RPC_S_OK)
112 {
113 LeaveCriticalSection(&assoc_list_cs);
114 return status;
115 }
116 list_add_head(&client_assoc_list, &assoc->entry);
117 *assoc_out = assoc;
118
119 LeaveCriticalSection(&assoc_list_cs);
120
121 TRACE("new assoc %p\n", assoc);
122
123 return RPC_S_OK;
124 }
125
126 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
127 LPCSTR Endpoint, LPCWSTR NetworkOptions,
128 unsigned long assoc_gid,
129 RpcAssoc **assoc_out)
130 {
131 RpcAssoc *assoc;
132 RPC_STATUS status;
133
134 EnterCriticalSection(&assoc_list_cs);
135 if (assoc_gid)
136 {
137 LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
138 {
139 /* FIXME: NetworkAddr shouldn't be NULL */
140 if (assoc->assoc_group_id == assoc_gid &&
141 !strcmp(Protseq, assoc->Protseq) &&
142 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
143 !strcmp(Endpoint, assoc->Endpoint) &&
144 ((!assoc->NetworkOptions == !NetworkOptions) &&
145 (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
146 {
147 assoc->refs++;
148 *assoc_out = assoc;
149 LeaveCriticalSection(&assoc_list_cs);
150 TRACE("using existing assoc %p\n", assoc);
151 return RPC_S_OK;
152 }
153 }
154 *assoc_out = NULL;
155 LeaveCriticalSection(&assoc_list_cs);
156 return RPC_S_NO_CONTEXT_AVAILABLE;
157 }
158
159 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
160 if (status != RPC_S_OK)
161 {
162 LeaveCriticalSection(&assoc_list_cs);
163 return status;
164 }
165 assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
166 list_add_head(&server_assoc_list, &assoc->entry);
167 *assoc_out = assoc;
168
169 LeaveCriticalSection(&assoc_list_cs);
170
171 TRACE("new assoc %p\n", assoc);
172
173 return RPC_S_OK;
174 }
175
176 ULONG RpcAssoc_Release(RpcAssoc *assoc)
177 {
178 ULONG refs;
179
180 EnterCriticalSection(&assoc_list_cs);
181 refs = --assoc->refs;
182 if (!refs)
183 list_remove(&assoc->entry);
184 LeaveCriticalSection(&assoc_list_cs);
185
186 if (!refs)
187 {
188 RpcConnection *Connection, *cursor2;
189 RpcContextHandle *context_handle, *context_handle_cursor;
190
191 TRACE("destroying assoc %p\n", assoc);
192
193 LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
194 {
195 list_remove(&Connection->conn_pool_entry);
196 RPCRT4_DestroyConnection(Connection);
197 }
198
199 LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
200 RpcContextHandle_Destroy(context_handle);
201
202 HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
203 HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
204 HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
205 HeapFree(GetProcessHeap(), 0, assoc->Protseq);
206
207 DeleteCriticalSection(&assoc->cs);
208
209 HeapFree(GetProcessHeap(), 0, assoc);
210 }
211
212 return refs;
213 }
214
215 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
216
217 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
218 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
219 const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
220 {
221 RpcPktHdr *hdr;
222 RpcPktHdr *response_hdr;
223 RPC_MESSAGE msg;
224 RPC_STATUS status;
225
226 TRACE("sending bind request to server\n");
227
228 hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
229 RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
230 assoc->assoc_group_id,
231 InterfaceId, TransferSyntax);
232
233 status = RPCRT4_Send(conn, hdr, NULL, 0);
234 RPCRT4_FreeHeader(hdr);
235 if (status != RPC_S_OK)
236 return status;
237
238 status = RPCRT4_Receive(conn, &response_hdr, &msg);
239 if (status != RPC_S_OK)
240 {
241 ERR("receive failed\n");
242 return status;
243 }
244
245 switch (response_hdr->common.ptype)
246 {
247 case PKT_BIND_ACK:
248 {
249 RpcAddressString *server_address = msg.Buffer;
250 if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
251 (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
252 {
253 unsigned short remaining = msg.BufferLength -
254 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
255 RpcResults *results = (RpcResults*)((ULONG_PTR)server_address +
256 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
257 if ((results->num_results == 1) && (remaining >= sizeof(*results)))
258 {
259 switch (results->results[0].result)
260 {
261 case RESULT_ACCEPT:
262 conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
263 conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
264 conn->ActiveInterface = *InterfaceId;
265 break;
266 case RESULT_PROVIDER_REJECTION:
267 switch (results->results[0].reason)
268 {
269 case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
270 ERR("syntax %s, %d.%d not supported\n",
271 debugstr_guid(&InterfaceId->SyntaxGUID),
272 InterfaceId->SyntaxVersion.MajorVersion,
273 InterfaceId->SyntaxVersion.MinorVersion);
274 status = RPC_S_UNKNOWN_IF;
275 break;
276 case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
277 ERR("transfer syntax not supported\n");
278 status = RPC_S_SERVER_UNAVAILABLE;
279 break;
280 case REASON_NONE:
281 default:
282 status = RPC_S_CALL_FAILED_DNE;
283 }
284 break;
285 case RESULT_USER_REJECTION:
286 default:
287 ERR("rejection result %d\n", results->results[0].result);
288 status = RPC_S_CALL_FAILED_DNE;
289 }
290 }
291 else
292 {
293 ERR("incorrect results size\n");
294 status = RPC_S_CALL_FAILED_DNE;
295 }
296 }
297 else
298 {
299 ERR("bind ack packet too small (%d)\n", msg.BufferLength);
300 status = RPC_S_PROTOCOL_ERROR;
301 }
302 break;
303 }
304 case PKT_BIND_NACK:
305 switch (response_hdr->bind_nack.reject_reason)
306 {
307 case REJECT_LOCAL_LIMIT_EXCEEDED:
308 case REJECT_TEMPORARY_CONGESTION:
309 ERR("server too busy\n");
310 status = RPC_S_SERVER_TOO_BUSY;
311 break;
312 case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
313 ERR("protocol version not supported\n");
314 status = RPC_S_PROTOCOL_ERROR;
315 break;
316 case REJECT_UNKNOWN_AUTHN_SERVICE:
317 ERR("unknown authentication service\n");
318 status = RPC_S_UNKNOWN_AUTHN_SERVICE;
319 break;
320 case REJECT_INVALID_CHECKSUM:
321 ERR("invalid checksum\n");
322 status = ERROR_ACCESS_DENIED;
323 break;
324 default:
325 ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
326 status = RPC_S_CALL_FAILED_DNE;
327 }
328 break;
329 default:
330 ERR("wrong packet type received %d\n", response_hdr->common.ptype);
331 status = RPC_S_PROTOCOL_ERROR;
332 break;
333 }
334
335 I_RpcFree(msg.Buffer);
336 RPCRT4_FreeHeader(response_hdr);
337 return status;
338 }
339
340 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
341 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
342 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
343 const RpcQualityOfService *QOS)
344 {
345 RpcConnection *Connection;
346 EnterCriticalSection(&assoc->cs);
347 /* try to find a compatible connection from the connection pool */
348 LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
349 {
350 if (!memcmp(&Connection->ActiveInterface, InterfaceId,
351 sizeof(RPC_SYNTAX_IDENTIFIER)) &&
352 RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
353 RpcQualityOfService_IsEqual(Connection->QOS, QOS))
354 {
355 list_remove(&Connection->conn_pool_entry);
356 LeaveCriticalSection(&assoc->cs);
357 TRACE("got connection from pool %p\n", Connection);
358 return Connection;
359 }
360 }
361
362 LeaveCriticalSection(&assoc->cs);
363 return NULL;
364 }
365
366 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
367 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
368 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
369 RpcQualityOfService *QOS, RpcConnection **Connection)
370 {
371 RpcConnection *NewConnection;
372 RPC_STATUS status;
373
374 *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
375 if (*Connection)
376 return RPC_S_OK;
377
378 /* create a new connection */
379 status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
380 assoc->Protseq, assoc->NetworkAddr,
381 assoc->Endpoint, assoc->NetworkOptions,
382 AuthInfo, QOS);
383 if (status != RPC_S_OK)
384 return status;
385
386 status = RPCRT4_OpenClientConnection(NewConnection);
387 if (status != RPC_S_OK)
388 {
389 RPCRT4_DestroyConnection(NewConnection);
390 return status;
391 }
392
393 status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
394 if (status != RPC_S_OK)
395 {
396 RPCRT4_DestroyConnection(NewConnection);
397 return status;
398 }
399
400 *Connection = NewConnection;
401
402 return RPC_S_OK;
403 }
404
405 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
406 {
407 assert(!Connection->server);
408 EnterCriticalSection(&assoc->cs);
409 if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
410 list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
411 LeaveCriticalSection(&assoc->cs);
412 }
413
414 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
415 NDR_SCONTEXT *SContext)
416 {
417 RpcContextHandle *context_handle;
418
419 context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
420 if (!context_handle)
421 return ERROR_OUTOFMEMORY;
422
423 context_handle->ctx_guard = CtxGuard;
424 RtlInitializeResource(&context_handle->rw_lock);
425 context_handle->refs = 1;
426
427 /* lock here to mirror unmarshall, so we don't need to special-case the
428 * freeing of a non-marshalled context handle */
429 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
430
431 EnterCriticalSection(&assoc->cs);
432 list_add_tail(&assoc->context_handle_list, &context_handle->entry);
433 LeaveCriticalSection(&assoc->cs);
434
435 *SContext = (NDR_SCONTEXT)context_handle;
436 return RPC_S_OK;
437 }
438
439 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
440 {
441 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
442 return context_handle->ctx_guard == CtxGuard;
443 }
444
445 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
446 void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
447 {
448 RpcContextHandle *context_handle;
449
450 EnterCriticalSection(&assoc->cs);
451 LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
452 {
453 if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
454 !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
455 {
456 *SContext = (NDR_SCONTEXT)context_handle;
457 if (context_handle->refs++)
458 {
459 LeaveCriticalSection(&assoc->cs);
460 TRACE("found %p\n", context_handle);
461 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
462 return RPC_S_OK;
463 }
464 }
465 }
466 LeaveCriticalSection(&assoc->cs);
467
468 ERR("no context handle found for uuid %s, guard %p\n",
469 debugstr_guid(uuid), CtxGuard);
470 return ERROR_INVALID_HANDLE;
471 }
472
473 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
474 NDR_SCONTEXT SContext,
475 void *CtxGuard,
476 NDR_RUNDOWN rundown_routine)
477 {
478 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
479 RPC_STATUS status;
480
481 if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
482 return ERROR_INVALID_HANDLE;
483
484 EnterCriticalSection(&assoc->cs);
485 if (UuidIsNil(&context_handle->uuid, &status))
486 {
487 /* add a ref for the data being valid */
488 context_handle->refs++;
489 UuidCreate(&context_handle->uuid);
490 context_handle->rundown_routine = rundown_routine;
491 TRACE("allocated uuid %s for context handle %p\n",
492 debugstr_guid(&context_handle->uuid), context_handle);
493 }
494 LeaveCriticalSection(&assoc->cs);
495
496 return RPC_S_OK;
497 }
498
499 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
500 {
501 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
502 *uuid = context_handle->uuid;
503 }
504
505 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
506 {
507 TRACE("freeing %p\n", context_handle);
508
509 if (context_handle->user_context && context_handle->rundown_routine)
510 {
511 TRACE("calling rundown routine %p with user context %p\n",
512 context_handle->rundown_routine, context_handle->user_context);
513 context_handle->rundown_routine(context_handle->user_context);
514 }
515
516 RtlDeleteResource(&context_handle->rw_lock);
517
518 HeapFree(GetProcessHeap(), 0, context_handle);
519 }
520
521 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
522 {
523 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
524 unsigned int refs;
525
526 if (release_lock)
527 RtlReleaseResource(&context_handle->rw_lock);
528
529 EnterCriticalSection(&assoc->cs);
530 refs = --context_handle->refs;
531 if (!refs)
532 list_remove(&context_handle->entry);
533 LeaveCriticalSection(&assoc->cs);
534
535 if (!refs)
536 RpcContextHandle_Destroy(context_handle);
537
538 return refs;
539 }