* Sync up to trunk HEAD (r62975).
[reactos.git] / dll / win32 / rpcrt4 / ndr_ole.c
1 /*
2 * OLE32 callouts, COM interface marshalling
3 *
4 * Copyright 2001 Ove Kåven, TransGaming Technologies
5 *
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19 *
20 * TODO:
21 * - fix the wire-protocol to match MS/RPC
22 * - finish RpcStream_Vtbl
23 */
24
25 #include "precomp.h"
26
27 WINE_DEFAULT_DEBUG_CHANNEL(ole);
28
29 static HMODULE hOLE;
30
31 static HRESULT (WINAPI *COM_GetMarshalSizeMax)(ULONG *,REFIID,LPUNKNOWN,DWORD,LPVOID,DWORD);
32 static HRESULT (WINAPI *COM_MarshalInterface)(LPSTREAM,REFIID,LPUNKNOWN,DWORD,LPVOID,DWORD);
33 static HRESULT (WINAPI *COM_UnmarshalInterface)(LPSTREAM,REFIID,LPVOID*);
34 static HRESULT (WINAPI *COM_ReleaseMarshalData)(LPSTREAM);
35 static HRESULT (WINAPI *COM_GetClassObject)(REFCLSID,DWORD,COSERVERINFO *,REFIID,LPVOID *);
36 static HRESULT (WINAPI *COM_GetPSClsid)(REFIID,CLSID *);
37 static LPVOID (WINAPI *COM_MemAlloc)(ULONG);
38 static void (WINAPI *COM_MemFree)(LPVOID);
39
40 static HMODULE LoadCOM(void)
41 {
42 if (hOLE) return hOLE;
43 hOLE = LoadLibraryA("OLE32.DLL");
44 if (!hOLE) return 0;
45 COM_GetMarshalSizeMax = (LPVOID)GetProcAddress(hOLE, "CoGetMarshalSizeMax");
46 COM_MarshalInterface = (LPVOID)GetProcAddress(hOLE, "CoMarshalInterface");
47 COM_UnmarshalInterface = (LPVOID)GetProcAddress(hOLE, "CoUnmarshalInterface");
48 COM_ReleaseMarshalData = (LPVOID)GetProcAddress(hOLE, "CoReleaseMarshalData");
49 COM_GetClassObject = (LPVOID)GetProcAddress(hOLE, "CoGetClassObject");
50 COM_GetPSClsid = (LPVOID)GetProcAddress(hOLE, "CoGetPSClsid");
51 COM_MemAlloc = (LPVOID)GetProcAddress(hOLE, "CoTaskMemAlloc");
52 COM_MemFree = (LPVOID)GetProcAddress(hOLE, "CoTaskMemFree");
53 return hOLE;
54 }
55
56 /* CoMarshalInterface/CoUnmarshalInterface works on streams,
57 * so implement a simple stream on top of the RPC buffer
58 * (which also implements the MInterfacePointer structure) */
59 typedef struct RpcStreamImpl
60 {
61 IStream IStream_iface;
62 LONG RefCount;
63 PMIDL_STUB_MESSAGE pMsg;
64 LPDWORD size;
65 unsigned char *data;
66 DWORD pos;
67 } RpcStreamImpl;
68
69 static inline RpcStreamImpl *impl_from_IStream(IStream *iface)
70 {
71 return CONTAINING_RECORD(iface, RpcStreamImpl, IStream_iface);
72 }
73
74 static HRESULT WINAPI RpcStream_QueryInterface(LPSTREAM iface,
75 REFIID riid,
76 LPVOID *obj)
77 {
78 RpcStreamImpl *This = impl_from_IStream(iface);
79 if (IsEqualGUID(&IID_IUnknown, riid) ||
80 IsEqualGUID(&IID_ISequentialStream, riid) ||
81 IsEqualGUID(&IID_IStream, riid)) {
82 *obj = This;
83 InterlockedIncrement( &This->RefCount );
84 return S_OK;
85 }
86 return E_NOINTERFACE;
87 }
88
89 static ULONG WINAPI RpcStream_AddRef(LPSTREAM iface)
90 {
91 RpcStreamImpl *This = impl_from_IStream(iface);
92 return InterlockedIncrement( &This->RefCount );
93 }
94
95 static ULONG WINAPI RpcStream_Release(LPSTREAM iface)
96 {
97 RpcStreamImpl *This = impl_from_IStream(iface);
98 ULONG ref = InterlockedDecrement( &This->RefCount );
99 if (!ref) {
100 TRACE("size=%d\n", *This->size);
101 This->pMsg->Buffer = This->data + *This->size;
102 HeapFree(GetProcessHeap(),0,This);
103 return 0;
104 }
105 return ref;
106 }
107
108 static HRESULT WINAPI RpcStream_Read(LPSTREAM iface,
109 void *pv,
110 ULONG cb,
111 ULONG *pcbRead)
112 {
113 RpcStreamImpl *This = impl_from_IStream(iface);
114 HRESULT hr = S_OK;
115 if (This->pos + cb > *This->size)
116 {
117 cb = *This->size - This->pos;
118 hr = S_FALSE;
119 }
120 if (cb) {
121 memcpy(pv, This->data + This->pos, cb);
122 This->pos += cb;
123 }
124 if (pcbRead) *pcbRead = cb;
125 return hr;
126 }
127
128 static HRESULT WINAPI RpcStream_Write(LPSTREAM iface,
129 const void *pv,
130 ULONG cb,
131 ULONG *pcbWritten)
132 {
133 RpcStreamImpl *This = impl_from_IStream(iface);
134 if (This->data + cb > (unsigned char *)This->pMsg->RpcMsg->Buffer + This->pMsg->BufferLength)
135 return STG_E_MEDIUMFULL;
136 memcpy(This->data + This->pos, pv, cb);
137 This->pos += cb;
138 if (This->pos > *This->size) *This->size = This->pos;
139 if (pcbWritten) *pcbWritten = cb;
140 return S_OK;
141 }
142
143 static HRESULT WINAPI RpcStream_Seek(LPSTREAM iface,
144 LARGE_INTEGER move,
145 DWORD origin,
146 ULARGE_INTEGER *newPos)
147 {
148 RpcStreamImpl *This = impl_from_IStream(iface);
149 switch (origin) {
150 case STREAM_SEEK_SET:
151 This->pos = move.u.LowPart;
152 break;
153 case STREAM_SEEK_CUR:
154 This->pos = This->pos + move.u.LowPart;
155 break;
156 case STREAM_SEEK_END:
157 This->pos = *This->size + move.u.LowPart;
158 break;
159 default:
160 return STG_E_INVALIDFUNCTION;
161 }
162 if (newPos) {
163 newPos->u.LowPart = This->pos;
164 newPos->u.HighPart = 0;
165 }
166 return S_OK;
167 }
168
169 static HRESULT WINAPI RpcStream_SetSize(LPSTREAM iface,
170 ULARGE_INTEGER newSize)
171 {
172 RpcStreamImpl *This = impl_from_IStream(iface);
173 *This->size = newSize.u.LowPart;
174 return S_OK;
175 }
176
177 static const IStreamVtbl RpcStream_Vtbl =
178 {
179 RpcStream_QueryInterface,
180 RpcStream_AddRef,
181 RpcStream_Release,
182 RpcStream_Read,
183 RpcStream_Write,
184 RpcStream_Seek,
185 RpcStream_SetSize,
186 NULL, /* CopyTo */
187 NULL, /* Commit */
188 NULL, /* Revert */
189 NULL, /* LockRegion */
190 NULL, /* UnlockRegion */
191 NULL, /* Stat */
192 NULL /* Clone */
193 };
194
195 static LPSTREAM RpcStream_Create(PMIDL_STUB_MESSAGE pStubMsg, BOOL init)
196 {
197 RpcStreamImpl *This;
198 This = HeapAlloc(GetProcessHeap(),HEAP_ZERO_MEMORY,sizeof(RpcStreamImpl));
199 if (!This) return NULL;
200 This->IStream_iface.lpVtbl = &RpcStream_Vtbl;
201 This->RefCount = 1;
202 This->pMsg = pStubMsg;
203 This->size = (LPDWORD)pStubMsg->Buffer;
204 This->data = (unsigned char*)(This->size + 1);
205 This->pos = 0;
206 if (init) *This->size = 0;
207 TRACE("init size=%d\n", *This->size);
208 return (LPSTREAM)This;
209 }
210
211 static const IID* get_ip_iid(PMIDL_STUB_MESSAGE pStubMsg, unsigned char *pMemory, PFORMAT_STRING pFormat)
212 {
213 const IID *riid;
214 if (!pFormat) return &IID_IUnknown;
215 TRACE("format=%02x %02x\n", pFormat[0], pFormat[1]);
216 if (pFormat[0] != RPC_FC_IP) FIXME("format=%d\n", pFormat[0]);
217 if (pFormat[1] == RPC_FC_CONSTANT_IID) {
218 riid = (const IID *)&pFormat[2];
219 } else {
220 ComputeConformance(pStubMsg, pMemory, pFormat+2, 0);
221 riid = (const IID *)pStubMsg->MaxCount;
222 }
223 if (!riid) riid = &IID_IUnknown;
224 TRACE("got %s\n", debugstr_guid(riid));
225 return riid;
226 }
227
228 /***********************************************************************
229 * NdrInterfacePointerMarshall [RPCRT4.@]
230 */
231 unsigned char * WINAPI NdrInterfacePointerMarshall(PMIDL_STUB_MESSAGE pStubMsg,
232 unsigned char *pMemory,
233 PFORMAT_STRING pFormat)
234 {
235 const IID *riid = get_ip_iid(pStubMsg, pMemory, pFormat);
236 LPSTREAM stream;
237 HRESULT hr;
238
239 TRACE("(%p,%p,%p)\n", pStubMsg, pMemory, pFormat);
240 pStubMsg->MaxCount = 0;
241 if (!LoadCOM()) return NULL;
242 if (pStubMsg->Buffer + sizeof(DWORD) <= (unsigned char *)pStubMsg->RpcMsg->Buffer + pStubMsg->BufferLength) {
243 stream = RpcStream_Create(pStubMsg, TRUE);
244 if (stream) {
245 if (pMemory)
246 hr = COM_MarshalInterface(stream, riid, (LPUNKNOWN)pMemory,
247 pStubMsg->dwDestContext, pStubMsg->pvDestContext,
248 MSHLFLAGS_NORMAL);
249 else
250 hr = S_OK;
251
252 IStream_Release(stream);
253 if (FAILED(hr))
254 RpcRaiseException(hr);
255 }
256 }
257 return NULL;
258 }
259
260 /***********************************************************************
261 * NdrInterfacePointerUnmarshall [RPCRT4.@]
262 */
263 unsigned char * WINAPI NdrInterfacePointerUnmarshall(PMIDL_STUB_MESSAGE pStubMsg,
264 unsigned char **ppMemory,
265 PFORMAT_STRING pFormat,
266 unsigned char fMustAlloc)
267 {
268 LPSTREAM stream;
269 HRESULT hr;
270
271 TRACE("(%p,%p,%p,%d)\n", pStubMsg, ppMemory, pFormat, fMustAlloc);
272 if (!LoadCOM()) return NULL;
273 *(LPVOID*)ppMemory = NULL;
274 if (pStubMsg->Buffer + sizeof(DWORD) < (unsigned char *)pStubMsg->RpcMsg->Buffer + pStubMsg->BufferLength) {
275 stream = RpcStream_Create(pStubMsg, FALSE);
276 if (!stream) RpcRaiseException(E_OUTOFMEMORY);
277 if (*((RpcStreamImpl *)stream)->size != 0)
278 hr = COM_UnmarshalInterface(stream, &IID_NULL, (LPVOID*)ppMemory);
279 else
280 hr = S_OK;
281 IStream_Release(stream);
282 if (FAILED(hr))
283 RpcRaiseException(hr);
284 }
285 return NULL;
286 }
287
288 /***********************************************************************
289 * NdrInterfacePointerBufferSize [RPCRT4.@]
290 */
291 void WINAPI NdrInterfacePointerBufferSize(PMIDL_STUB_MESSAGE pStubMsg,
292 unsigned char *pMemory,
293 PFORMAT_STRING pFormat)
294 {
295 const IID *riid = get_ip_iid(pStubMsg, pMemory, pFormat);
296 ULONG size = 0;
297
298 TRACE("(%p,%p,%p)\n", pStubMsg, pMemory, pFormat);
299 if (!LoadCOM()) return;
300 COM_GetMarshalSizeMax(&size, riid, (LPUNKNOWN)pMemory,
301 pStubMsg->dwDestContext, pStubMsg->pvDestContext,
302 MSHLFLAGS_NORMAL);
303 TRACE("size=%d\n", size);
304 pStubMsg->BufferLength += sizeof(DWORD) + size;
305 }
306
307 /***********************************************************************
308 * NdrInterfacePointerMemorySize [RPCRT4.@]
309 */
310 ULONG WINAPI NdrInterfacePointerMemorySize(PMIDL_STUB_MESSAGE pStubMsg,
311 PFORMAT_STRING pFormat)
312 {
313 ULONG size;
314
315 TRACE("(%p,%p)\n", pStubMsg, pFormat);
316
317 size = *(ULONG *)pStubMsg->Buffer;
318 pStubMsg->Buffer += 4;
319 pStubMsg->MemorySize += 4;
320
321 pStubMsg->Buffer += size;
322
323 return pStubMsg->MemorySize;
324 }
325
326 /***********************************************************************
327 * NdrInterfacePointerFree [RPCRT4.@]
328 */
329 void WINAPI NdrInterfacePointerFree(PMIDL_STUB_MESSAGE pStubMsg,
330 unsigned char *pMemory,
331 PFORMAT_STRING pFormat)
332 {
333 LPUNKNOWN pUnk = (LPUNKNOWN)pMemory;
334 TRACE("(%p,%p,%p)\n", pStubMsg, pMemory, pFormat);
335 if (pUnk) IUnknown_Release(pUnk);
336 }
337
338 /***********************************************************************
339 * NdrOleAllocate [RPCRT4.@]
340 */
341 void * WINAPI NdrOleAllocate(SIZE_T Size)
342 {
343 if (!LoadCOM()) return NULL;
344 return COM_MemAlloc(Size);
345 }
346
347 /***********************************************************************
348 * NdrOleFree [RPCRT4.@]
349 */
350 void WINAPI NdrOleFree(void *NodeToFree)
351 {
352 if (!LoadCOM()) return;
353 COM_MemFree(NodeToFree);
354 }
355
356 /***********************************************************************
357 * Helper function to create a proxy.
358 * Probably similar to NdrpCreateProxy.
359 */
360 HRESULT create_proxy(REFIID iid, IUnknown *pUnkOuter, IRpcProxyBuffer **pproxy, void **ppv)
361 {
362 CLSID clsid;
363 IPSFactoryBuffer *psfac;
364 HRESULT r;
365
366 if(!LoadCOM()) return E_FAIL;
367
368 r = COM_GetPSClsid( iid, &clsid );
369 if(FAILED(r)) return r;
370
371 r = COM_GetClassObject( &clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IPSFactoryBuffer, (void**)&psfac );
372 if(FAILED(r)) return r;
373
374 r = IPSFactoryBuffer_CreateProxy(psfac, pUnkOuter, iid, pproxy, ppv);
375
376 IPSFactoryBuffer_Release(psfac);
377 return r;
378 }
379
380 /***********************************************************************
381 * Helper function to create a stub.
382 * This probably looks very much like NdrpCreateStub.
383 */
384 HRESULT create_stub(REFIID iid, IUnknown *pUnk, IRpcStubBuffer **ppstub)
385 {
386 CLSID clsid;
387 IPSFactoryBuffer *psfac;
388 HRESULT r;
389
390 if(!LoadCOM()) return E_FAIL;
391
392 r = COM_GetPSClsid( iid, &clsid );
393 if(FAILED(r)) return r;
394
395 r = COM_GetClassObject( &clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IPSFactoryBuffer, (void**)&psfac );
396 if(FAILED(r)) return r;
397
398 r = IPSFactoryBuffer_CreateStub(psfac, iid, pUnk, ppstub);
399
400 IPSFactoryBuffer_Release(psfac);
401 return r;
402 }