Sync trunk r40500
[reactos.git] / reactos / drivers / network / afd / afd / lock.c
1 /* $Id$
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: drivers/net/afd/afd/lock.c
5 * PURPOSE: Ancillary functions driver
6 * PROGRAMMER: Art Yerkes (ayerkes@speakeasy.net)
7 * UPDATE HISTORY:
8 * 20040708 Created
9 */
10 #include "afd.h"
11 #include "tdi_proto.h"
12 #include "tdiconn.h"
13 #include "debug.h"
14 #include "pseh/pseh2.h"
15
16 /* Lock a method_neither request so it'll be available from DISPATCH_LEVEL */
17 PVOID LockRequest( PIRP Irp, PIO_STACK_LOCATION IrpSp ) {
18 BOOLEAN LockFailed = FALSE;
19
20 Irp->MdlAddress =
21 IoAllocateMdl( IrpSp->Parameters.DeviceIoControl.Type3InputBuffer,
22 IrpSp->Parameters.DeviceIoControl.InputBufferLength,
23 FALSE,
24 FALSE,
25 NULL );
26 if( Irp->MdlAddress ) {
27 _SEH2_TRY {
28 MmProbeAndLockPages( Irp->MdlAddress, Irp->RequestorMode, IoModifyAccess );
29 } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
30 LockFailed = TRUE;
31 } _SEH2_END;
32
33 if( LockFailed ) {
34 IoFreeMdl( Irp->MdlAddress );
35 Irp->MdlAddress = NULL;
36 return NULL;
37 }
38
39 IrpSp->Parameters.DeviceIoControl.Type3InputBuffer =
40 MmGetSystemAddressForMdlSafe( Irp->MdlAddress, NormalPagePriority );
41
42 if( !IrpSp->Parameters.DeviceIoControl.Type3InputBuffer ) {
43 IoFreeMdl( Irp->MdlAddress );
44 Irp->MdlAddress = NULL;
45 return NULL;
46 }
47
48 return IrpSp->Parameters.DeviceIoControl.Type3InputBuffer;
49 } else return NULL;
50 }
51
52 VOID UnlockRequest( PIRP Irp, PIO_STACK_LOCATION IrpSp ) {
53 PVOID Buffer = MmGetSystemAddressForMdlSafe( Irp->MdlAddress, NormalPagePriority );
54 if( IrpSp->Parameters.DeviceIoControl.Type3InputBuffer == Buffer || Buffer == NULL ) {
55 MmUnmapLockedPages( IrpSp->Parameters.DeviceIoControl.Type3InputBuffer, Irp->MdlAddress );
56 MmUnlockPages( Irp->MdlAddress );
57 IoFreeMdl( Irp->MdlAddress );
58 }
59
60 Irp->MdlAddress = NULL;
61 }
62
63 /* Note: We add an extra buffer if LockAddress is true. This allows us to
64 * treat the address buffer as an ordinary client buffer. It's only used
65 * for datagrams. */
66
67 PAFD_WSABUF LockBuffers( PAFD_WSABUF Buf, UINT Count,
68 PVOID AddressBuf, PINT AddressLen,
69 BOOLEAN Write, BOOLEAN LockAddress ) {
70 UINT i;
71 /* Copy the buffer array so we don't lose it */
72 UINT Lock = LockAddress ? 2 : 0;
73 UINT Size = sizeof(AFD_WSABUF) * (Count + Lock);
74 PAFD_WSABUF NewBuf = ExAllocatePool( PagedPool, Size * 2 );
75 BOOLEAN LockFailed = FALSE;
76
77 AFD_DbgPrint(MID_TRACE,("Called(%08x)\n", NewBuf));
78
79 if( NewBuf ) {
80 RtlZeroMemory(NewBuf, Size * 2);
81
82 PAFD_MAPBUF MapBuf = (PAFD_MAPBUF)(NewBuf + Count + Lock);
83
84 _SEH2_TRY {
85 RtlCopyMemory( NewBuf, Buf, sizeof(AFD_WSABUF) * Count );
86 if( LockAddress ) {
87 if (AddressBuf && AddressLen) {
88 NewBuf[Count].buf = AddressBuf;
89 NewBuf[Count].len = *AddressLen;
90 NewBuf[Count + 1].buf = (PVOID)AddressLen;
91 NewBuf[Count + 1].len = sizeof(*AddressLen);
92 }
93 Count += 2;
94 }
95 } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
96 AFD_DbgPrint(MIN_TRACE,("Access violation copying buffer info "
97 "from userland (%x %x)\n",
98 Buf, AddressLen));
99 ExFreePool( NewBuf );
100 _SEH2_YIELD(return NULL);
101 } _SEH2_END;
102
103 for( i = 0; i < Count; i++ ) {
104 AFD_DbgPrint(MID_TRACE,("Locking buffer %d (%x:%d)\n",
105 i, NewBuf[i].buf, NewBuf[i].len));
106
107 if( NewBuf[i].buf && NewBuf[i].len ) {
108 MapBuf[i].Mdl = IoAllocateMdl( NewBuf[i].buf,
109 NewBuf[i].len,
110 FALSE,
111 FALSE,
112 NULL );
113 } else {
114 MapBuf[i].Mdl = NULL;
115 continue;
116 }
117
118 AFD_DbgPrint(MID_TRACE,("NewMdl @ %x\n", MapBuf[i].Mdl));
119
120 if( MapBuf[i].Mdl ) {
121 AFD_DbgPrint(MID_TRACE,("Probe and lock pages\n"));
122 _SEH2_TRY {
123 MmProbeAndLockPages( MapBuf[i].Mdl, KernelMode,
124 Write ? IoModifyAccess : IoReadAccess );
125 } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
126 LockFailed = TRUE;
127 } _SEH2_END;
128 AFD_DbgPrint(MID_TRACE,("MmProbeAndLock finished\n"));
129
130 if( LockFailed ) {
131 IoFreeMdl( MapBuf[i].Mdl );
132 MapBuf[i].Mdl = NULL;
133 ExFreePool( NewBuf );
134 return NULL;
135 }
136 } else {
137 ExFreePool( NewBuf );
138 return NULL;
139 }
140 }
141 }
142
143 AFD_DbgPrint(MID_TRACE,("Leaving %x\n", NewBuf));
144
145 return NewBuf;
146 }
147
148 VOID UnlockBuffers( PAFD_WSABUF Buf, UINT Count, BOOL Address ) {
149 UINT Lock = Address ? 2 : 0;
150 PAFD_MAPBUF Map = (PAFD_MAPBUF)(Buf + Count + Lock);
151 UINT i;
152
153 if( !Buf ) return;
154
155 for( i = 0; i < Count + Lock; i++ ) {
156 if( Map[i].Mdl ) {
157 MmUnlockPages( Map[i].Mdl );
158 IoFreeMdl( Map[i].Mdl );
159 Map[i].Mdl = NULL;
160 }
161 }
162
163 ExFreePool( Buf );
164 Buf = NULL;
165 }
166
167 /* Produce a kernel-land handle array with handles replaced by object
168 * pointers. This will allow the system to do proper alerting */
169 PAFD_HANDLE LockHandles( PAFD_HANDLE HandleArray, UINT HandleCount ) {
170 UINT i;
171 NTSTATUS Status = STATUS_SUCCESS;
172
173 PAFD_HANDLE FileObjects = ExAllocatePool
174 ( NonPagedPool, HandleCount * sizeof(AFD_HANDLE) );
175
176 for( i = 0; FileObjects && i < HandleCount; i++ ) {
177 FileObjects[i].Status = 0;
178 FileObjects[i].Events = HandleArray[i].Events;
179 FileObjects[i].Handle = 0;
180 if( !HandleArray[i].Handle ) continue;
181 if( NT_SUCCESS(Status) ) {
182 Status = ObReferenceObjectByHandle
183 ( (PVOID)(ULONG_PTR)HandleArray[i].Handle,
184 FILE_ALL_ACCESS,
185 NULL,
186 KernelMode,
187 (PVOID*)&FileObjects[i].Handle,
188 NULL );
189 }
190
191 if( !NT_SUCCESS(Status) )
192 FileObjects[i].Handle = 0;
193 }
194
195 if( !NT_SUCCESS(Status) ) {
196 UnlockHandles( FileObjects, HandleCount );
197 return NULL;
198 }
199
200 return FileObjects;
201 }
202
203 VOID UnlockHandles( PAFD_HANDLE HandleArray, UINT HandleCount ) {
204 UINT i;
205
206 for( i = 0; i < HandleCount; i++ ) {
207 if( HandleArray[i].Handle )
208 ObDereferenceObject( (PVOID)(ULONG_PTR)HandleArray[i].Handle );
209 }
210
211 ExFreePool( HandleArray );
212 HandleArray = NULL;
213 }
214
215 /* Returns transitioned state or SOCKET_STATE_INVALID_TRANSITION */
216 UINT SocketAcquireStateLock( PAFD_FCB FCB ) {
217 NTSTATUS Status = STATUS_SUCCESS;
218 PVOID CurrentThread = KeGetCurrentThread();
219
220 ASSERT(KeGetCurrentIrql() <= APC_LEVEL);
221
222 AFD_DbgPrint(MAX_TRACE,("Called on %x, attempting to lock\n", FCB));
223
224 /* Wait for the previous user to unlock the FCB state. There might be
225 * multiple waiters waiting to change the state. We need to check each
226 * time we get the event whether somebody still has the state locked */
227
228 if( !FCB ) return FALSE;
229
230 if( CurrentThread == FCB->CurrentThread ) {
231 FCB->LockCount++;
232 AFD_DbgPrint(MID_TRACE,
233 ("Same thread, lock count %d\n", FCB->LockCount));
234 return TRUE;
235 } else {
236 AFD_DbgPrint(MID_TRACE,
237 ("Thread %x opposes lock thread %x\n",
238 CurrentThread, FCB->CurrentThread));
239 }
240
241
242 ExAcquireFastMutex( &FCB->Mutex );
243
244 while( FCB->Locked ) {
245 AFD_DbgPrint
246 (MID_TRACE,("FCB %x is locked, waiting for notification\n",
247 FCB));
248 ExReleaseFastMutex( &FCB->Mutex );
249 Status = KeWaitForSingleObject( &FCB->StateLockedEvent,
250 UserRequest,
251 KernelMode,
252 FALSE,
253 NULL );
254 ExAcquireFastMutex( &FCB->Mutex );
255 }
256 FCB->Locked = TRUE;
257 FCB->CurrentThread = CurrentThread;
258 FCB->LockCount++;
259 ExReleaseFastMutex( &FCB->Mutex );
260
261 AFD_DbgPrint(MAX_TRACE,("Got lock (%d).\n", FCB->LockCount));
262
263 return TRUE;
264 }
265
266 VOID SocketStateUnlock( PAFD_FCB FCB ) {
267 #ifdef DBG
268 PVOID CurrentThread = KeGetCurrentThread();
269 #endif
270 ASSERT(FCB->LockCount > 0);
271 ASSERT(KeGetCurrentIrql() <= APC_LEVEL);
272
273 ExAcquireFastMutex( &FCB->Mutex );
274 FCB->LockCount--;
275
276 if( !FCB->LockCount ) {
277 FCB->CurrentThread = NULL;
278 FCB->Locked = FALSE;
279
280 AFD_DbgPrint(MAX_TRACE,("Unlocked.\n"));
281 KePulseEvent( &FCB->StateLockedEvent, IO_NETWORK_INCREMENT, FALSE );
282 } else {
283 AFD_DbgPrint(MAX_TRACE,("New lock count: %d (Thr: %x)\n",
284 FCB->LockCount, CurrentThread));
285 }
286 ExReleaseFastMutex( &FCB->Mutex );
287 }
288
289 NTSTATUS NTAPI UnlockAndMaybeComplete
290 ( PAFD_FCB FCB, NTSTATUS Status, PIRP Irp,
291 UINT Information ) {
292
293 Irp->IoStatus.Status = Status;
294 Irp->IoStatus.Information = Information;
295
296 if( Status == STATUS_PENDING ) {
297 /* We should firstly mark this IRP as pending, because
298 otherwise it may be completed by StreamSocketConnectComplete()
299 before we return from SocketStateUnlock(). */
300 IoMarkIrpPending( Irp );
301 SocketStateUnlock( FCB );
302 } else {
303 if ( Irp->MdlAddress ) UnlockRequest( Irp, IoGetCurrentIrpStackLocation( Irp ) );
304 SocketStateUnlock( FCB );
305 IoCompleteRequest( Irp, IO_NETWORK_INCREMENT );
306 }
307 return Status;
308 }
309
310
311 NTSTATUS LostSocket( PIRP Irp ) {
312 NTSTATUS Status = STATUS_FILE_CLOSED;
313 AFD_DbgPrint(MIN_TRACE,("Called.\n"));
314 Irp->IoStatus.Information = 0;
315 Irp->IoStatus.Status = Status;
316 if ( Irp->MdlAddress ) UnlockRequest( Irp, IoGetCurrentIrpStackLocation( Irp ) );
317 IoCompleteRequest( Irp, IO_NO_INCREMENT );
318 return Status;
319 }
320
321 NTSTATUS LeaveIrpUntilLater( PAFD_FCB FCB, PIRP Irp, UINT Function ) {
322 InsertTailList( &FCB->PendingIrpList[Function],
323 &Irp->Tail.Overlay.ListEntry );
324 IoMarkIrpPending(Irp);
325 Irp->IoStatus.Status = STATUS_PENDING;
326 return UnlockAndMaybeComplete( FCB, STATUS_PENDING, Irp, 0 );
327 }