62263aa1fdedd57db183025a5685a5a220e42050
[reactos.git] / reactos / drivers / network / afd / afd / lock.c
1 /* $Id$
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: drivers/net/afd/afd/lock.c
5 * PURPOSE: Ancillary functions driver
6 * PROGRAMMER: Art Yerkes (ayerkes@speakeasy.net)
7 * UPDATE HISTORY:
8 * 20040708 Created
9 */
10 #include "afd.h"
11 #include "tdi_proto.h"
12 #include "tdiconn.h"
13 #include "debug.h"
14 #include "pseh/pseh2.h"
15
16 /* Lock a method_neither request so it'll be available from DISPATCH_LEVEL */
17 PVOID LockRequest( PIRP Irp, PIO_STACK_LOCATION IrpSp ) {
18 BOOLEAN LockFailed = FALSE;
19
20 Irp->MdlAddress =
21 IoAllocateMdl( IrpSp->Parameters.DeviceIoControl.Type3InputBuffer,
22 IrpSp->Parameters.DeviceIoControl.InputBufferLength,
23 FALSE,
24 FALSE,
25 NULL );
26 if( Irp->MdlAddress ) {
27 _SEH2_TRY {
28 MmProbeAndLockPages( Irp->MdlAddress, Irp->RequestorMode, IoModifyAccess );
29 } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
30 LockFailed = TRUE;
31 } _SEH2_END;
32
33 if( LockFailed ) {
34 IoFreeMdl( Irp->MdlAddress );
35 Irp->MdlAddress = NULL;
36 return NULL;
37 }
38
39 IrpSp->Parameters.DeviceIoControl.Type3InputBuffer =
40 MmGetSystemAddressForMdlSafe( Irp->MdlAddress, NormalPagePriority );
41
42 if( !IrpSp->Parameters.DeviceIoControl.Type3InputBuffer ) {
43 IoFreeMdl( Irp->MdlAddress );
44 Irp->MdlAddress = NULL;
45 return NULL;
46 }
47
48 return IrpSp->Parameters.DeviceIoControl.Type3InputBuffer;
49 } else return NULL;
50 }
51
52 VOID UnlockRequest( PIRP Irp, PIO_STACK_LOCATION IrpSp ) {
53 PVOID Buffer = MmGetSystemAddressForMdlSafe( Irp->MdlAddress, NormalPagePriority );
54 if( IrpSp->Parameters.DeviceIoControl.Type3InputBuffer == Buffer || Buffer == NULL ) {
55 MmUnmapLockedPages( IrpSp->Parameters.DeviceIoControl.Type3InputBuffer, Irp->MdlAddress );
56 MmUnlockPages( Irp->MdlAddress );
57 IoFreeMdl( Irp->MdlAddress );
58 }
59
60 Irp->MdlAddress = NULL;
61 }
62
63 /* Note: We add an extra buffer if LockAddress is true. This allows us to
64 * treat the address buffer as an ordinary client buffer. It's only used
65 * for datagrams. */
66
67 PAFD_WSABUF LockBuffers( PAFD_WSABUF Buf, UINT Count,
68 PVOID AddressBuf, PINT AddressLen,
69 BOOLEAN Write, BOOLEAN LockAddress ) {
70 UINT i;
71 /* Copy the buffer array so we don't lose it */
72 UINT Lock = LockAddress ? 2 : 0;
73 UINT Size = sizeof(AFD_WSABUF) * (Count + Lock);
74 PAFD_WSABUF NewBuf = ExAllocatePool( PagedPool, Size * 2 );
75 BOOLEAN LockFailed = FALSE;
76 PAFD_MAPBUF MapBuf;
77
78 AFD_DbgPrint(MID_TRACE,("Called(%08x)\n", NewBuf));
79
80 if( NewBuf ) {
81 RtlZeroMemory(NewBuf, Size * 2);
82
83 MapBuf = (PAFD_MAPBUF)(NewBuf + Count + Lock);
84
85 _SEH2_TRY {
86 RtlCopyMemory( NewBuf, Buf, sizeof(AFD_WSABUF) * Count );
87 if( LockAddress ) {
88 if (AddressBuf && AddressLen) {
89 NewBuf[Count].buf = AddressBuf;
90 NewBuf[Count].len = *AddressLen;
91 NewBuf[Count + 1].buf = (PVOID)AddressLen;
92 NewBuf[Count + 1].len = sizeof(*AddressLen);
93 }
94 Count += 2;
95 }
96 } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
97 AFD_DbgPrint(MIN_TRACE,("Access violation copying buffer info "
98 "from userland (%x %x)\n",
99 Buf, AddressLen));
100 ExFreePool( NewBuf );
101 _SEH2_YIELD(return NULL);
102 } _SEH2_END;
103
104 for( i = 0; i < Count; i++ ) {
105 AFD_DbgPrint(MID_TRACE,("Locking buffer %d (%x:%d)\n",
106 i, NewBuf[i].buf, NewBuf[i].len));
107
108 if( NewBuf[i].buf && NewBuf[i].len ) {
109 MapBuf[i].Mdl = IoAllocateMdl( NewBuf[i].buf,
110 NewBuf[i].len,
111 FALSE,
112 FALSE,
113 NULL );
114 } else {
115 MapBuf[i].Mdl = NULL;
116 continue;
117 }
118
119 AFD_DbgPrint(MID_TRACE,("NewMdl @ %x\n", MapBuf[i].Mdl));
120
121 if( MapBuf[i].Mdl ) {
122 AFD_DbgPrint(MID_TRACE,("Probe and lock pages\n"));
123 _SEH2_TRY {
124 MmProbeAndLockPages( MapBuf[i].Mdl, KernelMode,
125 Write ? IoModifyAccess : IoReadAccess );
126 } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
127 LockFailed = TRUE;
128 } _SEH2_END;
129 AFD_DbgPrint(MID_TRACE,("MmProbeAndLock finished\n"));
130
131 if( LockFailed ) {
132 IoFreeMdl( MapBuf[i].Mdl );
133 MapBuf[i].Mdl = NULL;
134 ExFreePool( NewBuf );
135 return NULL;
136 }
137 } else {
138 ExFreePool( NewBuf );
139 return NULL;
140 }
141 }
142 }
143
144 AFD_DbgPrint(MID_TRACE,("Leaving %x\n", NewBuf));
145
146 return NewBuf;
147 }
148
149 VOID UnlockBuffers( PAFD_WSABUF Buf, UINT Count, BOOL Address ) {
150 UINT Lock = Address ? 2 : 0;
151 PAFD_MAPBUF Map = (PAFD_MAPBUF)(Buf + Count + Lock);
152 UINT i;
153
154 if( !Buf ) return;
155
156 for( i = 0; i < Count + Lock; i++ ) {
157 if( Map[i].Mdl ) {
158 MmUnlockPages( Map[i].Mdl );
159 IoFreeMdl( Map[i].Mdl );
160 Map[i].Mdl = NULL;
161 }
162 }
163
164 ExFreePool( Buf );
165 Buf = NULL;
166 }
167
168 /* Produce a kernel-land handle array with handles replaced by object
169 * pointers. This will allow the system to do proper alerting */
170 PAFD_HANDLE LockHandles( PAFD_HANDLE HandleArray, UINT HandleCount ) {
171 UINT i;
172 NTSTATUS Status = STATUS_SUCCESS;
173
174 PAFD_HANDLE FileObjects = ExAllocatePool
175 ( NonPagedPool, HandleCount * sizeof(AFD_HANDLE) );
176
177 for( i = 0; FileObjects && i < HandleCount; i++ ) {
178 FileObjects[i].Status = 0;
179 FileObjects[i].Events = HandleArray[i].Events;
180 FileObjects[i].Handle = 0;
181 if( !HandleArray[i].Handle ) continue;
182 if( NT_SUCCESS(Status) ) {
183 Status = ObReferenceObjectByHandle
184 ( (PVOID)HandleArray[i].Handle,
185 FILE_ALL_ACCESS,
186 NULL,
187 KernelMode,
188 (PVOID*)&FileObjects[i].Handle,
189 NULL );
190 }
191
192 if( !NT_SUCCESS(Status) )
193 FileObjects[i].Handle = 0;
194 }
195
196 if( !NT_SUCCESS(Status) ) {
197 UnlockHandles( FileObjects, HandleCount );
198 return NULL;
199 }
200
201 return FileObjects;
202 }
203
204 VOID UnlockHandles( PAFD_HANDLE HandleArray, UINT HandleCount ) {
205 UINT i;
206
207 for( i = 0; i < HandleCount; i++ ) {
208 if( HandleArray[i].Handle )
209 ObDereferenceObject( (PVOID)HandleArray[i].Handle );
210 }
211
212 ExFreePool( HandleArray );
213 HandleArray = NULL;
214 }
215
216 /* Returns transitioned state or SOCKET_STATE_INVALID_TRANSITION */
217 UINT SocketAcquireStateLock( PAFD_FCB FCB ) {
218 NTSTATUS Status = STATUS_SUCCESS;
219 PVOID CurrentThread = KeGetCurrentThread();
220
221 ASSERT(KeGetCurrentIrql() <= APC_LEVEL);
222
223 AFD_DbgPrint(MAX_TRACE,("Called on %x, attempting to lock\n", FCB));
224
225 /* Wait for the previous user to unlock the FCB state. There might be
226 * multiple waiters waiting to change the state. We need to check each
227 * time we get the event whether somebody still has the state locked */
228
229 if( !FCB ) return FALSE;
230
231 if( CurrentThread == FCB->CurrentThread ) {
232 FCB->LockCount++;
233 AFD_DbgPrint(MID_TRACE,
234 ("Same thread, lock count %d\n", FCB->LockCount));
235 return TRUE;
236 } else {
237 AFD_DbgPrint(MID_TRACE,
238 ("Thread %x opposes lock thread %x\n",
239 CurrentThread, FCB->CurrentThread));
240 }
241
242
243 ExAcquireFastMutex( &FCB->Mutex );
244
245 while( FCB->Locked ) {
246 AFD_DbgPrint
247 (MID_TRACE,("FCB %x is locked, waiting for notification\n",
248 FCB));
249 ExReleaseFastMutex( &FCB->Mutex );
250 Status = KeWaitForSingleObject( &FCB->StateLockedEvent,
251 UserRequest,
252 KernelMode,
253 FALSE,
254 NULL );
255 ExAcquireFastMutex( &FCB->Mutex );
256 }
257 FCB->Locked = TRUE;
258 FCB->CurrentThread = CurrentThread;
259 FCB->LockCount++;
260 ExReleaseFastMutex( &FCB->Mutex );
261
262 AFD_DbgPrint(MAX_TRACE,("Got lock (%d).\n", FCB->LockCount));
263
264 return TRUE;
265 }
266
267 VOID SocketStateUnlock( PAFD_FCB FCB ) {
268 #if DBG
269 PVOID CurrentThread = KeGetCurrentThread();
270 #endif
271 ASSERT(FCB->LockCount > 0);
272 ASSERT(KeGetCurrentIrql() <= APC_LEVEL);
273
274 ExAcquireFastMutex( &FCB->Mutex );
275 FCB->LockCount--;
276
277 if( !FCB->LockCount ) {
278 FCB->CurrentThread = NULL;
279 FCB->Locked = FALSE;
280
281 AFD_DbgPrint(MAX_TRACE,("Unlocked.\n"));
282 KePulseEvent( &FCB->StateLockedEvent, IO_NETWORK_INCREMENT, FALSE );
283 } else {
284 AFD_DbgPrint(MAX_TRACE,("New lock count: %d (Thr: %x)\n",
285 FCB->LockCount, CurrentThread));
286 }
287 ExReleaseFastMutex( &FCB->Mutex );
288 }
289
290 NTSTATUS NTAPI UnlockAndMaybeComplete
291 ( PAFD_FCB FCB, NTSTATUS Status, PIRP Irp,
292 UINT Information ) {
293
294 Irp->IoStatus.Status = Status;
295 Irp->IoStatus.Information = Information;
296
297 if( Status == STATUS_PENDING ) {
298 /* We should firstly mark this IRP as pending, because
299 otherwise it may be completed by StreamSocketConnectComplete()
300 before we return from SocketStateUnlock(). */
301 IoMarkIrpPending( Irp );
302 SocketStateUnlock( FCB );
303 } else {
304 if ( Irp->MdlAddress ) UnlockRequest( Irp, IoGetCurrentIrpStackLocation( Irp ) );
305 (void)IoSetCancelRoutine(Irp, NULL);
306 SocketStateUnlock( FCB );
307 IoCompleteRequest( Irp, IO_NETWORK_INCREMENT );
308 }
309 return Status;
310 }
311
312
313 NTSTATUS LostSocket( PIRP Irp ) {
314 NTSTATUS Status = STATUS_FILE_CLOSED;
315 AFD_DbgPrint(MIN_TRACE,("Called.\n"));
316 Irp->IoStatus.Information = 0;
317 Irp->IoStatus.Status = Status;
318 if ( Irp->MdlAddress ) UnlockRequest( Irp, IoGetCurrentIrpStackLocation( Irp ) );
319 IoCompleteRequest( Irp, IO_NO_INCREMENT );
320 return Status;
321 }
322
323 NTSTATUS LeaveIrpUntilLater( PAFD_FCB FCB, PIRP Irp, UINT Function ) {
324 InsertTailList( &FCB->PendingIrpList[Function],
325 &Irp->Tail.Overlay.ListEntry );
326 IoMarkIrpPending(Irp);
327 (void)IoSetCancelRoutine(Irp, AfdCancelHandler);
328 return UnlockAndMaybeComplete( FCB, STATUS_PENDING, Irp, 0 );
329 }