e7e6d34faf7d9df2d57f644f31a42a82cc6b367e
[reactos.git] / reactos / lib / atl / atlcom.h
1 /*
2 * ReactOS ATL
3 *
4 * Copyright 2009 Andrew Hill <ash77@reactos.org>
5 *
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21 #ifndef _atlcom_h
22 #define _atlcom_h
23
24 namespace ATL
25 {
26
27 template <class Base, const IID *piid, class T, class Copy, class ThreadModel = CComObjectThreadModel>
28 class CComEnum;
29
30 #define DECLARE_CLASSFACTORY_EX(cf) typedef ATL::CComCreator<ATL::CComObjectCached<cf> > _ClassFactoryCreatorClass;
31 #define DECLARE_CLASSFACTORY() DECLARE_CLASSFACTORY_EX(ATL::CComClassFactory)
32
33 class CComObjectRootBase
34 {
35 public:
36 long m_dwRef;
37 public:
38 CComObjectRootBase()
39 {
40 m_dwRef = 0;
41 }
42
43 ~CComObjectRootBase()
44 {
45 }
46
47 void SetVoid(void *)
48 {
49 }
50
51 HRESULT _AtlFinalConstruct()
52 {
53 return S_OK;
54 }
55
56 HRESULT FinalConstruct()
57 {
58 return S_OK;
59 }
60
61 void InternalFinalConstructAddRef()
62 {
63 }
64
65 void InternalFinalConstructRelease()
66 {
67 }
68
69 void FinalRelease()
70 {
71 }
72
73 static void WINAPI ObjectMain(bool)
74 {
75 }
76
77 static const struct _ATL_CATMAP_ENTRY *GetCategoryMap()
78 {
79 return NULL;
80 }
81
82 static HRESULT WINAPI InternalQueryInterface(void *pThis, const _ATL_INTMAP_ENTRY *pEntries, REFIID iid, void **ppvObject)
83 {
84 return AtlInternalQueryInterface(pThis, pEntries, iid, ppvObject);
85 }
86
87 };
88
89 template <class ThreadModel>
90 class CComObjectRootEx : public CComObjectRootBase
91 {
92 private:
93 typename ThreadModel::AutoDeleteCriticalSection m_critsec;
94 public:
95 ~CComObjectRootEx()
96 {
97 }
98
99 ULONG InternalAddRef()
100 {
101 ATLASSERT(m_dwRef >= 0);
102 return ThreadModel::Increment(reinterpret_cast<int *>(&m_dwRef));
103 }
104
105 ULONG InternalRelease()
106 {
107 ATLASSERT(m_dwRef > 0);
108 return ThreadModel::Decrement(reinterpret_cast<int *>(&m_dwRef));
109 }
110
111 void Lock()
112 {
113 m_critsec.Lock();
114 }
115
116 void Unlock()
117 {
118 m_critsec.Unlock();
119 }
120
121 HRESULT _AtlInitialConstruct()
122 {
123 return m_critsec.Init();
124 }
125 };
126
127 template <class Base>
128 class CComObject : public Base
129 {
130 public:
131 CComObject(void * = NULL)
132 {
133 _pAtlModule->Lock();
134 }
135
136 virtual ~CComObject()
137 {
138 CComObject<Base> *pThis;
139
140 pThis = reinterpret_cast<CComObject<Base> *>(this);
141 pThis->FinalRelease();
142 _pAtlModule->Unlock();
143 }
144
145 STDMETHOD_(ULONG, AddRef)()
146 {
147 CComObject<Base> *pThis;
148
149 pThis = reinterpret_cast<CComObject<Base> *>(this);
150 return pThis->InternalAddRef();
151 }
152
153 STDMETHOD_(ULONG, Release)()
154 {
155 CComObject<Base> *pThis;
156 ULONG l;
157
158 pThis = reinterpret_cast<CComObject<Base> *>(this);
159 l = pThis->InternalRelease();
160 if (l == 0)
161 delete this;
162 return l;
163 }
164
165 STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
166 {
167 CComObject<Base> *pThis;
168
169 pThis = reinterpret_cast<CComObject<Base> *>(this);
170 return pThis->_InternalQueryInterface(iid, ppvObject);
171 }
172
173 static HRESULT WINAPI CreateInstance(CComObject<Base> **pp)
174 {
175 CComObject<Base> *newInstance;
176 HRESULT hResult;
177
178 ATLASSERT(pp != NULL);
179 if (pp == NULL)
180 return E_POINTER;
181
182 hResult = E_OUTOFMEMORY;
183 newInstance = NULL;
184 ATLTRY(newInstance = new CComObject<Base>())
185 if (newInstance != NULL)
186 {
187 newInstance->SetVoid(NULL);
188 newInstance->InternalFinalConstructAddRef();
189 hResult = newInstance->_AtlInitialConstruct();
190 if (SUCCEEDED(hResult))
191 hResult = newInstance->FinalConstruct();
192 if (SUCCEEDED(hResult))
193 hResult = newInstance->_AtlFinalConstruct();
194 newInstance->InternalFinalConstructRelease();
195 if (hResult != S_OK)
196 {
197 delete newInstance;
198 newInstance = NULL;
199 }
200 }
201 *pp = newInstance;
202 return hResult;
203 }
204
205
206 };
207
208 template <HRESULT hResult>
209 class CComFailCreator
210 {
211 public:
212 static HRESULT WINAPI CreateInstance(void *, REFIID, LPVOID *ppv)
213 {
214 ATLASSERT(ppv != NULL);
215 if (ppv == NULL)
216 return E_POINTER;
217 *ppv = NULL;
218
219 return hResult;
220 }
221 };
222
223 template <class T1>
224 class CComCreator
225 {
226 public:
227 static HRESULT WINAPI CreateInstance(void *pv, REFIID riid, LPVOID *ppv)
228 {
229 T1 *newInstance;
230 HRESULT hResult;
231
232 ATLASSERT(ppv != NULL);
233 if (ppv == NULL)
234 return E_POINTER;
235 *ppv = NULL;
236
237 hResult = E_OUTOFMEMORY;
238 newInstance = NULL;
239 ATLTRY(newInstance = new T1())
240 if (newInstance != NULL)
241 {
242 newInstance->SetVoid(pv);
243 newInstance->InternalFinalConstructAddRef();
244 hResult = newInstance->_AtlInitialConstruct();
245 if (SUCCEEDED(hResult))
246 hResult = newInstance->FinalConstruct();
247 if (SUCCEEDED(hResult))
248 hResult = newInstance->_AtlFinalConstruct();
249 newInstance->InternalFinalConstructRelease();
250 if (SUCCEEDED(hResult))
251 hResult = newInstance->QueryInterface(riid, ppv);
252 if (FAILED(hResult))
253 {
254 delete newInstance;
255 newInstance = NULL;
256 }
257 }
258 return hResult;
259 }
260 };
261
262 template <class T1, class T2>
263 class CComCreator2
264 {
265 public:
266 static HRESULT WINAPI CreateInstance(void *pv, REFIID riid, LPVOID *ppv)
267 {
268 ATLASSERT(ppv != NULL && ppv != NULL);
269
270 if (pv == NULL)
271 return T1::CreateInstance(NULL, riid, ppv);
272 else
273 return T2::CreateInstance(pv, riid, ppv);
274 }
275 };
276
277 template <class Base>
278 class CComObjectCached : public Base
279 {
280 public:
281 CComObjectCached(void * = NULL)
282 {
283 }
284
285 STDMETHOD_(ULONG, AddRef)()
286 {
287 CComObjectCached<Base> *pThis;
288 ULONG newRefCount;
289
290 pThis = reinterpret_cast<CComObjectCached<Base>*>(this);
291 newRefCount = pThis->InternalAddRef();
292 if (newRefCount == 2)
293 _pAtlModule->Lock();
294 return newRefCount;
295 }
296
297 STDMETHOD_(ULONG, Release)()
298 {
299 CComObjectCached<Base> *pThis;
300 ULONG newRefCount;
301
302 pThis = reinterpret_cast<CComObjectCached<Base>*>(this);
303 newRefCount = pThis->InternalRelease();
304 if (newRefCount == 0)
305 delete this;
306 else if (newRefCount == 1)
307 _pAtlModule->Unlock();
308 return newRefCount;
309 }
310
311 STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
312 {
313 CComObjectCached<Base> *pThis;
314
315 pThis = reinterpret_cast<CComObjectCached<Base>*>(this);
316 return pThis->_InternalQueryInterface(iid, ppvObject);
317 }
318 };
319
320 #define BEGIN_COM_MAP(x) \
321 public: \
322 typedef x _ComMapClass; \
323 HRESULT _InternalQueryInterface(REFIID iid, void **ppvObject) \
324 { \
325 return InternalQueryInterface(this, _GetEntries(), iid, ppvObject); \
326 } \
327 const static ATL::_ATL_INTMAP_ENTRY *WINAPI _GetEntries() \
328 { \
329 static const ATL::_ATL_INTMAP_ENTRY _entries[] = {
330
331 #define END_COM_MAP() \
332 {NULL, 0, 0} \
333 }; \
334 return _entries; \
335 } \
336 virtual ULONG STDMETHODCALLTYPE AddRef() = 0; \
337 virtual ULONG STDMETHODCALLTYPE Release() = 0; \
338 STDMETHOD(QueryInterface)(REFIID, void **) = 0;
339
340 #define COM_INTERFACE_ENTRY_IID(iid, x) \
341 {&iid, offsetofclass(x, _ComMapClass), _ATL_SIMPLEMAPENTRY},
342
343 #define COM_INTERFACE_ENTRY2_IID(iid, x, x2) \
344 {&iid, \
345 reinterpret_cast<DWORD_PTR>(static_cast<x *>(static_cast<x2 *>(reinterpret_cast<_ComMapClass *>(_ATL_PACKING)))) - _ATL_PACKING, \
346 _ATL_SIMPLEMAPENTRY},
347
348 #define COM_INTERFACE_ENTRY_BREAK(x) \
349 {&_ATL_IIDOF(x), \
350 NULL, \
351 _Break}, // Break is a function that issues int 3.
352
353 #define COM_INTERFACE_ENTRY_NOINTERFACE(x) \
354 {&_ATL_IIDOF(x), \
355 NULL, \
356 _NoInterface}, // NoInterface returns E_NOINTERFACE.
357
358 #define COM_INTERFACE_ENTRY_FUNC(iid, dw, func) \
359 {&iid, \
360 dw, \
361 func},
362
363 #define COM_INTERFACE_ENTRY_FUNC_BLIND(dw, func) \
364 {NULL, \
365 dw, \
366 func},
367
368 #define COM_INTERFACE_ENTRY_CHAIN(classname) \
369 {NULL, \
370 reinterpret_cast<DWORD>(&_CComChainData<classname, _ComMapClass>::data), \
371 _Chain},
372
373 #define DECLARE_REGISTRY_RESOURCEID(x)\
374 static HRESULT WINAPI UpdateRegistry(BOOL bRegister)\
375 {\
376 return ATL::_pAtlModule->UpdateRegistryFromResource(x, bRegister); \
377 }
378
379 #define DECLARE_NOT_AGGREGATABLE(x) \
380 public: \
381 typedef ATL::CComCreator2<ATL::CComCreator<ATL::CComObject<x> >, ATL::CComFailCreator<CLASS_E_NOAGGREGATION> > _CreatorClass;
382
383 #define DECLARE_AGGREGATABLE(x) \
384 public: \
385 typedef CComCreator2<CComCreator<CComObject<x> >, CComCreator<CComAggObject<x> > > _CreatorClass;
386
387 #define DECLARE_ONLY_AGGREGATABLE(x) \
388 public: \
389 typedef CComCreator2<CComFailCreator<E_FAIL>, CComCreator<CComAggObject<x> > > _CreatorClass;
390
391 #define DECLARE_POLY_AGGREGATABLE(x) \
392 public: \
393 typedef CComCreator<CComPolyObject<x> > _CreatorClass;
394
395 #define DECLARE_GET_CONTROLLING_UNKNOWN() \
396 public: \
397 virtual IUnknown *GetControllingUnknown() \
398 { \
399 return GetUnknown(); \
400 }
401
402 #define DECLARE_PROTECT_FINAL_CONSTRUCT() \
403 void InternalFinalConstructAddRef() \
404 { \
405 InternalAddRef(); \
406 } \
407 void InternalFinalConstructRelease() \
408 { \
409 InternalRelease(); \
410 }
411
412 #define BEGIN_OBJECT_MAP(x) static ATL::_ATL_OBJMAP_ENTRY x[] = {
413
414 #define END_OBJECT_MAP() {NULL, NULL, NULL, NULL, NULL, 0, NULL, NULL, NULL}};
415
416 #define OBJECT_ENTRY(clsid, class) \
417 { \
418 &clsid, \
419 class::UpdateRegistry, \
420 class::_ClassFactoryCreatorClass::CreateInstance, \
421 class::_CreatorClass::CreateInstance, \
422 NULL, \
423 0, \
424 class::GetObjectDescription, \
425 class::GetCategoryMap, \
426 class::ObjectMain },
427
428 class CComClassFactory :
429 public IClassFactory,
430 public CComObjectRootEx<CComGlobalsThreadModel>
431 {
432 public:
433 _ATL_CREATORFUNC *m_pfnCreateInstance;
434 public:
435 STDMETHOD(CreateInstance)(LPUNKNOWN pUnkOuter, REFIID riid, void **ppvObj)
436 {
437 HRESULT hResult;
438
439 ATLASSERT(m_pfnCreateInstance != NULL);
440
441 if (ppvObj == NULL)
442 return E_POINTER;
443 *ppvObj = NULL;
444
445 if (pUnkOuter != NULL && InlineIsEqualUnknown(riid) == FALSE)
446 hResult = CLASS_E_NOAGGREGATION;
447 else
448 hResult = m_pfnCreateInstance(pUnkOuter, riid, ppvObj);
449 return hResult;
450 }
451
452 STDMETHOD(LockServer)(BOOL fLock)
453 {
454 if (fLock)
455 _pAtlModule->Lock();
456 else
457 _pAtlModule->Unlock();
458 return S_OK;
459 }
460
461 void SetVoid(void *pv)
462 {
463 m_pfnCreateInstance = (_ATL_CREATORFUNC *)pv;
464 }
465
466 BEGIN_COM_MAP(CComClassFactory)
467 COM_INTERFACE_ENTRY_IID(IID_IClassFactory, IClassFactory)
468 END_COM_MAP()
469 };
470
471 template <class T, const CLSID *pclsid = &CLSID_NULL>
472 class CComCoClass
473 {
474 public:
475 DECLARE_CLASSFACTORY()
476
477 static LPCTSTR WINAPI GetObjectDescription()
478 {
479 return NULL;
480 }
481 };
482
483 template <class T>
484 class _Copy
485 {
486 public:
487 static HRESULT copy(T *pTo, const T *pFrom)
488 {
489 memcpy(pTo, pFrom, sizeof(T));
490 return S_OK;
491 }
492
493 static void init(T *)
494 {
495 }
496
497 static void destroy(T *)
498 {
499 }
500 };
501
502 template<>
503 class _Copy<CONNECTDATA>
504 {
505 public:
506 static HRESULT copy(CONNECTDATA *pTo, const CONNECTDATA *pFrom)
507 {
508 *pTo = *pFrom;
509 if (pTo->pUnk)
510 pTo->pUnk->AddRef();
511 return S_OK;
512 }
513
514 static void init(CONNECTDATA *)
515 {
516 }
517
518 static void destroy(CONNECTDATA *p)
519 {
520 if (p->pUnk)
521 p->pUnk->Release();
522 }
523 };
524
525 template <class T>
526 class _CopyInterface
527 {
528 public:
529 static HRESULT copy(T **pTo, T **pFrom)
530 {
531 *pTo = *pFrom;
532 if (*pTo)
533 (*pTo)->AddRef();
534 return S_OK;
535 }
536
537 static void init(T **)
538 {
539 }
540
541 static void destroy(T **p)
542 {
543 if (*p)
544 (*p)->Release();
545 }
546 };
547
548 enum CComEnumFlags
549 {
550 AtlFlagNoCopy = 0,
551 AtlFlagTakeOwnership = 2, // BitOwn
552 AtlFlagCopy = 3 // BitOwn | BitCopy
553 };
554
555 template <class Base, const IID *piid, class T, class Copy>
556 class CComEnumImpl : public Base
557 {
558 private:
559 typedef CComObject<CComEnum<Base, piid, T, Copy> > enumeratorClass;
560 public:
561 CComPtr<IUnknown> m_spUnk;
562 DWORD m_dwFlags;
563 T *m_begin;
564 T *m_end;
565 T *m_iter;
566 public:
567 CComEnumImpl()
568 {
569 m_dwFlags = 0;
570 m_begin = NULL;
571 m_end = NULL;
572 m_iter = NULL;
573 }
574
575 virtual ~CComEnumImpl()
576 {
577 T *x;
578
579 if ((m_dwFlags & BitOwn) != 0)
580 {
581 for (x = m_begin; x != m_end; x++)
582 Copy::destroy(x);
583 delete [] m_begin;
584 }
585 }
586
587 HRESULT Init(T *begin, T *end, IUnknown *pUnk, CComEnumFlags flags = AtlFlagNoCopy)
588 {
589 T *newBuffer;
590 T *sourcePtr;
591 T *destPtr;
592 T *cleanupPtr;
593 HRESULT hResult;
594
595 if (flags == AtlFlagCopy)
596 {
597 ATLTRY(newBuffer = new T[end - begin])
598 if (newBuffer == NULL)
599 return E_OUTOFMEMORY;
600 destPtr = newBuffer;
601 for (sourcePtr = begin; sourcePtr != end; sourcePtr++)
602 {
603 Copy::init(destPtr);
604 hResult = Copy::copy(destPtr, sourcePtr);
605 if (FAILED(hResult))
606 {
607 cleanupPtr = m_begin;
608 while (cleanupPtr < destPtr)
609 Copy::destroy(cleanupPtr++);
610 delete [] newBuffer;
611 return hResult;
612 }
613 destPtr++;
614 }
615 m_begin = newBuffer;
616 m_end = m_begin + (end - begin);
617 }
618 else
619 {
620 m_begin = begin;
621 m_end = end;
622 }
623 m_spUnk = pUnk;
624 m_dwFlags = flags;
625 m_iter = m_begin;
626 return S_OK;
627 }
628
629 STDMETHOD(Next)(ULONG celt, T *rgelt, ULONG *pceltFetched)
630 {
631 ULONG numAvailable;
632 ULONG numToFetch;
633 T *rgeltTemp;
634 HRESULT hResult;
635
636 if (pceltFetched != NULL)
637 *pceltFetched = 0;
638 if (celt == 0)
639 return E_INVALIDARG;
640 if (rgelt == NULL || (celt != 1 && pceltFetched == NULL))
641 return E_POINTER;
642 if (m_begin == NULL || m_end == NULL || m_iter == NULL)
643 return E_FAIL;
644
645 numAvailable = static_cast<ULONG>(m_end - m_iter);
646 if (celt < numAvailable)
647 numToFetch = celt;
648 else
649 numToFetch = numAvailable;
650 if (pceltFetched != NULL)
651 *pceltFetched = numToFetch;
652 rgeltTemp = rgelt;
653 while (numToFetch != 0)
654 {
655 hResult = Copy::copy(rgeltTemp, m_iter);
656 if (FAILED(hResult))
657 {
658 while (rgelt < rgeltTemp)
659 Copy::destroy(rgelt++);
660 if (pceltFetched != NULL)
661 *pceltFetched = 0;
662 return hResult;
663 }
664 rgeltTemp++;
665 m_iter++;
666 numToFetch--;
667 }
668 if (numAvailable < celt)
669 return S_FALSE;
670 return S_OK;
671 }
672
673 STDMETHOD(Skip)(ULONG celt)
674 {
675 ULONG numAvailable;
676 ULONG numToSkip;
677
678 if (celt == 0)
679 return E_INVALIDARG;
680
681 numAvailable = static_cast<ULONG>(m_end - m_iter);
682 if (celt < numAvailable)
683 numToSkip = celt;
684 else
685 numToSkip = numAvailable;
686 m_iter += numToSkip;
687 if (numAvailable < celt)
688 return S_FALSE;
689 return S_OK;
690 }
691
692 STDMETHOD(Reset)()
693 {
694 m_iter = m_begin;
695 return S_OK;
696 }
697
698 STDMETHOD(Clone)(Base **ppEnum)
699 {
700 enumeratorClass *newInstance;
701 HRESULT hResult;
702
703 hResult = E_POINTER;
704 if (ppEnum != NULL)
705 {
706 *ppEnum = NULL;
707 hResult = enumeratorClass::CreateInstance(&newInstance);
708 if (SUCCEEDED(hResult))
709 {
710 hResult = newInstance->Init(m_begin, m_end, (m_dwFlags & BitOwn) ? this : m_spUnk);
711 if (SUCCEEDED(hResult))
712 {
713 newInstance->m_iter = m_iter;
714 hResult = newInstance->_InternalQueryInterface(*piid, (void **)ppEnum);
715 }
716 if (FAILED(hResult))
717 delete newInstance;
718 }
719 }
720 return hResult;
721 }
722
723 protected:
724 enum FlagBits
725 {
726 BitCopy = 1,
727 BitOwn = 2
728 };
729 };
730
731 template <class Base, const IID *piid, class T, class Copy, class ThreadModel>
732 class CComEnum :
733 public CComEnumImpl<Base, piid, T, Copy>,
734 public CComObjectRootEx<ThreadModel>
735 {
736 public:
737 typedef CComEnum<Base, piid, T, Copy > _CComEnum;
738 typedef CComEnumImpl<Base, piid, T, Copy > _CComEnumBase;
739
740 BEGIN_COM_MAP(_CComEnum)
741 COM_INTERFACE_ENTRY_IID(*piid, _CComEnumBase)
742 END_COM_MAP()
743 };
744
745 #ifndef _DEFAULT_VECTORLENGTH
746 #define _DEFAULT_VECTORLENGTH 4
747 #endif
748
749 class CComDynamicUnkArray
750 {
751 public:
752 int m_nSize;
753 IUnknown **m_ppUnk;
754 public:
755 CComDynamicUnkArray()
756 {
757 m_nSize = 0;
758 m_ppUnk = NULL;
759 }
760
761 ~CComDynamicUnkArray()
762 {
763 free(m_ppUnk);
764 }
765
766 IUnknown **begin()
767 {
768 return m_ppUnk;
769 }
770
771 IUnknown **end()
772 {
773 return &m_ppUnk[m_nSize];
774 }
775
776 IUnknown *GetAt(int nIndex)
777 {
778 ATLASSERT(nIndex >= 0 && nIndex < m_nSize);
779 if (nIndex >= 0 && nIndex < m_nSize)
780 return m_ppUnk[nIndex];
781 else
782 return NULL;
783 }
784
785 IUnknown *WINAPI GetUnknown(DWORD dwCookie)
786 {
787 ATLASSERT(dwCookie != 0 && dwCookie <= static_cast<DWORD>(m_nSize));
788 if (dwCookie != 0 && dwCookie <= static_cast<DWORD>(m_nSize))
789 return GetAt(dwCookie - 1);
790 else
791 return NULL;
792 }
793
794 DWORD WINAPI GetCookie(IUnknown **ppFind)
795 {
796 IUnknown **x;
797 DWORD curCookie;
798
799 ATLASSERT(ppFind != NULL && *ppFind != NULL);
800 if (ppFind != NULL && *ppFind != NULL)
801 {
802 curCookie = 1;
803 for (x = begin(); x < end(); x++)
804 {
805 if (*x == *ppFind)
806 return curCookie;
807 curCookie++;
808 }
809 }
810 return 0;
811 }
812
813 DWORD Add(IUnknown *pUnk)
814 {
815 IUnknown **x;
816 IUnknown **newArray;
817 int newSize;
818 DWORD curCookie;
819
820 ATLASSERT(pUnk != NULL);
821 if (m_nSize == 0)
822 {
823 newSize = _DEFAULT_VECTORLENGTH * sizeof(IUnknown *);
824 ATLTRY(newArray = reinterpret_cast<IUnknown **>(malloc(newSize)));
825 if (newArray == NULL)
826 return 0;
827 memset(newArray, 0, newSize);
828 m_ppUnk = newArray;
829 m_nSize = _DEFAULT_VECTORLENGTH;
830 }
831 curCookie = 1;
832 for (x = begin(); x < end(); x++)
833 {
834 if (*x == NULL)
835 {
836 *x = pUnk;
837 return curCookie;
838 }
839 curCookie++;
840 }
841 newSize = m_nSize * 2;
842 newArray = reinterpret_cast<IUnknown **>(realloc(m_ppUnk, newSize * sizeof(IUnknown *)));
843 if (newArray == NULL)
844 return 0;
845 memset(&m_ppUnk[m_nSize], 0, (newSize - m_nSize) * sizeof(IUnknown *));
846 m_ppUnk = newArray;
847 m_nSize = newSize;
848 m_ppUnk[m_nSize] = pUnk;
849 return m_nSize + 1;
850 }
851
852 BOOL Remove(DWORD dwCookie)
853 {
854 DWORD index;
855
856 index = dwCookie - 1;
857 ATLASSERT(index < dwCookie && index < static_cast<DWORD>(m_nSize));
858 if (index < dwCookie && index < static_cast<DWORD>(m_nSize) && m_ppUnk[index] != NULL)
859 {
860 m_ppUnk[index] = NULL;
861 return TRUE;
862 }
863 return FALSE;
864 }
865
866 };
867
868 struct _ATL_CONNMAP_ENTRY
869 {
870 DWORD_PTR dwOffset;
871 };
872
873 template <const IID *piid>
874 class _ICPLocator
875 {
876 public:
877 STDMETHOD(_LocCPQueryInterface)(REFIID riid, void **ppvObject) = 0;
878 virtual ULONG STDMETHODCALLTYPE AddRef() = 0;
879 virtual ULONG STDMETHODCALLTYPE Release() = 0;
880 };
881
882 template<class T, const IID *piid, class CDV = CComDynamicUnkArray>
883 class IConnectionPointImpl : public _ICPLocator<piid>
884 {
885 typedef CComEnum<IEnumConnections, &IID_IEnumConnections, CONNECTDATA, _Copy<CONNECTDATA> > CComEnumConnections;
886 public:
887 CDV m_vec;
888 public:
889 ~IConnectionPointImpl()
890 {
891 IUnknown **x;
892
893 for (x = m_vec.begin(); x < m_vec.end(); x++)
894 if (*x != NULL)
895 (*x)->Release();
896 }
897
898 STDMETHOD(_LocCPQueryInterface)(REFIID riid, void **ppvObject)
899 {
900 IConnectionPointImpl<T, piid, CDV> *pThis;
901
902 pThis = reinterpret_cast<IConnectionPointImpl<T, piid, CDV>*>(this);
903
904 ATLASSERT(ppvObject != NULL);
905 if (ppvObject == NULL)
906 return E_POINTER;
907
908 if (InlineIsEqualGUID(riid, IID_IConnectionPoint) || InlineIsEqualUnknown(riid))
909 {
910 *ppvObject = this;
911 pThis->AddRef();
912 return S_OK;
913 }
914 else
915 {
916 *ppvObject = NULL;
917 return E_NOINTERFACE;
918 }
919 }
920
921 STDMETHOD(GetConnectionInterface)(IID *piid2)
922 {
923 if (piid2 == NULL)
924 return E_POINTER;
925 *piid2 = *piid;
926 return S_OK;
927 }
928
929 STDMETHOD(GetConnectionPointContainer)(IConnectionPointContainer **ppCPC)
930 {
931 T *pThis;
932
933 pThis = static_cast<T *>(this);
934 return pThis->QueryInterface(IID_IConnectionPointContainer, reinterpret_cast<void **>(ppCPC));
935 }
936
937 STDMETHOD(Advise)(IUnknown *pUnkSink, DWORD *pdwCookie)
938 {
939 IUnknown *adviseTarget;
940 IID interfaceID;
941 HRESULT hResult;
942
943 if (pdwCookie != NULL)
944 *pdwCookie = 0;
945 if (pUnkSink == NULL || pdwCookie == NULL)
946 return E_POINTER;
947 GetConnectionInterface(&interfaceID); // can't fail
948 hResult = pUnkSink->QueryInterface(interfaceID, reinterpret_cast<void **>(&adviseTarget));
949 if (SUCCEEDED(hResult))
950 {
951 *pdwCookie = m_vec.Add(adviseTarget);
952 if (*pdwCookie != 0)
953 hResult = S_OK;
954 else
955 {
956 adviseTarget->Release();
957 hResult = CONNECT_E_ADVISELIMIT;
958 }
959 }
960 else if (hResult == E_NOINTERFACE)
961 hResult = CONNECT_E_CANNOTCONNECT;
962 return hResult;
963 }
964
965 STDMETHOD(Unadvise)(DWORD dwCookie)
966 {
967 IUnknown *adviseTarget;
968 HRESULT hResult;
969
970 adviseTarget = m_vec.GetUnknown(dwCookie);
971 if (m_vec.Remove(dwCookie))
972 {
973 if (adviseTarget != NULL)
974 adviseTarget->Release();
975 hResult = S_OK;
976 }
977 else
978 hResult = CONNECT_E_NOCONNECTION;
979 return hResult;
980 }
981
982 STDMETHOD(EnumConnections)(IEnumConnections **ppEnum)
983 {
984 CComObject<CComEnumConnections> *newEnumerator;
985 CONNECTDATA *itemBuffer;
986 CONNECTDATA *itemBufferEnd;
987 IUnknown **x;
988 HRESULT hResult;
989
990 ATLASSERT(ppEnum != NULL);
991 if (ppEnum == NULL)
992 return E_POINTER;
993 *ppEnum = NULL;
994
995 ATLTRY(itemBuffer = new CONNECTDATA[m_vec.end() - m_vec.begin()])
996 if (itemBuffer == NULL)
997 return E_OUTOFMEMORY;
998 itemBufferEnd = itemBuffer;
999 for (x = m_vec.begin(); x < m_vec.end(); x++)
1000 {
1001 if (*x != NULL)
1002 {
1003 (*x)->AddRef();
1004 itemBufferEnd->pUnk = *x;
1005 itemBufferEnd->dwCookie = m_vec.GetCookie(x);
1006 itemBufferEnd++;
1007 }
1008 }
1009 ATLTRY(newEnumerator = new CComObject<CComEnumConnections>)
1010 if (newEnumerator == NULL)
1011 return E_OUTOFMEMORY;
1012 newEnumerator->Init(itemBuffer, itemBufferEnd, NULL, AtlFlagTakeOwnership); // can't fail
1013 hResult = newEnumerator->_InternalQueryInterface(IID_IEnumConnections, (void **)ppEnum);
1014 if (FAILED(hResult))
1015 delete newEnumerator;
1016 return hResult;
1017 }
1018 };
1019
1020 template <class T>
1021 class IConnectionPointContainerImpl : public IConnectionPointContainer
1022 {
1023 typedef const _ATL_CONNMAP_ENTRY * (*handlerFunctionType)(int *);
1024 typedef CComEnum<IEnumConnectionPoints, &IID_IEnumConnectionPoints, IConnectionPoint *, _CopyInterface<IConnectionPoint> >
1025 CComEnumConnectionPoints;
1026
1027 public:
1028 STDMETHOD(EnumConnectionPoints)(IEnumConnectionPoints **ppEnum)
1029 {
1030 const _ATL_CONNMAP_ENTRY *entryPtr;
1031 int connectionPointCount;
1032 IConnectionPoint **itemBuffer;
1033 int destIndex;
1034 handlerFunctionType handlerFunction;
1035 CComEnumConnectionPoints *newEnumerator;
1036 HRESULT hResult;
1037
1038 ATLASSERT(ppEnum != NULL);
1039 if (ppEnum == NULL)
1040 return E_POINTER;
1041 *ppEnum = NULL;
1042
1043 entryPtr = T::GetConnMap(&connectionPointCount);
1044 ATLTRY(itemBuffer = new IConnectionPoint * [connectionPointCount])
1045 if (itemBuffer == NULL)
1046 return E_OUTOFMEMORY;
1047
1048 destIndex = 0;
1049 while (entryPtr->dwOffset != static_cast<DWORD_PTR>(-1))
1050 {
1051 if (entryPtr->dwOffset == static_cast<DWORD_PTR>(-2))
1052 {
1053 entryPtr++;
1054 handlerFunction = reinterpret_cast<handlerFunctionType>(entryPtr->dwOffset);
1055 entryPtr = handlerFunction(NULL);
1056 }
1057 else
1058 {
1059 itemBuffer[destIndex++] = reinterpret_cast<IConnectionPoint *>((char *)this + entryPtr->dwOffset);
1060 entryPtr++;
1061 }
1062 }
1063
1064 ATLTRY(newEnumerator = new CComObject<CComEnumConnectionPoints>)
1065 if (newEnumerator == NULL)
1066 {
1067 delete [] itemBuffer;
1068 return E_OUTOFMEMORY;
1069 }
1070
1071 newEnumerator->Init(&itemBuffer[0], &itemBuffer[destIndex], NULL, AtlFlagTakeOwnership); // can't fail
1072 hResult = newEnumerator->QueryInterface(IID_IEnumConnectionPoints, (void**)ppEnum);
1073 if (FAILED(hResult))
1074 delete newEnumerator;
1075 return hResult;
1076 }
1077
1078 STDMETHOD(FindConnectionPoint)(REFIID riid, IConnectionPoint **ppCP)
1079 {
1080 IID interfaceID;
1081 const _ATL_CONNMAP_ENTRY *entryPtr;
1082 handlerFunctionType handlerFunction;
1083 IConnectionPoint *connectionPoint;
1084 HRESULT hResult;
1085
1086 if (ppCP == NULL)
1087 return E_POINTER;
1088 *ppCP = NULL;
1089 hResult = CONNECT_E_NOCONNECTION;
1090 entryPtr = T::GetConnMap(NULL);
1091 while (entryPtr->dwOffset != static_cast<DWORD_PTR>(-1))
1092 {
1093 if (entryPtr->dwOffset == static_cast<DWORD_PTR>(-2))
1094 {
1095 entryPtr++;
1096 handlerFunction = reinterpret_cast<handlerFunctionType>(entryPtr->dwOffset);
1097 entryPtr = handlerFunction(NULL);
1098 }
1099 else
1100 {
1101 connectionPoint = reinterpret_cast<IConnectionPoint *>(reinterpret_cast<char *>(this) + entryPtr->dwOffset);
1102 if (SUCCEEDED(connectionPoint->GetConnectionInterface(&interfaceID)) && InlineIsEqualGUID(riid, interfaceID))
1103 {
1104 *ppCP = connectionPoint;
1105 connectionPoint->AddRef();
1106 hResult = S_OK;
1107 break;
1108 }
1109 entryPtr++;
1110 }
1111 }
1112 return hResult;
1113 }
1114 };
1115
1116 #define BEGIN_CONNECTION_POINT_MAP(x) \
1117 typedef x _atl_conn_classtype; \
1118 static const ATL::_ATL_CONNMAP_ENTRY *GetConnMap(int *pnEntries) { \
1119 static const ATL::_ATL_CONNMAP_ENTRY _entries[] = {
1120
1121 #define END_CONNECTION_POINT_MAP() \
1122 {(DWORD_PTR)-1} }; \
1123 if (pnEntries) \
1124 *pnEntries = sizeof(_entries) / sizeof(ATL::_ATL_CONNMAP_ENTRY) - 1; \
1125 return _entries;}
1126
1127 #define CONNECTION_POINT_ENTRY(iid) \
1128 {offsetofclass(ATL::_ICPLocator<&iid>, _atl_conn_classtype) - \
1129 offsetofclass(ATL::IConnectionPointContainerImpl<_atl_conn_classtype>, _atl_conn_classtype)},
1130
1131 }; // namespace ATL
1132
1133 #endif // _atlcom_h