2a537d3bcc98130620991ea6613179f33c5106fa
[reactos.git] / reactos / lib / atl / atlcom.h
1 /*
2 * ReactOS ATL
3 *
4 * Copyright 2009 Andrew Hill <ash77@reactos.org>
5 * Copyright 2013 Katayama Hirofumi MZ <katayama.hirofumi.mz@gmail.com>
6 *
7 * This library is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU Lesser General Public
9 * License as published by the Free Software Foundation; either
10 * version 2.1 of the License, or (at your option) any later version.
11 *
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 * Lesser General Public License for more details.
16 *
17 * You should have received a copy of the GNU Lesser General Public
18 * License along with this library; if not, write to the Free Software
19 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22 #pragma once
23
24 namespace ATL
25 {
26
27 template <class Base, const IID *piid, class T, class Copy, class ThreadModel = CComObjectThreadModel>
28 class CComEnum;
29
30 #define DECLARE_CLASSFACTORY_EX(cf) typedef ATL::CComCreator<ATL::CComObjectCached<cf> > _ClassFactoryCreatorClass;
31 #define DECLARE_CLASSFACTORY() DECLARE_CLASSFACTORY_EX(ATL::CComClassFactory)
32 #define DECLARE_CLASSFACTORY_SINGLETON(obj) DECLARE_CLASSFACTORY_EX(ATL::CComClassFactorySingleton<obj>)
33
34 class CComObjectRootBase
35 {
36 public:
37 LONG m_dwRef;
38 public:
39 CComObjectRootBase()
40 {
41 m_dwRef = 0;
42 }
43
44 ~CComObjectRootBase()
45 {
46 }
47
48 void SetVoid(void *)
49 {
50 }
51
52 HRESULT _AtlFinalConstruct()
53 {
54 return S_OK;
55 }
56
57 HRESULT FinalConstruct()
58 {
59 return S_OK;
60 }
61
62 void InternalFinalConstructAddRef()
63 {
64 }
65
66 void InternalFinalConstructRelease()
67 {
68 }
69
70 void FinalRelease()
71 {
72 }
73
74 static void WINAPI ObjectMain(bool)
75 {
76 }
77
78 static const struct _ATL_CATMAP_ENTRY *GetCategoryMap()
79 {
80 return NULL;
81 }
82
83 static HRESULT WINAPI InternalQueryInterface(void *pThis, const _ATL_INTMAP_ENTRY *pEntries, REFIID iid, void **ppvObject)
84 {
85 return AtlInternalQueryInterface(pThis, pEntries, iid, ppvObject);
86 }
87
88 };
89
90 template <class ThreadModel>
91 class CComObjectRootEx : public CComObjectRootBase
92 {
93 private:
94 typename ThreadModel::AutoDeleteCriticalSection m_critsec;
95 public:
96 ~CComObjectRootEx()
97 {
98 }
99
100 ULONG InternalAddRef()
101 {
102 ATLASSERT(m_dwRef >= 0);
103 return ThreadModel::Increment(&m_dwRef);
104 }
105
106 ULONG InternalRelease()
107 {
108 ATLASSERT(m_dwRef > 0);
109 return ThreadModel::Decrement(&m_dwRef);
110 }
111
112 void Lock()
113 {
114 m_critsec.Lock();
115 }
116
117 void Unlock()
118 {
119 m_critsec.Unlock();
120 }
121
122 HRESULT _AtlInitialConstruct()
123 {
124 return m_critsec.Init();
125 }
126 };
127
128 template <class Base>
129 class CComObject : public Base
130 {
131 public:
132 CComObject(void * = NULL)
133 {
134 _pAtlModule->Lock();
135 }
136
137 virtual ~CComObject()
138 {
139 this->FinalRelease();
140 _pAtlModule->Unlock();
141 }
142
143 STDMETHOD_(ULONG, AddRef)()
144 {
145 return this->InternalAddRef();
146 }
147
148 STDMETHOD_(ULONG, Release)()
149 {
150 ULONG newRefCount;
151
152 newRefCount = this->InternalRelease();
153 if (newRefCount == 0)
154 delete this;
155 return newRefCount;
156 }
157
158 STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
159 {
160 return this->_InternalQueryInterface(iid, ppvObject);
161 }
162
163 static HRESULT WINAPI CreateInstance(CComObject<Base> **pp)
164 {
165 CComObject<Base> *newInstance;
166 HRESULT hResult;
167
168 ATLASSERT(pp != NULL);
169 if (pp == NULL)
170 return E_POINTER;
171
172 hResult = E_OUTOFMEMORY;
173 newInstance = NULL;
174 ATLTRY(newInstance = new CComObject<Base>())
175 if (newInstance != NULL)
176 {
177 newInstance->SetVoid(NULL);
178 newInstance->InternalFinalConstructAddRef();
179 hResult = newInstance->_AtlInitialConstruct();
180 if (SUCCEEDED(hResult))
181 hResult = newInstance->FinalConstruct();
182 if (SUCCEEDED(hResult))
183 hResult = newInstance->_AtlFinalConstruct();
184 newInstance->InternalFinalConstructRelease();
185 if (hResult != S_OK)
186 {
187 delete newInstance;
188 newInstance = NULL;
189 }
190 }
191 *pp = newInstance;
192 return hResult;
193 }
194
195
196 };
197
198 template <class Base>
199 class CComContainedObject : public Base
200 {
201 public:
202 IUnknown* m_pUnkOuter;
203 CComContainedObject(void * pv = NULL) : m_pUnkOuter(static_cast<IUnknown*>(pv))
204 {
205 }
206
207 STDMETHOD_(ULONG, AddRef)()
208 {
209 return m_pUnkOuter->AddRef();
210 }
211
212 STDMETHOD_(ULONG, Release)()
213 {
214 return m_pUnkOuter->Release();
215 }
216
217 STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
218 {
219 return m_pUnkOuter->QueryInterface(iid, ppvObject);
220 }
221
222 IUnknown* GetControllingUnknown()
223 {
224 return m_pUnkOuter;
225 }
226 };
227
228 template <class contained>
229 class CComAggObject : public contained
230 {
231 public:
232 CComContainedObject<contained> m_contained;
233
234 CComAggObject(void * pv = NULL) : m_contained(static_cast<contained*>(pv))
235 {
236 _pAtlModule->Lock();
237 }
238
239 virtual ~CComAggObject()
240 {
241 this->FinalRelease();
242 _pAtlModule->Unlock();
243 }
244
245 HRESULT FinalConstruct()
246 {
247 return m_contained.FinalConstruct();
248 }
249 void FinalRelease()
250 {
251 m_contained.FinalRelease();
252 }
253
254 STDMETHOD_(ULONG, AddRef)()
255 {
256 return this->InternalAddRef();
257 }
258
259 STDMETHOD_(ULONG, Release)()
260 {
261 ULONG newRefCount;
262 newRefCount = this->InternalRelease();
263 if (newRefCount == 0)
264 delete this;
265 return newRefCount;
266 }
267
268 STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
269 {
270 if (ppvObject == NULL)
271 return E_POINTER;
272 if (iid == IID_IUnknown)
273 *ppvObject = reinterpret_cast<void*>(this);
274 else
275 return m_contained._InternalQueryInterface(iid, ppvObject);
276 return S_OK;
277 }
278
279 static HRESULT WINAPI CreateInstance(IUnknown * punkOuter, CComAggObject<contained> **pp)
280 {
281 CComAggObject<contained> *newInstance;
282 HRESULT hResult;
283
284 ATLASSERT(pp != NULL);
285 if (pp == NULL)
286 return E_POINTER;
287
288 hResult = E_OUTOFMEMORY;
289 newInstance = NULL;
290 ATLTRY(newInstance = new CComAggObject<contained>(punkOuter))
291 if (newInstance != NULL)
292 {
293 newInstance->SetVoid(NULL);
294 newInstance->InternalFinalConstructAddRef();
295 hResult = newInstance->_AtlInitialConstruct();
296 if (SUCCEEDED(hResult))
297 hResult = newInstance->FinalConstruct();
298 if (SUCCEEDED(hResult))
299 hResult = newInstance->_AtlFinalConstruct();
300 newInstance->InternalFinalConstructRelease();
301 if (hResult != S_OK)
302 {
303 delete newInstance;
304 newInstance = NULL;
305 }
306 }
307 *pp = newInstance;
308 return hResult;
309 }
310 };
311
312 template <class contained>
313 class CComPolyObject : public contained
314 {
315 public:
316 CComContainedObject<contained> m_contained;
317
318 CComPolyObject(void * pv = NULL)
319 : m_contained(pv ? static_cast<contained*>(pv) : this)
320 {
321 _pAtlModule->Lock();
322 }
323
324 virtual ~CComPolyObject()
325 {
326 this->FinalRelease();
327 _pAtlModule->Unlock();
328 }
329
330 HRESULT FinalConstruct()
331 {
332 return m_contained.FinalConstruct();
333 }
334 void FinalRelease()
335 {
336 m_contained.FinalRelease();
337 }
338
339 STDMETHOD_(ULONG, AddRef)()
340 {
341 return this->InternalAddRef();
342 }
343
344 STDMETHOD_(ULONG, Release)()
345 {
346 ULONG newRefCount;
347 newRefCount = this->InternalRelease();
348 if (newRefCount == 0)
349 delete this;
350 return newRefCount;
351 }
352
353 STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
354 {
355 if (ppvObject == NULL)
356 return E_POINTER;
357 if (iid == IID_IUnknown)
358 *ppvObject = reinterpret_cast<void*>(this);
359 else
360 return m_contained._InternalQueryInterface(iid, ppvObject);
361 return S_OK;
362 }
363
364 static HRESULT WINAPI CreateInstance(IUnknown * punkOuter, CComPolyObject<contained> **pp)
365 {
366 CComPolyObject<contained> *newInstance;
367 HRESULT hResult;
368
369 ATLASSERT(pp != NULL);
370 if (pp == NULL)
371 return E_POINTER;
372
373 hResult = E_OUTOFMEMORY;
374 newInstance = NULL;
375 ATLTRY(newInstance = new CComPolyObject<contained>(punkOuter))
376 if (newInstance != NULL)
377 {
378 newInstance->SetVoid(NULL);
379 newInstance->InternalFinalConstructAddRef();
380 hResult = newInstance->_AtlInitialConstruct();
381 if (SUCCEEDED(hResult))
382 hResult = newInstance->FinalConstruct();
383 if (SUCCEEDED(hResult))
384 hResult = newInstance->_AtlFinalConstruct();
385 newInstance->InternalFinalConstructRelease();
386 if (hResult != S_OK)
387 {
388 delete newInstance;
389 newInstance = NULL;
390 }
391 }
392 *pp = newInstance;
393 return hResult;
394 }
395 };
396
397 template <HRESULT hResult>
398 class CComFailCreator
399 {
400 public:
401 static HRESULT WINAPI CreateInstance(void *, REFIID, LPVOID *ppv)
402 {
403 ATLASSERT(ppv != NULL);
404 if (ppv == NULL)
405 return E_POINTER;
406 *ppv = NULL;
407
408 return hResult;
409 }
410 };
411
412 template <class T1>
413 class CComCreator
414 {
415 public:
416 static HRESULT WINAPI CreateInstance(void *pv, REFIID riid, LPVOID *ppv)
417 {
418 T1 *newInstance;
419 HRESULT hResult;
420
421 ATLASSERT(ppv != NULL);
422 if (ppv == NULL)
423 return E_POINTER;
424 *ppv = NULL;
425
426 hResult = E_OUTOFMEMORY;
427 newInstance = NULL;
428 ATLTRY(newInstance = new T1(pv))
429 if (newInstance != NULL)
430 {
431 newInstance->SetVoid(pv);
432 newInstance->InternalFinalConstructAddRef();
433 hResult = newInstance->_AtlInitialConstruct();
434 if (SUCCEEDED(hResult))
435 hResult = newInstance->FinalConstruct();
436 if (SUCCEEDED(hResult))
437 hResult = newInstance->_AtlFinalConstruct();
438 newInstance->InternalFinalConstructRelease();
439 if (SUCCEEDED(hResult))
440 hResult = newInstance->QueryInterface(riid, ppv);
441 if (FAILED(hResult))
442 {
443 delete newInstance;
444 newInstance = NULL;
445 }
446 }
447 return hResult;
448 }
449 };
450
451 template <class T1, class T2>
452 class CComCreator2
453 {
454 public:
455 static HRESULT WINAPI CreateInstance(void *pv, REFIID riid, LPVOID *ppv)
456 {
457 ATLASSERT(ppv != NULL && riid != NULL);
458
459 if (pv == NULL)
460 return T1::CreateInstance(NULL, riid, ppv);
461 else
462 return T2::CreateInstance(pv, riid, ppv);
463 }
464 };
465
466 template <class Base>
467 class CComObjectCached : public Base
468 {
469 public:
470 CComObjectCached(void * = NULL)
471 {
472 }
473
474 virtual ~CComObjectCached()
475 {
476 this->FinalRelease();
477 }
478
479 STDMETHOD_(ULONG, AddRef)()
480 {
481 ULONG newRefCount;
482
483 newRefCount = this->InternalAddRef();
484 if (newRefCount == 2)
485 _pAtlModule->Lock();
486 return newRefCount;
487 }
488
489 STDMETHOD_(ULONG, Release)()
490 {
491 ULONG newRefCount;
492
493 newRefCount = this->InternalRelease();
494 if (newRefCount == 0)
495 delete this;
496 else if (newRefCount == 1)
497 _pAtlModule->Unlock();
498 return newRefCount;
499 }
500
501 STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
502 {
503 return this->_InternalQueryInterface(iid, ppvObject);
504 }
505
506 static HRESULT WINAPI CreateInstance(CComObjectCached<Base> **pp)
507 {
508 CComObjectCached<Base> *newInstance;
509 HRESULT hResult;
510
511 ATLASSERT(pp != NULL);
512 if (pp == NULL)
513 return E_POINTER;
514
515 hResult = E_OUTOFMEMORY;
516 newInstance = NULL;
517 ATLTRY(newInstance = new CComObjectCached<Base>())
518 if (newInstance != NULL)
519 {
520 newInstance->SetVoid(NULL);
521 newInstance->InternalFinalConstructAddRef();
522 hResult = newInstance->_AtlInitialConstruct();
523 if (SUCCEEDED(hResult))
524 hResult = newInstance->FinalConstruct();
525 if (SUCCEEDED(hResult))
526 hResult = newInstance->_AtlFinalConstruct();
527 newInstance->InternalFinalConstructRelease();
528 if (hResult != S_OK)
529 {
530 delete newInstance;
531 newInstance = NULL;
532 }
533 }
534 *pp = newInstance;
535 return hResult;
536 }
537 };
538
539 #define BEGIN_COM_MAP(x) \
540 public: \
541 typedef x _ComMapClass; \
542 HRESULT _InternalQueryInterface(REFIID iid, void **ppvObject) \
543 { \
544 return this->InternalQueryInterface(this, _GetEntries(), iid, ppvObject); \
545 } \
546 const static ATL::_ATL_INTMAP_ENTRY *WINAPI _GetEntries() \
547 { \
548 static const ATL::_ATL_INTMAP_ENTRY _entries[] = {
549
550 #define END_COM_MAP() \
551 {NULL, 0, 0} \
552 }; \
553 return _entries; \
554 } \
555 virtual ULONG STDMETHODCALLTYPE AddRef() = 0; \
556 virtual ULONG STDMETHODCALLTYPE Release() = 0; \
557 STDMETHOD(QueryInterface)(REFIID, void **) = 0;
558
559 #define COM_INTERFACE_ENTRY_IID(iid, x) \
560 {&iid, offsetofclass(x, _ComMapClass), _ATL_SIMPLEMAPENTRY},
561
562 #define COM_INTERFACE_ENTRY(x) \
563 {&_ATL_IIDOF(x), \
564 offsetofclass(x, _ComMapClass), \
565 _ATL_SIMPLEMAPENTRY},
566
567 #define COM_INTERFACE_ENTRY2_IID(iid, x, x2) \
568 {&iid, \
569 reinterpret_cast<DWORD_PTR>(static_cast<x *>(static_cast<x2 *>(reinterpret_cast<_ComMapClass *>(_ATL_PACKING)))) - _ATL_PACKING, \
570 _ATL_SIMPLEMAPENTRY},
571
572 #define COM_INTERFACE_ENTRY_BREAK(x) \
573 {&_ATL_IIDOF(x), \
574 NULL, \
575 _Break}, // Break is a function that issues int 3.
576
577 #define COM_INTERFACE_ENTRY_NOINTERFACE(x) \
578 {&_ATL_IIDOF(x), \
579 NULL, \
580 _NoInterface}, // NoInterface returns E_NOINTERFACE.
581
582 #define COM_INTERFACE_ENTRY_FUNC(iid, dw, func) \
583 {&iid, \
584 dw, \
585 func},
586
587 #define COM_INTERFACE_ENTRY_FUNC_BLIND(dw, func) \
588 {NULL, \
589 dw, \
590 func},
591
592 #define COM_INTERFACE_ENTRY_CHAIN(classname) \
593 {NULL, \
594 reinterpret_cast<DWORD>(&_CComChainData<classname, _ComMapClass>::data), \
595 _Chain},
596
597 #define DECLARE_NO_REGISTRY()\
598 static HRESULT WINAPI UpdateRegistry(BOOL /*bRegister*/) \
599 { \
600 return S_OK; \
601 }
602
603 #define DECLARE_REGISTRY_RESOURCEID(x) \
604 static HRESULT WINAPI UpdateRegistry(BOOL bRegister) \
605 { \
606 return ATL::_pAtlModule->UpdateRegistryFromResource(x, bRegister); \
607 }
608
609 #define DECLARE_NOT_AGGREGATABLE(x) \
610 public: \
611 typedef ATL::CComCreator2<ATL::CComCreator<ATL::CComObject<x> >, ATL::CComFailCreator<CLASS_E_NOAGGREGATION> > _CreatorClass;
612
613 #define DECLARE_AGGREGATABLE(x) \
614 public: \
615 typedef ATL::CComCreator2<ATL::CComCreator<ATL::CComObject<x> >, ATL::CComCreator<ATL::CComAggObject<x> > > _CreatorClass;
616
617 #define DECLARE_ONLY_AGGREGATABLE(x) \
618 public: \
619 typedef ATL::CComCreator2<ATL::CComFailCreator<E_FAIL>, ATL::CComCreator<ATL::CComAggObject<x> > > _CreatorClass;
620
621 #define DECLARE_POLY_AGGREGATABLE(x) \
622 public: \
623 typedef ATL::CComCreator<ATL::CComPolyObject<x> > _CreatorClass;
624
625 #define COM_INTERFACE_ENTRY_AGGREGATE(iid, punk) \
626 {&iid, \
627 (DWORD_PTR)offsetof(_ComMapClass, punk), \
628 _Delegate},
629
630 #define DECLARE_GET_CONTROLLING_UNKNOWN() \
631 public: \
632 virtual IUnknown *GetControllingUnknown() \
633 { \
634 return GetUnknown(); \
635 }
636
637 #define DECLARE_PROTECT_FINAL_CONSTRUCT() \
638 void InternalFinalConstructAddRef() \
639 { \
640 InternalAddRef(); \
641 } \
642 void InternalFinalConstructRelease() \
643 { \
644 InternalRelease(); \
645 }
646
647 #define BEGIN_OBJECT_MAP(x) static ATL::_ATL_OBJMAP_ENTRY x[] = {
648
649 #define END_OBJECT_MAP() {NULL, NULL, NULL, NULL, NULL, 0, NULL, NULL, NULL}};
650
651 #define OBJECT_ENTRY(clsid, class) \
652 { \
653 &clsid, \
654 class::UpdateRegistry, \
655 class::_ClassFactoryCreatorClass::CreateInstance, \
656 class::_CreatorClass::CreateInstance, \
657 NULL, \
658 0, \
659 class::GetObjectDescription, \
660 class::GetCategoryMap, \
661 class::ObjectMain },
662
663 class CComClassFactory :
664 public IClassFactory,
665 public CComObjectRootEx<CComGlobalsThreadModel>
666 {
667 public:
668 _ATL_CREATORFUNC *m_pfnCreateInstance;
669
670 virtual ~CComClassFactory()
671 {
672 }
673
674 public:
675 STDMETHOD(CreateInstance)(LPUNKNOWN pUnkOuter, REFIID riid, void **ppvObj)
676 {
677 HRESULT hResult;
678
679 ATLASSERT(m_pfnCreateInstance != NULL);
680
681 if (ppvObj == NULL)
682 return E_POINTER;
683 *ppvObj = NULL;
684
685 if (pUnkOuter != NULL && InlineIsEqualUnknown(riid) == FALSE)
686 hResult = CLASS_E_NOAGGREGATION;
687 else
688 hResult = m_pfnCreateInstance(pUnkOuter, riid, ppvObj);
689 return hResult;
690 }
691
692 STDMETHOD(LockServer)(BOOL fLock)
693 {
694 if (fLock)
695 _pAtlModule->Lock();
696 else
697 _pAtlModule->Unlock();
698 return S_OK;
699 }
700
701 void SetVoid(void *pv)
702 {
703 m_pfnCreateInstance = (_ATL_CREATORFUNC *)pv;
704 }
705
706 BEGIN_COM_MAP(CComClassFactory)
707 COM_INTERFACE_ENTRY_IID(IID_IClassFactory, IClassFactory)
708 END_COM_MAP()
709 };
710
711 template <class T>
712 class CComClassFactorySingleton :
713 public CComClassFactory
714 {
715 public:
716 HRESULT m_hrCreate;
717 IUnknown *m_spObj;
718
719 public:
720 CComClassFactorySingleton() :
721 m_hrCreate(S_OK),
722 m_spObj(NULL)
723 {
724 }
725
726 STDMETHOD(CreateInstance)(LPUNKNOWN pUnkOuter, REFIID riid, void **ppvObj)
727 {
728 HRESULT hResult;
729
730 if (ppvObj == NULL)
731 return E_POINTER;
732 *ppvObj = NULL;
733
734 if (pUnkOuter != NULL)
735 hResult = CLASS_E_NOAGGREGATION;
736 else if (m_hrCreate == S_OK && m_spObj == NULL)
737 {
738 _SEH2_TRY
739 {
740 Lock();
741 if (m_hrCreate == S_OK && m_spObj == NULL)
742 {
743 CComObjectCached<T> *pObj;
744 m_hrCreate = CComObjectCached<T>::CreateInstance(&pObj);
745 if (SUCCEEDED(m_hrCreate))
746 {
747 m_hrCreate = pObj->QueryInterface(IID_IUnknown, reinterpret_cast<PVOID *>(&m_spObj));
748 if (FAILED(m_hrCreate))
749 delete pObj;
750 }
751 }
752 }
753 _SEH2_FINALLY
754 {
755 Unlock();
756 }
757 _SEH2_END;
758 }
759 if (m_hrCreate == S_OK)
760 hResult = m_spObj->QueryInterface(riid, ppvObj);
761 else
762 hResult = m_hrCreate;
763 return hResult;
764 }
765 };
766
767 template <class T, const CLSID *pclsid = &CLSID_NULL>
768 class CComCoClass
769 {
770 public:
771 DECLARE_CLASSFACTORY()
772
773 static LPCTSTR WINAPI GetObjectDescription()
774 {
775 return NULL;
776 }
777 };
778
779 template <class T>
780 class _Copy
781 {
782 public:
783 static HRESULT copy(T *pTo, const T *pFrom)
784 {
785 memcpy(pTo, pFrom, sizeof(T));
786 return S_OK;
787 }
788
789 static void init(T *)
790 {
791 }
792
793 static void destroy(T *)
794 {
795 }
796 };
797
798 template<>
799 class _Copy<CONNECTDATA>
800 {
801 public:
802 static HRESULT copy(CONNECTDATA *pTo, const CONNECTDATA *pFrom)
803 {
804 *pTo = *pFrom;
805 if (pTo->pUnk)
806 pTo->pUnk->AddRef();
807 return S_OK;
808 }
809
810 static void init(CONNECTDATA *)
811 {
812 }
813
814 static void destroy(CONNECTDATA *p)
815 {
816 if (p->pUnk)
817 p->pUnk->Release();
818 }
819 };
820
821 template <class T>
822 class _CopyInterface
823 {
824 public:
825 static HRESULT copy(T **pTo, T **pFrom)
826 {
827 *pTo = *pFrom;
828 if (*pTo)
829 (*pTo)->AddRef();
830 return S_OK;
831 }
832
833 static void init(T **)
834 {
835 }
836
837 static void destroy(T **p)
838 {
839 if (*p)
840 (*p)->Release();
841 }
842 };
843
844 enum CComEnumFlags
845 {
846 AtlFlagNoCopy = 0,
847 AtlFlagTakeOwnership = 2, // BitOwn
848 AtlFlagCopy = 3 // BitOwn | BitCopy
849 };
850
851 template <class Base, const IID *piid, class T, class Copy>
852 class CComEnumImpl : public Base
853 {
854 private:
855 typedef CComObject<CComEnum<Base, piid, T, Copy> > enumeratorClass;
856 public:
857 CComPtr<IUnknown> m_spUnk;
858 DWORD m_dwFlags;
859 T *m_begin;
860 T *m_end;
861 T *m_iter;
862 public:
863 CComEnumImpl()
864 {
865 m_dwFlags = 0;
866 m_begin = NULL;
867 m_end = NULL;
868 m_iter = NULL;
869 }
870
871 virtual ~CComEnumImpl()
872 {
873 T *x;
874
875 if ((m_dwFlags & BitOwn) != 0)
876 {
877 for (x = m_begin; x != m_end; x++)
878 Copy::destroy(x);
879 delete [] m_begin;
880 }
881 }
882
883 HRESULT Init(T *begin, T *end, IUnknown *pUnk, CComEnumFlags flags = AtlFlagNoCopy)
884 {
885 T *newBuffer;
886 T *sourcePtr;
887 T *destPtr;
888 T *cleanupPtr;
889 HRESULT hResult;
890
891 if (flags == AtlFlagCopy)
892 {
893 ATLTRY(newBuffer = new T[end - begin])
894 if (newBuffer == NULL)
895 return E_OUTOFMEMORY;
896 destPtr = newBuffer;
897 for (sourcePtr = begin; sourcePtr != end; sourcePtr++)
898 {
899 Copy::init(destPtr);
900 hResult = Copy::copy(destPtr, sourcePtr);
901 if (FAILED(hResult))
902 {
903 cleanupPtr = m_begin;
904 while (cleanupPtr < destPtr)
905 Copy::destroy(cleanupPtr++);
906 delete [] newBuffer;
907 return hResult;
908 }
909 destPtr++;
910 }
911 m_begin = newBuffer;
912 m_end = m_begin + (end - begin);
913 }
914 else
915 {
916 m_begin = begin;
917 m_end = end;
918 }
919 m_spUnk = pUnk;
920 m_dwFlags = flags;
921 m_iter = m_begin;
922 return S_OK;
923 }
924
925 STDMETHOD(Next)(ULONG celt, T *rgelt, ULONG *pceltFetched)
926 {
927 ULONG numAvailable;
928 ULONG numToFetch;
929 T *rgeltTemp;
930 HRESULT hResult;
931
932 if (pceltFetched != NULL)
933 *pceltFetched = 0;
934 if (celt == 0)
935 return E_INVALIDARG;
936 if (rgelt == NULL || (celt != 1 && pceltFetched == NULL))
937 return E_POINTER;
938 if (m_begin == NULL || m_end == NULL || m_iter == NULL)
939 return E_FAIL;
940
941 numAvailable = static_cast<ULONG>(m_end - m_iter);
942 if (celt < numAvailable)
943 numToFetch = celt;
944 else
945 numToFetch = numAvailable;
946 if (pceltFetched != NULL)
947 *pceltFetched = numToFetch;
948 rgeltTemp = rgelt;
949 while (numToFetch != 0)
950 {
951 hResult = Copy::copy(rgeltTemp, m_iter);
952 if (FAILED(hResult))
953 {
954 while (rgelt < rgeltTemp)
955 Copy::destroy(rgelt++);
956 if (pceltFetched != NULL)
957 *pceltFetched = 0;
958 return hResult;
959 }
960 rgeltTemp++;
961 m_iter++;
962 numToFetch--;
963 }
964 if (numAvailable < celt)
965 return S_FALSE;
966 return S_OK;
967 }
968
969 STDMETHOD(Skip)(ULONG celt)
970 {
971 ULONG numAvailable;
972 ULONG numToSkip;
973
974 if (celt == 0)
975 return E_INVALIDARG;
976
977 numAvailable = static_cast<ULONG>(m_end - m_iter);
978 if (celt < numAvailable)
979 numToSkip = celt;
980 else
981 numToSkip = numAvailable;
982 m_iter += numToSkip;
983 if (numAvailable < celt)
984 return S_FALSE;
985 return S_OK;
986 }
987
988 STDMETHOD(Reset)()
989 {
990 m_iter = m_begin;
991 return S_OK;
992 }
993
994 STDMETHOD(Clone)(Base **ppEnum)
995 {
996 enumeratorClass *newInstance;
997 HRESULT hResult;
998
999 hResult = E_POINTER;
1000 if (ppEnum != NULL)
1001 {
1002 *ppEnum = NULL;
1003 hResult = enumeratorClass::CreateInstance(&newInstance);
1004 if (SUCCEEDED(hResult))
1005 {
1006 hResult = newInstance->Init(m_begin, m_end, (m_dwFlags & BitOwn) ? this : m_spUnk);
1007 if (SUCCEEDED(hResult))
1008 {
1009 newInstance->m_iter = m_iter;
1010 hResult = newInstance->_InternalQueryInterface(*piid, (void **)ppEnum);
1011 }
1012 if (FAILED(hResult))
1013 delete newInstance;
1014 }
1015 }
1016 return hResult;
1017 }
1018
1019 protected:
1020 enum FlagBits
1021 {
1022 BitCopy = 1,
1023 BitOwn = 2
1024 };
1025 };
1026
1027 template <class Base, const IID *piid, class T, class Copy, class ThreadModel>
1028 class CComEnum :
1029 public CComEnumImpl<Base, piid, T, Copy>,
1030 public CComObjectRootEx<ThreadModel>
1031 {
1032 public:
1033 typedef CComEnum<Base, piid, T, Copy > _CComEnum;
1034 typedef CComEnumImpl<Base, piid, T, Copy > _CComEnumBase;
1035
1036 BEGIN_COM_MAP(_CComEnum)
1037 COM_INTERFACE_ENTRY_IID(*piid, _CComEnumBase)
1038 END_COM_MAP()
1039 };
1040
1041 #ifndef _DEFAULT_VECTORLENGTH
1042 #define _DEFAULT_VECTORLENGTH 4
1043 #endif
1044
1045 class CComDynamicUnkArray
1046 {
1047 public:
1048 int m_nSize;
1049 IUnknown **m_ppUnk;
1050 public:
1051 CComDynamicUnkArray()
1052 {
1053 m_nSize = 0;
1054 m_ppUnk = NULL;
1055 }
1056
1057 ~CComDynamicUnkArray()
1058 {
1059 free(m_ppUnk);
1060 }
1061
1062 IUnknown **begin()
1063 {
1064 return m_ppUnk;
1065 }
1066
1067 IUnknown **end()
1068 {
1069 return &m_ppUnk[m_nSize];
1070 }
1071
1072 IUnknown *GetAt(int nIndex)
1073 {
1074 ATLASSERT(nIndex >= 0 && nIndex < m_nSize);
1075 if (nIndex >= 0 && nIndex < m_nSize)
1076 return m_ppUnk[nIndex];
1077 else
1078 return NULL;
1079 }
1080
1081 IUnknown *WINAPI GetUnknown(DWORD dwCookie)
1082 {
1083 ATLASSERT(dwCookie != 0 && dwCookie <= static_cast<DWORD>(m_nSize));
1084 if (dwCookie != 0 && dwCookie <= static_cast<DWORD>(m_nSize))
1085 return GetAt(dwCookie - 1);
1086 else
1087 return NULL;
1088 }
1089
1090 DWORD WINAPI GetCookie(IUnknown **ppFind)
1091 {
1092 IUnknown **x;
1093 DWORD curCookie;
1094
1095 ATLASSERT(ppFind != NULL && *ppFind != NULL);
1096 if (ppFind != NULL && *ppFind != NULL)
1097 {
1098 curCookie = 1;
1099 for (x = begin(); x < end(); x++)
1100 {
1101 if (*x == *ppFind)
1102 return curCookie;
1103 curCookie++;
1104 }
1105 }
1106 return 0;
1107 }
1108
1109 DWORD Add(IUnknown *pUnk)
1110 {
1111 IUnknown **x;
1112 IUnknown **newArray;
1113 int newSize;
1114 DWORD curCookie;
1115
1116 ATLASSERT(pUnk != NULL);
1117 if (m_nSize == 0)
1118 {
1119 newSize = _DEFAULT_VECTORLENGTH * sizeof(IUnknown *);
1120 ATLTRY(newArray = reinterpret_cast<IUnknown **>(malloc(newSize)));
1121 if (newArray == NULL)
1122 return 0;
1123 memset(newArray, 0, newSize);
1124 m_ppUnk = newArray;
1125 m_nSize = _DEFAULT_VECTORLENGTH;
1126 }
1127 curCookie = 1;
1128 for (x = begin(); x < end(); x++)
1129 {
1130 if (*x == NULL)
1131 {
1132 *x = pUnk;
1133 return curCookie;
1134 }
1135 curCookie++;
1136 }
1137 newSize = m_nSize * 2;
1138 newArray = reinterpret_cast<IUnknown **>(realloc(m_ppUnk, newSize * sizeof(IUnknown *)));
1139 if (newArray == NULL)
1140 return 0;
1141 m_ppUnk = newArray;
1142 memset(&m_ppUnk[m_nSize], 0, (newSize - m_nSize) * sizeof(IUnknown *));
1143 curCookie = m_nSize + 1;
1144 m_nSize = newSize;
1145 m_ppUnk[curCookie - 1] = pUnk;
1146 return curCookie;
1147 }
1148
1149 BOOL Remove(DWORD dwCookie)
1150 {
1151 DWORD index;
1152
1153 index = dwCookie - 1;
1154 ATLASSERT(index < dwCookie && index < static_cast<DWORD>(m_nSize));
1155 if (index < dwCookie && index < static_cast<DWORD>(m_nSize) && m_ppUnk[index] != NULL)
1156 {
1157 m_ppUnk[index] = NULL;
1158 return TRUE;
1159 }
1160 return FALSE;
1161 }
1162
1163 private:
1164 CComDynamicUnkArray &operator = (const CComDynamicUnkArray &)
1165 {
1166 return *this;
1167 }
1168
1169 CComDynamicUnkArray(const CComDynamicUnkArray &)
1170 {
1171 }
1172 };
1173
1174 struct _ATL_CONNMAP_ENTRY
1175 {
1176 DWORD_PTR dwOffset;
1177 };
1178
1179 template <const IID *piid>
1180 class _ICPLocator
1181 {
1182 public:
1183 STDMETHOD(_LocCPQueryInterface)(REFIID riid, void **ppvObject) = 0;
1184 virtual ULONG STDMETHODCALLTYPE AddRef() = 0;
1185 virtual ULONG STDMETHODCALLTYPE Release() = 0;
1186 };
1187
1188 template<class T, const IID *piid, class CDV = CComDynamicUnkArray>
1189 class IConnectionPointImpl : public _ICPLocator<piid>
1190 {
1191 typedef CComEnum<IEnumConnections, &IID_IEnumConnections, CONNECTDATA, _Copy<CONNECTDATA> > CComEnumConnections;
1192 public:
1193 CDV m_vec;
1194 public:
1195 ~IConnectionPointImpl()
1196 {
1197 IUnknown **x;
1198
1199 for (x = m_vec.begin(); x < m_vec.end(); x++)
1200 if (*x != NULL)
1201 (*x)->Release();
1202 }
1203
1204 STDMETHOD(_LocCPQueryInterface)(REFIID riid, void **ppvObject)
1205 {
1206 IConnectionPointImpl<T, piid, CDV> *pThis;
1207
1208 pThis = reinterpret_cast<IConnectionPointImpl<T, piid, CDV>*>(this);
1209
1210 ATLASSERT(ppvObject != NULL);
1211 if (ppvObject == NULL)
1212 return E_POINTER;
1213
1214 if (InlineIsEqualGUID(riid, IID_IConnectionPoint) || InlineIsEqualUnknown(riid))
1215 {
1216 *ppvObject = this;
1217 pThis->AddRef();
1218 return S_OK;
1219 }
1220 else
1221 {
1222 *ppvObject = NULL;
1223 return E_NOINTERFACE;
1224 }
1225 }
1226
1227 STDMETHOD(GetConnectionInterface)(IID *piid2)
1228 {
1229 if (piid2 == NULL)
1230 return E_POINTER;
1231 *piid2 = *piid;
1232 return S_OK;
1233 }
1234
1235 STDMETHOD(GetConnectionPointContainer)(IConnectionPointContainer **ppCPC)
1236 {
1237 T *pThis;
1238
1239 pThis = static_cast<T *>(this);
1240 return pThis->QueryInterface(IID_IConnectionPointContainer, reinterpret_cast<void **>(ppCPC));
1241 }
1242
1243 STDMETHOD(Advise)(IUnknown *pUnkSink, DWORD *pdwCookie)
1244 {
1245 IUnknown *adviseTarget;
1246 IID interfaceID;
1247 HRESULT hResult;
1248
1249 if (pdwCookie != NULL)
1250 *pdwCookie = 0;
1251 if (pUnkSink == NULL || pdwCookie == NULL)
1252 return E_POINTER;
1253 GetConnectionInterface(&interfaceID); // can't fail
1254 hResult = pUnkSink->QueryInterface(interfaceID, reinterpret_cast<void **>(&adviseTarget));
1255 if (SUCCEEDED(hResult))
1256 {
1257 *pdwCookie = m_vec.Add(adviseTarget);
1258 if (*pdwCookie != 0)
1259 hResult = S_OK;
1260 else
1261 {
1262 adviseTarget->Release();
1263 hResult = CONNECT_E_ADVISELIMIT;
1264 }
1265 }
1266 else if (hResult == E_NOINTERFACE)
1267 hResult = CONNECT_E_CANNOTCONNECT;
1268 return hResult;
1269 }
1270
1271 STDMETHOD(Unadvise)(DWORD dwCookie)
1272 {
1273 IUnknown *adviseTarget;
1274 HRESULT hResult;
1275
1276 adviseTarget = m_vec.GetUnknown(dwCookie);
1277 if (m_vec.Remove(dwCookie))
1278 {
1279 if (adviseTarget != NULL)
1280 adviseTarget->Release();
1281 hResult = S_OK;
1282 }
1283 else
1284 hResult = CONNECT_E_NOCONNECTION;
1285 return hResult;
1286 }
1287
1288 STDMETHOD(EnumConnections)(IEnumConnections **ppEnum)
1289 {
1290 CComObject<CComEnumConnections> *newEnumerator;
1291 CONNECTDATA *itemBuffer;
1292 CONNECTDATA *itemBufferEnd;
1293 IUnknown **x;
1294 HRESULT hResult;
1295
1296 ATLASSERT(ppEnum != NULL);
1297 if (ppEnum == NULL)
1298 return E_POINTER;
1299 *ppEnum = NULL;
1300
1301 ATLTRY(itemBuffer = new CONNECTDATA[m_vec.end() - m_vec.begin()])
1302 if (itemBuffer == NULL)
1303 return E_OUTOFMEMORY;
1304 itemBufferEnd = itemBuffer;
1305 for (x = m_vec.begin(); x < m_vec.end(); x++)
1306 {
1307 if (*x != NULL)
1308 {
1309 (*x)->AddRef();
1310 itemBufferEnd->pUnk = *x;
1311 itemBufferEnd->dwCookie = m_vec.GetCookie(x);
1312 itemBufferEnd++;
1313 }
1314 }
1315 ATLTRY(newEnumerator = new CComObject<CComEnumConnections>)
1316 if (newEnumerator == NULL)
1317 return E_OUTOFMEMORY;
1318 newEnumerator->Init(itemBuffer, itemBufferEnd, NULL, AtlFlagTakeOwnership); // can't fail
1319 hResult = newEnumerator->_InternalQueryInterface(IID_IEnumConnections, (void **)ppEnum);
1320 if (FAILED(hResult))
1321 delete newEnumerator;
1322 return hResult;
1323 }
1324 };
1325
1326 template <class T>
1327 class IConnectionPointContainerImpl : public IConnectionPointContainer
1328 {
1329 typedef const _ATL_CONNMAP_ENTRY * (*handlerFunctionType)(int *);
1330 typedef CComEnum<IEnumConnectionPoints, &IID_IEnumConnectionPoints, IConnectionPoint *, _CopyInterface<IConnectionPoint> >
1331 CComEnumConnectionPoints;
1332
1333 public:
1334 STDMETHOD(EnumConnectionPoints)(IEnumConnectionPoints **ppEnum)
1335 {
1336 const _ATL_CONNMAP_ENTRY *entryPtr;
1337 int connectionPointCount;
1338 IConnectionPoint **itemBuffer;
1339 int destIndex;
1340 handlerFunctionType handlerFunction;
1341 CComEnumConnectionPoints *newEnumerator;
1342 HRESULT hResult;
1343
1344 ATLASSERT(ppEnum != NULL);
1345 if (ppEnum == NULL)
1346 return E_POINTER;
1347 *ppEnum = NULL;
1348
1349 entryPtr = T::GetConnMap(&connectionPointCount);
1350 ATLTRY(itemBuffer = new IConnectionPoint * [connectionPointCount])
1351 if (itemBuffer == NULL)
1352 return E_OUTOFMEMORY;
1353
1354 destIndex = 0;
1355 while (entryPtr->dwOffset != static_cast<DWORD_PTR>(-1))
1356 {
1357 if (entryPtr->dwOffset == static_cast<DWORD_PTR>(-2))
1358 {
1359 entryPtr++;
1360 handlerFunction = reinterpret_cast<handlerFunctionType>(entryPtr->dwOffset);
1361 entryPtr = handlerFunction(NULL);
1362 }
1363 else
1364 {
1365 itemBuffer[destIndex++] = reinterpret_cast<IConnectionPoint *>((char *)this + entryPtr->dwOffset);
1366 entryPtr++;
1367 }
1368 }
1369
1370 ATLTRY(newEnumerator = new CComObject<CComEnumConnectionPoints>)
1371 if (newEnumerator == NULL)
1372 {
1373 delete [] itemBuffer;
1374 return E_OUTOFMEMORY;
1375 }
1376
1377 newEnumerator->Init(&itemBuffer[0], &itemBuffer[destIndex], NULL, AtlFlagTakeOwnership); // can't fail
1378 hResult = newEnumerator->QueryInterface(IID_IEnumConnectionPoints, (void**)ppEnum);
1379 if (FAILED(hResult))
1380 delete newEnumerator;
1381 return hResult;
1382 }
1383
1384 STDMETHOD(FindConnectionPoint)(REFIID riid, IConnectionPoint **ppCP)
1385 {
1386 IID interfaceID;
1387 const _ATL_CONNMAP_ENTRY *entryPtr;
1388 handlerFunctionType handlerFunction;
1389 IConnectionPoint *connectionPoint;
1390 HRESULT hResult;
1391
1392 if (ppCP == NULL)
1393 return E_POINTER;
1394 *ppCP = NULL;
1395 hResult = CONNECT_E_NOCONNECTION;
1396 entryPtr = T::GetConnMap(NULL);
1397 while (entryPtr->dwOffset != static_cast<DWORD_PTR>(-1))
1398 {
1399 if (entryPtr->dwOffset == static_cast<DWORD_PTR>(-2))
1400 {
1401 entryPtr++;
1402 handlerFunction = reinterpret_cast<handlerFunctionType>(entryPtr->dwOffset);
1403 entryPtr = handlerFunction(NULL);
1404 }
1405 else
1406 {
1407 connectionPoint = reinterpret_cast<IConnectionPoint *>(reinterpret_cast<char *>(this) + entryPtr->dwOffset);
1408 if (SUCCEEDED(connectionPoint->GetConnectionInterface(&interfaceID)) && InlineIsEqualGUID(riid, interfaceID))
1409 {
1410 *ppCP = connectionPoint;
1411 connectionPoint->AddRef();
1412 hResult = S_OK;
1413 break;
1414 }
1415 entryPtr++;
1416 }
1417 }
1418 return hResult;
1419 }
1420 };
1421
1422 #define BEGIN_CONNECTION_POINT_MAP(x) \
1423 typedef x _atl_conn_classtype; \
1424 static const ATL::_ATL_CONNMAP_ENTRY *GetConnMap(int *pnEntries) { \
1425 static const ATL::_ATL_CONNMAP_ENTRY _entries[] = {
1426
1427 #define END_CONNECTION_POINT_MAP() \
1428 {(DWORD_PTR)-1} }; \
1429 if (pnEntries) \
1430 *pnEntries = sizeof(_entries) / sizeof(ATL::_ATL_CONNMAP_ENTRY) - 1; \
1431 return _entries;}
1432
1433 #define CONNECTION_POINT_ENTRY(iid) \
1434 {offsetofclass(ATL::_ICPLocator<&iid>, _atl_conn_classtype) - \
1435 offsetofclass(ATL::IConnectionPointContainerImpl<_atl_conn_classtype>, _atl_conn_classtype)},
1436
1437 }; // namespace ATL