[ATL]
[reactos.git] / reactos / lib / atl / atlbase.h
1 /*
2 * ReactOS ATL
3 *
4 * Copyright 2009 Andrew Hill <ash77@reactos.org>
5 *
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21 #pragma once
22
23 #include "atlcore.h"
24 #include "statreg.h"
25
26 #ifdef _MSC_VER
27 // It is common to use this in ATL constructors. They only store this for later use, so the usage is safe.
28 #pragma warning(disable:4355)
29 #endif
30
31 #ifndef _ATL_PACKING
32 #define _ATL_PACKING 8
33 #endif
34
35 #ifndef _ATL_FREE_THREADED
36 #ifndef _ATL_APARTMENT_THREADED
37 #ifndef _ATL_SINGLE_THREADED
38 #define _ATL_FREE_THREADED
39 #endif
40 #endif
41 #endif
42
43 #ifndef ATLTRY
44 #define ATLTRY(x) x;
45 #endif
46
47 #ifdef _ATL_DISABLE_NO_VTABLE
48 #define ATL_NO_VTABLE
49 #else
50 #define ATL_NO_VTABLE __declspec(novtable)
51 #endif
52
53 #define offsetofclass(base, derived) (reinterpret_cast<DWORD_PTR>(static_cast<base *>(reinterpret_cast<derived *>(_ATL_PACKING))) - _ATL_PACKING)
54
55 namespace ATL
56 {
57
58 class CAtlModule;
59 class CComModule;
60 class CAtlComModule;
61 __declspec(selectany) CAtlModule *_pAtlModule = NULL;
62 __declspec(selectany) CComModule *_pModule = NULL;
63 extern CAtlComModule _AtlComModule;
64
65 typedef HRESULT (WINAPI _ATL_CREATORFUNC)(void *pv, REFIID riid, LPVOID *ppv);
66 typedef LPCTSTR (WINAPI _ATL_DESCRIPTIONFUNC)();
67 typedef const struct _ATL_CATMAP_ENTRY * (_ATL_CATMAPFUNC)();
68
69 struct _ATL_OBJMAP_ENTRY30
70 {
71 const CLSID *pclsid;
72 HRESULT (WINAPI *pfnUpdateRegistry)(BOOL bRegister);
73 _ATL_CREATORFUNC *pfnGetClassObject;
74 _ATL_CREATORFUNC *pfnCreateInstance;
75 IUnknown *pCF;
76 DWORD dwRegister;
77 _ATL_DESCRIPTIONFUNC *pfnGetObjectDescription;
78 _ATL_CATMAPFUNC *pfnGetCategoryMap;
79 void (WINAPI *pfnObjectMain)(bool bStarting);
80
81 HRESULT WINAPI RevokeClassObject()
82 {
83 if (dwRegister == 0)
84 return S_OK;
85 return CoRevokeClassObject(dwRegister);
86 }
87
88 HRESULT WINAPI RegisterClassObject(DWORD dwClsContext, DWORD dwFlags)
89 {
90 IUnknown *p;
91 HRESULT hResult;
92
93 p = NULL;
94 if (pfnGetClassObject == NULL)
95 return S_OK;
96
97 hResult = pfnGetClassObject(reinterpret_cast<LPVOID *>(pfnCreateInstance), IID_IUnknown, reinterpret_cast<LPVOID *>(&p));
98 if (SUCCEEDED(hResult))
99 hResult = CoRegisterClassObject(*pclsid, p, dwClsContext, dwFlags, &dwRegister);
100
101 if (p != NULL)
102 p->Release();
103
104 return hResult;
105 }
106 };
107
108 typedef _ATL_OBJMAP_ENTRY30 _ATL_OBJMAP_ENTRY;
109
110 typedef void (__stdcall _ATL_TERMFUNC)(DWORD_PTR dw);
111
112 struct _ATL_TERMFUNC_ELEM
113 {
114 _ATL_TERMFUNC *pFunc;
115 DWORD_PTR dw;
116 _ATL_TERMFUNC_ELEM *pNext;
117 };
118
119 struct _ATL_MODULE70
120 {
121 UINT cbSize;
122 LONG m_nLockCnt;
123 _ATL_TERMFUNC_ELEM *m_pTermFuncs;
124 CComCriticalSection m_csStaticDataInitAndTypeInfo;
125 };
126 typedef _ATL_MODULE70 _ATL_MODULE;
127
128 typedef HRESULT (WINAPI _ATL_CREATORARGFUNC)(void *pv, REFIID riid, LPVOID *ppv, DWORD_PTR dw);
129
130 #define _ATL_SIMPLEMAPENTRY ((ATL::_ATL_CREATORARGFUNC *)1)
131
132 struct _ATL_INTMAP_ENTRY
133 {
134 const IID *piid;
135 DWORD_PTR dw;
136 _ATL_CREATORARGFUNC *pFunc;
137 };
138
139 struct _AtlCreateWndData
140 {
141 void *m_pThis;
142 DWORD m_dwThreadID;
143 _AtlCreateWndData *m_pNext;
144 };
145
146 struct _ATL_COM_MODULE70
147 {
148 UINT cbSize;
149 HINSTANCE m_hInstTypeLib;
150 _ATL_OBJMAP_ENTRY **m_ppAutoObjMapFirst;
151 _ATL_OBJMAP_ENTRY **m_ppAutoObjMapLast;
152 CComCriticalSection m_csObjMap;
153 };
154 typedef _ATL_COM_MODULE70 _ATL_COM_MODULE;
155
156 struct _ATL_WIN_MODULE70
157 {
158 UINT cbSize;
159 CComCriticalSection m_csWindowCreate;
160 _AtlCreateWndData *m_pCreateWndList;
161 #ifdef NOTYET
162 CSimpleArray<ATOM> m_rgWindowClassAtoms;
163 #endif
164 };
165 typedef _ATL_WIN_MODULE70 _ATL_WIN_MODULE;
166
167 struct _ATL_REGMAP_ENTRY
168 {
169 LPCOLESTR szKey;
170 LPCOLESTR szData;
171 };
172
173 HRESULT __stdcall AtlWinModuleInit(_ATL_WIN_MODULE *pWinModule);
174 HRESULT __stdcall AtlWinModuleTerm(_ATL_WIN_MODULE *pWinModule, HINSTANCE hInst);
175 HRESULT __stdcall AtlInternalQueryInterface(void *pThis, const _ATL_INTMAP_ENTRY *pEntries, REFIID iid, void **ppvObject);
176 void __stdcall AtlWinModuleAddCreateWndData(_ATL_WIN_MODULE *pWinModule, _AtlCreateWndData *pData, void *pObject);
177 void *__stdcall AtlWinModuleExtractCreateWndData(_ATL_WIN_MODULE *pWinModule);
178 HRESULT __stdcall AtlComModuleGetClassObject(_ATL_COM_MODULE *pComModule, REFCLSID rclsid, REFIID riid, LPVOID *ppv);
179
180 template<class TLock>
181 class CComCritSecLock
182 {
183 private:
184 bool m_bLocked;
185 TLock &m_cs;
186 public:
187 CComCritSecLock(TLock &cs, bool bInitialLock = true) : m_cs(cs)
188 {
189 HRESULT hResult;
190
191 m_bLocked = false;
192 if (bInitialLock)
193 {
194 hResult = Lock();
195 if (FAILED(hResult))
196 {
197 ATLASSERT(false);
198 }
199 }
200 }
201
202 ~CComCritSecLock()
203 {
204 if (m_bLocked)
205 Unlock();
206 }
207
208 HRESULT Lock()
209 {
210 HRESULT hResult;
211
212 ATLASSERT(!m_bLocked);
213 hResult = m_cs.Lock();
214 if (FAILED(hResult))
215 return hResult;
216 m_bLocked = true;
217
218 return S_OK;
219 }
220
221 void Unlock()
222 {
223 HRESULT hResult;
224
225 ATLASSERT(m_bLocked);
226 hResult = m_cs.Unlock();
227 if (FAILED(hResult))
228 {
229 ATLASSERT(false);
230 }
231 m_bLocked = false;
232 }
233 };
234
235
236 class CHandle
237 {
238 public:
239 HANDLE m_handle;
240
241 public:
242 CHandle() :
243 m_handle(NULL)
244 {
245 }
246
247 CHandle(_Inout_ CHandle& handle) :
248 m_handle(NULL)
249 {
250 Attach(handle.Detach());
251 }
252
253 explicit CHandle(_In_ HANDLE handle) :
254 m_handle(handle)
255 {
256 }
257
258 ~CHandle()
259 {
260 if (m_handle)
261 {
262 Close();
263 }
264 }
265
266 CHandle& operator=(_Inout_ CHandle& handle)
267 {
268 if (this != &handle)
269 {
270 if (m_handle)
271 {
272 Close();
273 }
274 Attach(handle.Detach());
275 }
276
277 return *this;
278 }
279
280 operator HANDLE() const
281 {
282 return m_handle;
283 }
284
285 void Attach(_In_ HANDLE handle)
286 {
287 ATLASSERT(m_handle == NULL);
288 m_handle = handle;
289 }
290
291 HANDLE Detach()
292 {
293 HANDLE handle = m_handle;
294 m_handle = NULL;
295 return handle;
296 }
297
298 void Close()
299 {
300 if (m_handle)
301 {
302 ::CloseHandle(m_handle);
303 m_handle = NULL;
304 }
305 }
306 };
307
308
309 inline BOOL WINAPI InlineIsEqualUnknown(REFGUID rguid1)
310 {
311 return (
312 ((unsigned long *)&rguid1)[0] == 0 &&
313 ((unsigned long *)&rguid1)[1] == 0 &&
314 ((unsigned long *)&rguid1)[2] == 0x000000C0 &&
315 ((unsigned long *)&rguid1)[3] == 0x46000000);
316 }
317
318 class CComMultiThreadModelNoCS
319 {
320 public:
321 typedef CComFakeCriticalSection AutoCriticalSection;
322 typedef CComFakeCriticalSection CriticalSection;
323 typedef CComMultiThreadModelNoCS ThreadModelNoCS;
324 typedef CComFakeCriticalSection AutoDeleteCriticalSection;
325
326 static ULONG WINAPI Increment(LPLONG p)
327 {
328 return InterlockedIncrement(p);
329 }
330
331 static ULONG WINAPI Decrement(LPLONG p)
332 {
333 return InterlockedDecrement(p);
334 }
335 };
336
337 class CComMultiThreadModel
338 {
339 public:
340 typedef CComAutoCriticalSection AutoCriticalSection;
341 typedef CComCriticalSection CriticalSection;
342 typedef CComMultiThreadModelNoCS ThreadModelNoCS;
343 typedef CComAutoDeleteCriticalSection AutoDeleteCriticalSection;
344
345 static ULONG WINAPI Increment(LPLONG p)
346 {
347 return InterlockedIncrement(p);
348 }
349
350 static ULONG WINAPI Decrement(LPLONG p)
351 {
352 return InterlockedDecrement(p);
353 }
354 };
355
356 class CComSingleThreadModel
357 {
358 public:
359 typedef CComFakeCriticalSection AutoCriticalSection;
360 typedef CComFakeCriticalSection CriticalSection;
361 typedef CComSingleThreadModel ThreadModelNoCS;
362 typedef CComFakeCriticalSection AutoDeleteCriticalSection;
363
364 static ULONG WINAPI Increment(LPLONG p)
365 {
366 return ++*p;
367 }
368
369 static ULONG WINAPI Decrement(LPLONG p)
370 {
371 return --*p;
372 }
373 };
374
375 #if defined(_ATL_FREE_THREADED)
376
377 typedef CComMultiThreadModel CComObjectThreadModel;
378 typedef CComMultiThreadModel CComGlobalsThreadModel;
379
380 #elif defined(_ATL_APARTMENT_THREADED)
381
382 typedef CComSingleThreadModel CComObjectThreadModel;
383 typedef CComMultiThreadModel CComGlobalsThreadModel;
384
385 #elif defined(_ATL_SINGLE_THREADED)
386
387 typedef CComSingleThreadModel CComObjectThreadModel;
388 typedef CComSingleThreadModel CComGlobalsThreadModel;
389
390 #else
391 #error No threading model
392 #endif
393
394 class CAtlModule : public _ATL_MODULE
395 {
396 protected:
397 static GUID m_libid;
398 public:
399 CAtlModule()
400 {
401 ATLASSERT(_pAtlModule == NULL);
402 _pAtlModule = this;
403 cbSize = sizeof(_ATL_MODULE);
404 m_nLockCnt = 0;
405 }
406
407 virtual LONG GetLockCount()
408 {
409 return m_nLockCnt;
410 }
411
412 virtual LONG Lock()
413 {
414 return CComGlobalsThreadModel::Increment(&m_nLockCnt);
415 }
416
417 virtual LONG Unlock()
418 {
419 return CComGlobalsThreadModel::Decrement(&m_nLockCnt);
420 }
421
422 virtual HRESULT AddCommonRGSReplacements(IRegistrarBase* /*pRegistrar*/) = 0;
423
424 HRESULT WINAPI UpdateRegistryFromResource(LPCTSTR lpszRes, BOOL bRegister, struct _ATL_REGMAP_ENTRY *pMapEntries = NULL)
425 {
426 CRegObject registrar;
427 WCHAR modulePath[MAX_PATH];
428 HRESULT hResult;
429 PCWSTR lpwszRes;
430
431 hResult = CommonInitRegistrar(registrar, modulePath, sizeof(modulePath) / sizeof(modulePath[0]), pMapEntries);
432 if (FAILED(hResult))
433 return hResult;
434 #ifdef UNICODE
435 lpwszRes = lpszRes;
436 #else
437 /* FIXME: this is a bit of a hack, need to re-evaluate */
438 WCHAR resid[MAX_PATH];
439 MultiByteToWideChar(CP_ACP, 0, lpszRes, -1, resid, MAX_PATH);
440 lpwszRes = resid;
441 #endif
442 if (bRegister != FALSE)
443 hResult = registrar.ResourceRegisterSz(modulePath, lpwszRes, L"REGISTRY");
444 else
445 hResult = registrar.ResourceUnregisterSz(modulePath, lpwszRes, L"REGISTRY");
446
447 return hResult;
448 }
449
450 HRESULT WINAPI UpdateRegistryFromResource(UINT nResID, BOOL bRegister, struct _ATL_REGMAP_ENTRY *pMapEntries = NULL)
451 {
452 CRegObject registrar;
453 WCHAR modulePath[MAX_PATH];
454 HRESULT hResult;
455
456 hResult = CommonInitRegistrar(registrar, modulePath, sizeof(modulePath) / sizeof(modulePath[0]), pMapEntries);
457 if (FAILED(hResult))
458 return hResult;
459
460 if (bRegister != FALSE)
461 hResult = registrar.ResourceRegister(modulePath, nResID, L"REGISTRY");
462 else
463 hResult = registrar.ResourceUnregister(modulePath, nResID, L"REGISTRY");
464
465 return hResult;
466 }
467
468 private:
469 HRESULT CommonInitRegistrar(CRegObject &registrar, WCHAR *modulePath, DWORD modulePathCount, struct _ATL_REGMAP_ENTRY *pMapEntries)
470 {
471 HINSTANCE hInstance;
472 DWORD dwFLen;
473 HRESULT hResult;
474
475 hInstance = _AtlBaseModule.GetModuleInstance();
476 dwFLen = GetModuleFileNameW(hInstance, modulePath, modulePathCount);
477 if (dwFLen == modulePathCount)
478 return HRESULT_FROM_WIN32(ERROR_INSUFFICIENT_BUFFER);
479 else if (dwFLen == 0)
480 return HRESULT_FROM_WIN32(GetLastError());
481
482 if (pMapEntries != NULL)
483 {
484 while (pMapEntries->szKey != NULL)
485 {
486 ATLASSERT(pMapEntries->szData != NULL);
487 hResult = registrar.AddReplacement(pMapEntries->szKey, pMapEntries->szData);
488 if (FAILED(hResult))
489 return hResult;
490 pMapEntries++;
491 }
492 }
493
494 hResult = AddCommonRGSReplacements(&registrar);
495 if (FAILED(hResult))
496 return hResult;
497
498 hResult = registrar.AddReplacement(L"Module", modulePath);
499 if (FAILED(hResult))
500 return hResult;
501
502 hResult = registrar.AddReplacement(L"Module_Raw", modulePath);
503 if (FAILED(hResult))
504 return hResult;
505
506 return S_OK;
507 }
508 };
509
510 __declspec(selectany) GUID CAtlModule::m_libid = {0x0, 0x0, 0x0, {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0} };
511
512 template <class T>
513 class CAtlModuleT : public CAtlModule
514 {
515 public:
516
517 virtual HRESULT AddCommonRGSReplacements(IRegistrarBase *pRegistrar)
518 {
519 return pRegistrar->AddReplacement(L"APPID", T::GetAppId());
520 }
521
522 static LPCOLESTR GetAppId()
523 {
524 return L"";
525 }
526 };
527
528 class CAtlComModule : public _ATL_COM_MODULE
529 {
530 public:
531 CAtlComModule()
532 {
533 GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, (LPCWSTR)this, &m_hInstTypeLib);
534 m_ppAutoObjMapFirst = NULL;
535 m_ppAutoObjMapLast = NULL;
536 if (FAILED(m_csObjMap.Init()))
537 {
538 ATLASSERT(0);
539 CAtlBaseModule::m_bInitFailed = true;
540 return;
541 }
542 cbSize = sizeof(_ATL_COM_MODULE);
543 }
544
545 ~CAtlComModule()
546 {
547 Term();
548 }
549
550 void Term()
551 {
552 if (cbSize != 0)
553 {
554 ATLASSERT(m_ppAutoObjMapFirst == NULL);
555 ATLASSERT(m_ppAutoObjMapLast == NULL);
556 m_csObjMap.Term();
557 cbSize = 0;
558 }
559 }
560 };
561
562 template <class T>
563 class CAtlDllModuleT : public CAtlModuleT<T>
564 {
565 public:
566 CAtlDllModuleT()
567 {
568 }
569
570 HRESULT DllCanUnloadNow()
571 {
572 T *pThis;
573
574 pThis = static_cast<T *>(this);
575 if (pThis->GetLockCount() == 0)
576 return S_OK;
577 return S_FALSE;
578 }
579
580 HRESULT DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
581 {
582 T *pThis;
583
584 pThis = static_cast<T *>(this);
585 return pThis->GetClassObject(rclsid, riid, ppv);
586 }
587
588 HRESULT DllRegisterServer(BOOL bRegTypeLib = TRUE)
589 {
590 T *pThis;
591 HRESULT hResult;
592
593 pThis = static_cast<T *>(this);
594 hResult = pThis->RegisterServer(bRegTypeLib);
595 return hResult;
596 }
597
598 HRESULT DllUnregisterServer(BOOL bUnRegTypeLib = TRUE)
599 {
600 T *pThis;
601 HRESULT hResult;
602
603 pThis = static_cast<T *>(this);
604 hResult = pThis->UnregisterServer(bUnRegTypeLib);
605 return hResult;
606 }
607
608 HRESULT GetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
609 {
610 return AtlComModuleGetClassObject(&_AtlComModule, rclsid, riid, ppv);
611 }
612 };
613
614 class CComModule : public CAtlModuleT<CComModule>
615 {
616 public:
617 _ATL_OBJMAP_ENTRY *m_pObjMap;
618 public:
619 CComModule()
620 {
621 ATLASSERT(_pModule == NULL);
622 _pModule = this;
623 _pModule->m_pObjMap = NULL;
624 }
625
626 ~CComModule()
627 {
628 _pModule = NULL;
629 }
630
631 HRESULT Init(_ATL_OBJMAP_ENTRY *p, HINSTANCE /* h */, const GUID *plibid)
632 {
633 _ATL_OBJMAP_ENTRY *objectMapEntry;
634
635 if (plibid != NULL)
636 m_libid = *plibid;
637
638 if (p != reinterpret_cast<_ATL_OBJMAP_ENTRY *>(-1))
639 {
640 m_pObjMap = p;
641 if (p != NULL)
642 {
643 objectMapEntry = p;
644 while (objectMapEntry->pclsid != NULL)
645 {
646 objectMapEntry->pfnObjectMain(true);
647 objectMapEntry++;
648 }
649 }
650 }
651 return S_OK;
652 }
653
654 void Term()
655 {
656 _ATL_OBJMAP_ENTRY *objectMapEntry;
657
658 if (m_pObjMap != NULL)
659 {
660 objectMapEntry = m_pObjMap;
661 while (objectMapEntry->pclsid != NULL)
662 {
663 if (objectMapEntry->pCF != NULL)
664 objectMapEntry->pCF->Release();
665 objectMapEntry->pCF = NULL;
666 objectMapEntry->pfnObjectMain(false);
667 objectMapEntry++;
668 }
669 }
670 }
671
672 HRESULT GetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
673 {
674 _ATL_OBJMAP_ENTRY *objectMapEntry;
675 HRESULT hResult;
676
677 ATLASSERT(ppv != NULL);
678 if (ppv == NULL)
679 return E_POINTER;
680 *ppv = NULL;
681 hResult = S_OK;
682 if (m_pObjMap != NULL)
683 {
684 objectMapEntry = m_pObjMap;
685 while (objectMapEntry->pclsid != NULL)
686 {
687 if (objectMapEntry->pfnGetClassObject != NULL && InlineIsEqualGUID(rclsid, *objectMapEntry->pclsid) != FALSE)
688 {
689 if (objectMapEntry->pCF == NULL)
690 {
691 CComCritSecLock<CComCriticalSection> lock(_AtlComModule.m_csObjMap, true);
692
693 if (objectMapEntry->pCF == NULL)
694 hResult = objectMapEntry->pfnGetClassObject(reinterpret_cast<void *>(objectMapEntry->pfnCreateInstance), IID_IUnknown, reinterpret_cast<LPVOID *>(&objectMapEntry->pCF));
695 }
696 if (objectMapEntry->pCF != NULL)
697 hResult = objectMapEntry->pCF->QueryInterface(riid, ppv);
698 break;
699 }
700 objectMapEntry++;
701 }
702 }
703 if (hResult == S_OK && *ppv == NULL)
704 {
705 // FIXME: call AtlComModuleGetClassObject
706 hResult = CLASS_E_CLASSNOTAVAILABLE;
707 }
708 return hResult;
709 }
710
711 HRESULT RegisterServer(BOOL bRegTypeLib = FALSE, const CLSID *pCLSID = NULL)
712 {
713 _ATL_OBJMAP_ENTRY *objectMapEntry;
714 HRESULT hResult;
715
716 hResult = S_OK;
717 objectMapEntry = m_pObjMap;
718 if (objectMapEntry != NULL)
719 {
720 while (objectMapEntry->pclsid != NULL)
721 {
722 if (pCLSID == NULL || IsEqualGUID(*pCLSID, *objectMapEntry->pclsid) != FALSE)
723 {
724 hResult = objectMapEntry->pfnUpdateRegistry(TRUE);
725 if (FAILED(hResult))
726 break;
727 }
728 objectMapEntry++;
729 }
730 }
731 return hResult;
732 }
733
734 HRESULT UnregisterServer(BOOL bUnRegTypeLib, const CLSID *pCLSID = NULL)
735 {
736 _ATL_OBJMAP_ENTRY *objectMapEntry;
737 HRESULT hResult;
738
739 hResult = S_OK;
740 objectMapEntry = m_pObjMap;
741 if (objectMapEntry != NULL)
742 {
743 while (objectMapEntry->pclsid != NULL)
744 {
745 if (pCLSID == NULL || IsEqualGUID(*pCLSID, *objectMapEntry->pclsid) != FALSE)
746 {
747 hResult = objectMapEntry->pfnUpdateRegistry(FALSE); //unregister
748 if (FAILED(hResult))
749 break;
750 }
751 objectMapEntry++;
752 }
753 }
754 return hResult;
755 }
756
757 HRESULT DllCanUnloadNow()
758 {
759 if (GetLockCount() == 0)
760 return S_OK;
761 return S_FALSE;
762 }
763
764 HRESULT DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
765 {
766 return GetClassObject(rclsid, riid, ppv);
767 }
768
769 HRESULT DllRegisterServer(BOOL bRegTypeLib = TRUE)
770 {
771 return RegisterServer(bRegTypeLib);
772 }
773
774 HRESULT DllUnregisterServer(BOOL bUnRegTypeLib = TRUE)
775 {
776 return UnregisterServer(bUnRegTypeLib);
777 }
778
779 };
780
781 class CAtlWinModule : public _ATL_WIN_MODULE
782 {
783 public:
784 CAtlWinModule()
785 {
786 HRESULT hResult;
787
788 hResult = AtlWinModuleInit(this);
789 if (FAILED(hResult))
790 {
791 CAtlBaseModule::m_bInitFailed = true;
792 ATLASSERT(0);
793 }
794 }
795
796 ~CAtlWinModule()
797 {
798 Term();
799 }
800
801 void Term()
802 {
803 AtlWinModuleTerm(this, _AtlBaseModule.GetModuleInstance());
804 }
805
806 void AddCreateWndData(_AtlCreateWndData *pData, void *pObject)
807 {
808 AtlWinModuleAddCreateWndData(this, pData, pObject);
809 }
810
811 void *ExtractCreateWndData()
812 {
813 return AtlWinModuleExtractCreateWndData(this);
814 }
815 };
816
817 extern CAtlWinModule _AtlWinModule;
818
819 template<class T>
820 class CComPtr
821 {
822 public:
823 T *p;
824 public:
825 CComPtr()
826 {
827 p = NULL;
828 }
829
830 CComPtr(T *lp)
831 {
832 p = lp;
833 if (p != NULL)
834 p->AddRef();
835 }
836
837 CComPtr(const CComPtr<T> &lp)
838 {
839 p = lp.p;
840 if (p != NULL)
841 p->AddRef();
842 }
843
844 ~CComPtr()
845 {
846 if (p != NULL)
847 p->Release();
848 }
849
850 T *operator = (T *lp)
851 {
852 if (p != NULL)
853 p->Release();
854 p = lp;
855 if (p != NULL)
856 p->AddRef();
857 return *this;
858 }
859
860 T *operator = (const CComPtr<T> &lp)
861 {
862 if (p != NULL)
863 p->Release();
864 p = lp.p;
865 if (p != NULL)
866 p->AddRef();
867 return *this;
868 }
869
870 void Release()
871 {
872 if (p != NULL)
873 {
874 p->Release();
875 p = NULL;
876 }
877 }
878
879 void Attach(T *lp)
880 {
881 if (p != NULL)
882 p->Release();
883 p = lp;
884 }
885
886 T *Detach()
887 {
888 T *saveP;
889
890 saveP = p;
891 p = NULL;
892 return saveP;
893 }
894
895 T **operator & ()
896 {
897 ATLASSERT(p == NULL);
898 return &p;
899 }
900
901 operator T * ()
902 {
903 return p;
904 }
905
906 T *operator -> ()
907 {
908 ATLASSERT(p != NULL);
909 return p;
910 }
911 };
912
913
914 // TODO: When someone needs it, make the allocator a template, so you can use it for both
915 // CoTask* allocations, and CRT-like allocations (malloc, realloc, free)
916 template<class T>
917 class CComHeapPtr
918 {
919 public:
920 CComHeapPtr() :
921 m_Data(NULL)
922 {
923 }
924
925 explicit CComHeapPtr(T *lp) :
926 m_Data(lp)
927 {
928 }
929
930 explicit CComHeapPtr(CComHeapPtr<T> &lp)
931 {
932 m_Data = lp.Detach();
933 }
934
935 ~CComHeapPtr()
936 {
937 Release();
938 }
939
940 T *operator = (CComHeapPtr<T> &lp)
941 {
942 if (lp.m_Data != m_Data)
943 Attach(lp.Detach());
944 return *this;
945 }
946
947 bool Allocate(size_t nElements = 1)
948 {
949 ATLASSERT(m_Data == NULL);
950 m_Data = static_cast<T*>(::CoTaskMemAlloc(nElements * sizeof(T)));
951 return m_Data != NULL;
952 }
953
954 bool Reallocate(_In_ size_t nElements)
955 {
956 T* newData = static_cast<T*>(::CoTaskMemRealloc(m_Data, nElements * sizeof(T)));
957 if (newData == NULL)
958 return false;
959 m_Data = newData;
960 return true;
961 }
962
963 void Release()
964 {
965 if (m_Data)
966 {
967 ::CoTaskMemFree(m_Data);
968 m_Data = NULL;
969 }
970 }
971
972 void Attach(T *lp)
973 {
974 Release();
975 m_Data = lp;
976 }
977
978 T *Detach()
979 {
980 T *saveP = m_Data;
981 m_Data = NULL;
982 return saveP;
983 }
984
985 T **operator &()
986 {
987 ATLASSERT(m_Data == NULL);
988 return &m_Data;
989 }
990
991 operator T* () const
992 {
993 return m_Data;
994 }
995
996 protected:
997 T *m_Data;
998 };
999
1000
1001 class CComBSTR
1002 {
1003 public:
1004 BSTR m_str;
1005 public:
1006 CComBSTR() :
1007 m_str(NULL)
1008 {
1009 }
1010
1011 CComBSTR(LPCOLESTR pSrc)
1012 {
1013 if (pSrc == NULL)
1014 m_str = NULL;
1015 else
1016 m_str = ::SysAllocString(pSrc);
1017 }
1018
1019 CComBSTR(int length)
1020 {
1021 if (length == 0)
1022 m_str = NULL;
1023 else
1024 m_str = ::SysAllocStringLen(NULL, length);
1025 }
1026
1027 CComBSTR(int length, LPCOLESTR pSrc)
1028 {
1029 if (length == 0)
1030 m_str = NULL;
1031 else
1032 m_str = ::SysAllocStringLen(pSrc, length);
1033 }
1034
1035 CComBSTR(PCSTR pSrc)
1036 {
1037 if (pSrc)
1038 {
1039 int len = MultiByteToWideChar(CP_THREAD_ACP, 0, pSrc, -1, NULL, 0);
1040 m_str = ::SysAllocStringLen(NULL, len - 1);
1041 if (m_str)
1042 {
1043 int res = MultiByteToWideChar(CP_THREAD_ACP, 0, pSrc, -1, m_str, len);
1044 ATLASSERT(res == len);
1045 if (res != len)
1046 {
1047 ::SysFreeString(m_str);
1048 m_str = NULL;
1049 }
1050 }
1051 }
1052 else
1053 {
1054 m_str = NULL;
1055 }
1056 }
1057
1058 CComBSTR(const CComBSTR &other)
1059 {
1060 m_str = other.Copy();
1061 }
1062
1063 CComBSTR(REFGUID guid)
1064 {
1065 OLECHAR szGuid[40];
1066 ::StringFromGUID2(guid, szGuid, 40);
1067 m_str = ::SysAllocString(szGuid);
1068 }
1069
1070 ~CComBSTR()
1071 {
1072 ::SysFreeString(m_str);
1073 m_str = NULL;
1074 }
1075
1076 operator BSTR () const
1077 {
1078 return m_str;
1079 }
1080
1081 BSTR *operator & ()
1082 {
1083 return &m_str;
1084 }
1085
1086 CComBSTR &operator = (const CComBSTR &other)
1087 {
1088 ::SysFreeString(m_str);
1089 m_str = other.Copy();
1090 return *this;
1091 }
1092
1093 BSTR Copy() const
1094 {
1095 if (!m_str)
1096 return NULL;
1097 return ::SysAllocStringLen(m_str, ::SysStringLen(m_str));
1098 }
1099
1100 HRESULT CopyTo(BSTR *other) const
1101 {
1102 if (!other)
1103 return E_POINTER;
1104 *other = Copy();
1105 return S_OK;
1106 }
1107
1108 bool LoadString(HMODULE module, DWORD uID)
1109 {
1110 ::SysFreeString(m_str);
1111 m_str = NULL;
1112 const wchar_t *ptr = NULL;
1113 int len = ::LoadStringW(module, uID, (PWSTR)&ptr, 0);
1114 if (len)
1115 m_str = ::SysAllocStringLen(ptr, len);
1116 return m_str != NULL;
1117 }
1118
1119 unsigned int Length() const
1120 {
1121 return ::SysStringLen(m_str);
1122 }
1123
1124 unsigned int ByteLength() const
1125 {
1126 return ::SysStringByteLen(m_str);
1127 }
1128 };
1129
1130 class CComVariant : public tagVARIANT
1131 {
1132 public:
1133 CComVariant()
1134 {
1135 ::VariantInit(this);
1136 }
1137
1138 ~CComVariant()
1139 {
1140 Clear();
1141 }
1142
1143 HRESULT Clear()
1144 {
1145 return ::VariantClear(this);
1146 }
1147 };
1148
1149 inline HRESULT __stdcall AtlAdvise(IUnknown *pUnkCP, IUnknown *pUnk, const IID &iid, LPDWORD pdw)
1150 {
1151 CComPtr<IConnectionPointContainer> container;
1152 CComPtr<IConnectionPoint> connectionPoint;
1153 HRESULT hResult;
1154
1155 if (pUnkCP == NULL)
1156 return E_INVALIDARG;
1157 hResult = pUnkCP->QueryInterface(IID_IConnectionPointContainer, (void **)&container);
1158 if (FAILED(hResult))
1159 return hResult;
1160 hResult = container->FindConnectionPoint(iid, &connectionPoint);
1161 if (FAILED(hResult))
1162 return hResult;
1163 return connectionPoint->Advise(pUnk, pdw);
1164 }
1165
1166 inline HRESULT __stdcall AtlUnadvise(IUnknown *pUnkCP, const IID &iid, DWORD dw)
1167 {
1168 CComPtr<IConnectionPointContainer> container;
1169 CComPtr<IConnectionPoint> connectionPoint;
1170 HRESULT hResult;
1171
1172 if (pUnkCP == NULL)
1173 return E_INVALIDARG;
1174 hResult = pUnkCP->QueryInterface(IID_IConnectionPointContainer, (void **)&container);
1175 if (FAILED(hResult))
1176 return hResult;
1177 hResult = container->FindConnectionPoint(iid, &connectionPoint);
1178 if (FAILED(hResult))
1179 return hResult;
1180 return connectionPoint->Unadvise(dw);
1181 }
1182
1183 inline HRESULT __stdcall AtlInternalQueryInterface(void *pThis, const _ATL_INTMAP_ENTRY *pEntries, REFIID iid, void **ppvObject)
1184 {
1185 int i;
1186 IUnknown *resultInterface;
1187 HRESULT hResult;
1188
1189 ATLASSERT(pThis != NULL && pEntries != NULL);
1190 if (pThis == NULL || pEntries == NULL)
1191 return E_INVALIDARG;
1192 ATLASSERT(ppvObject != NULL);
1193 if (ppvObject == NULL)
1194 return E_POINTER;
1195
1196 if (InlineIsEqualUnknown(iid))
1197 {
1198 resultInterface = reinterpret_cast<IUnknown *>(reinterpret_cast<char *>(pThis) + pEntries[0].dw);
1199 *ppvObject = resultInterface;
1200 resultInterface->AddRef();
1201 return S_OK;
1202 }
1203
1204 i = 0;
1205 while (pEntries[i].pFunc != 0)
1206 {
1207 if (pEntries[i].piid == NULL || InlineIsEqualGUID(iid, *pEntries[i].piid))
1208 {
1209 if (pEntries[i].pFunc == reinterpret_cast<_ATL_CREATORARGFUNC *>(1))
1210 {
1211 ATLASSERT(pEntries[i].piid != NULL);
1212 resultInterface = reinterpret_cast<IUnknown *>(reinterpret_cast<char *>(pThis) + pEntries[i].dw);
1213 *ppvObject = resultInterface;
1214 resultInterface->AddRef();
1215 return S_OK;
1216 }
1217 else
1218 {
1219 hResult = pEntries[i].pFunc(pThis, iid, ppvObject, 0);
1220 if (hResult == S_OK || (FAILED(hResult) && pEntries[i].piid != NULL))
1221 return hResult;
1222 }
1223 break;
1224 }
1225 i++;
1226 }
1227 *ppvObject = NULL;
1228 return E_NOINTERFACE;
1229 }
1230
1231 inline HRESULT __stdcall AtlWinModuleInit(_ATL_WIN_MODULE *pWinModule)
1232 {
1233 if (pWinModule == NULL)
1234 return E_INVALIDARG;
1235 pWinModule->m_pCreateWndList = NULL;
1236 return pWinModule->m_csWindowCreate.Init();
1237 }
1238
1239 inline HRESULT __stdcall AtlWinModuleTerm(_ATL_WIN_MODULE *pWinModule, HINSTANCE hInst)
1240 {
1241 if (pWinModule == NULL)
1242 return E_INVALIDARG;
1243 pWinModule->m_csWindowCreate.Term();
1244 return S_OK;
1245 }
1246
1247 inline void __stdcall AtlWinModuleAddCreateWndData(_ATL_WIN_MODULE *pWinModule, _AtlCreateWndData *pData, void *pObject)
1248 {
1249 CComCritSecLock<CComCriticalSection> lock(pWinModule->m_csWindowCreate, true);
1250
1251 ATLASSERT(pWinModule != NULL);
1252 ATLASSERT(pObject != NULL);
1253
1254 pData->m_pThis = pObject;
1255 pData->m_dwThreadID = ::GetCurrentThreadId();
1256 pData->m_pNext = pWinModule->m_pCreateWndList;
1257 pWinModule->m_pCreateWndList = pData;
1258 }
1259
1260 inline void *__stdcall AtlWinModuleExtractCreateWndData(_ATL_WIN_MODULE *pWinModule)
1261 {
1262 CComCritSecLock<CComCriticalSection> lock(pWinModule->m_csWindowCreate, true);
1263 void *result;
1264 _AtlCreateWndData *currentEntry;
1265 _AtlCreateWndData **previousLink;
1266 DWORD threadID;
1267
1268 ATLASSERT(pWinModule != NULL);
1269
1270 result = NULL;
1271 threadID = GetCurrentThreadId();
1272 currentEntry = pWinModule->m_pCreateWndList;
1273 previousLink = &pWinModule->m_pCreateWndList;
1274 while (currentEntry != NULL)
1275 {
1276 if (currentEntry->m_dwThreadID == threadID)
1277 {
1278 *previousLink = currentEntry->m_pNext;
1279 result = currentEntry->m_pThis;
1280 break;
1281 }
1282 previousLink = &currentEntry->m_pNext;
1283 currentEntry = currentEntry->m_pNext;
1284 }
1285 return result;
1286 }
1287
1288 }; // namespace ATL
1289
1290 #ifndef _ATL_NO_AUTOMATIC_NAMESPACE
1291 using namespace ATL;
1292 #endif //!_ATL_NO_AUTOMATIC_NAMESPACE