[atlnew]
authorAndrew Hill <ash77@reactos.org>
Tue, 20 Oct 2009 12:26:51 +0000 (12:26 +0000)
committerAndrew Hill <ash77@reactos.org>
Tue, 20 Oct 2009 12:26:51 +0000 (12:26 +0000)
- Initial checkin of minimal atl library
- To prevent conflicts with wine atl dll, this library is named atlnew

svn path=/trunk/; revision=43645

reactos/lib/atl/atl.rbuild [new file with mode: 0644]
reactos/lib/atl/atlbase.cpp [new file with mode: 0644]
reactos/lib/atl/atlbase.h [new file with mode: 0644]
reactos/lib/atl/atlcom.h [new file with mode: 0644]
reactos/lib/atl/atlcore.cpp [new file with mode: 0644]
reactos/lib/atl/atlcore.h [new file with mode: 0644]
reactos/lib/atl/atlwin.h [new file with mode: 0644]
reactos/lib/atl/statreg.h [new file with mode: 0644]
reactos/lib/lib.rbuild

diff --git a/reactos/lib/atl/atl.rbuild b/reactos/lib/atl/atl.rbuild
new file mode 100644 (file)
index 0000000..33eb6a5
--- /dev/null
@@ -0,0 +1,10 @@
+<?xml version="1.0"?>
+<!DOCTYPE module SYSTEM "../../../tools/rbuild/project.dtd">
+<module name="atlnew" type="staticlibrary">
+       <include base="atlnew">.</include>
+       <define name="UNICODE" />
+       <define name="_UNICODE" />
+       <define name="ROS_Headers" />
+       <file>atlbase.cpp</file>
+       <file>atlcore.cpp</file>
+</module>
diff --git a/reactos/lib/atl/atlbase.cpp b/reactos/lib/atl/atlbase.cpp
new file mode 100644 (file)
index 0000000..7ef85d1
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * ReactOS ATL
+ *
+ * Copyright 2009 Andrew Hill <ash77@reactos.org>
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+#include <windows.h>
+#include <tchar.h>
+#include <ocidl.h>
+#include "atlbase.h"
+
+namespace ATL
+{
+
+CAtlBaseModule                                                         _AtlBaseModule;
+CAtlWinModule                                                          _AtlWinModule = CAtlWinModule();
+CAtlComModule                                                          _AtlComModule;
+
+};  // namespace ATL
diff --git a/reactos/lib/atl/atlbase.h b/reactos/lib/atl/atlbase.h
new file mode 100644 (file)
index 0000000..865050c
--- /dev/null
@@ -0,0 +1,965 @@
+/*
+ * ReactOS ATL
+ *
+ * Copyright 2009 Andrew Hill <ash77@reactos.org>
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+#ifndef _atlbase_h
+#define _atlbase_h
+
+#include "atlcore.h"
+#include "statreg.h"
+
+#ifdef _MSC_VER
+// It is common to use this in ATL constructors. They only store this for later use, so the usage is safe.
+#pragma warning(disable:4355)
+#endif
+
+#ifndef _ATL_PACKING
+#define _ATL_PACKING 8
+#endif
+
+#ifndef _ATL_FREE_THREADED
+#ifndef _ATL_APARTMENT_THREADED
+#ifndef _ATL_SINGLE_THREADED
+#define _ATL_FREE_THREADED
+#endif
+#endif
+#endif
+
+#ifndef ATLTRY
+#define ATLTRY(x) x;
+#endif
+
+#ifdef _ATL_DISABLE_NO_VTABLE
+#define ATL_NO_VTABLE
+#else
+#define ATL_NO_VTABLE __declspec(novtable)
+#endif
+
+#define offsetofclass(base, derived) (reinterpret_cast<DWORD_PTR>(static_cast<base *>(reinterpret_cast<derived *>(_ATL_PACKING))) - _ATL_PACKING)
+
+extern "C" IMAGE_DOS_HEADER __ImageBase;
+
+namespace ATL
+{
+
+class CAtlModule;
+class CComModule;
+class CAtlComModule;
+__declspec(selectany) CAtlModule                       *_pAtlModule = NULL;
+__declspec(selectany) CComModule                       *_pModule = NULL;
+extern CAtlComModule _AtlComModule;
+
+typedef HRESULT (WINAPI _ATL_CREATORFUNC)(void *pv, REFIID riid, LPVOID *ppv);
+typedef LPCTSTR (WINAPI _ATL_DESCRIPTIONFUNC)();
+typedef const struct _ATL_CATMAP_ENTRY * (_ATL_CATMAPFUNC)();
+
+struct _ATL_OBJMAP_ENTRY30
+{
+       const CLSID                                                             *pclsid;
+       HRESULT (WINAPI *pfnUpdateRegistry)(BOOL bRegister);
+       _ATL_CREATORFUNC                                                *pfnGetClassObject;
+       _ATL_CREATORFUNC                                                *pfnCreateInstance;
+       IUnknown                                                                *pCF;
+       DWORD                                                                   dwRegister;
+       _ATL_DESCRIPTIONFUNC                                    *pfnGetObjectDescription;
+       _ATL_CATMAPFUNC                                                 *pfnGetCategoryMap;
+       void (WINAPI *pfnObjectMain)(bool bStarting);
+
+       HRESULT WINAPI RevokeClassObject()
+       {
+               if (dwRegister == 0)
+                       return S_OK;
+               return CoRevokeClassObject(dwRegister);
+       }
+
+       HRESULT WINAPI RegisterClassObject(DWORD dwClsContext, DWORD dwFlags)
+       {
+               IUnknown                                                        *p;
+               HRESULT                                                         hResult;
+
+               p = NULL;
+               if (pfnGetClassObject == NULL)
+                       return S_OK;
+
+               hResult = pfnGetClassObject(reinterpret_cast<LPVOID *>(pfnCreateInstance), IID_IUnknown, reinterpret_cast<LPVOID *>(&p));
+               if (SUCCEEDED(hResult))
+                       hResult = CoRegisterClassObject(*pclsid, p, dwClsContext, dwFlags, &dwRegister);
+
+               if (p != NULL)
+                       p->Release();
+
+               return hResult;
+       }
+};
+
+typedef _ATL_OBJMAP_ENTRY30 _ATL_OBJMAP_ENTRY;
+
+typedef void (__stdcall _ATL_TERMFUNC)(DWORD_PTR dw);
+
+struct _ATL_TERMFUNC_ELEM
+{
+       _ATL_TERMFUNC                                                   *pFunc;
+       DWORD_PTR                                                               dw;
+       _ATL_TERMFUNC_ELEM                                              *pNext;
+};
+
+struct _ATL_MODULE70
+{
+       UINT                                                                    cbSize;
+       LONG                                                                    m_nLockCnt;
+       _ATL_TERMFUNC_ELEM                                              *m_pTermFuncs;
+       CComCriticalSection                                             m_csStaticDataInitAndTypeInfo;
+};
+typedef _ATL_MODULE70 _ATL_MODULE;
+
+typedef HRESULT (WINAPI _ATL_CREATORARGFUNC)(void *pv, REFIID riid, LPVOID *ppv, DWORD_PTR dw);
+
+#define _ATL_SIMPLEMAPENTRY ((ATL::_ATL_CREATORARGFUNC *)1)
+
+struct _ATL_INTMAP_ENTRY
+{
+       const IID                                                               *piid;
+       DWORD_PTR                                                               dw;
+       _ATL_CREATORARGFUNC                                             *pFunc;
+};
+
+struct _AtlCreateWndData
+{
+       void                                                                    *m_pThis;
+       DWORD                                                                   m_dwThreadID;
+       _AtlCreateWndData                                               *m_pNext;
+};
+
+struct _ATL_COM_MODULE70
+{
+       UINT                                                                    cbSize;
+       HINSTANCE                                                               m_hInstTypeLib;
+       _ATL_OBJMAP_ENTRY                                               **m_ppAutoObjMapFirst;
+       _ATL_OBJMAP_ENTRY                                               **m_ppAutoObjMapLast;
+       CComCriticalSection                                             m_csObjMap;
+};
+typedef _ATL_COM_MODULE70 _ATL_COM_MODULE;
+
+struct _ATL_WIN_MODULE70
+{
+       UINT                                                                    cbSize;
+       CComCriticalSection                                             m_csWindowCreate;
+       _AtlCreateWndData                                               *m_pCreateWndList;
+#ifdef NOTYET
+       CSimpleArray<ATOM>                                              m_rgWindowClassAtoms;
+#endif
+};
+typedef _ATL_WIN_MODULE70 _ATL_WIN_MODULE;
+
+struct _ATL_REGMAP_ENTRY
+{
+       LPCOLESTR                                                               szKey;
+       LPCOLESTR                                                               szData;
+};
+
+HRESULT __stdcall AtlWinModuleInit(_ATL_WIN_MODULE *pWinModule);
+HRESULT __stdcall AtlWinModuleTerm(_ATL_WIN_MODULE *pWinModule, HINSTANCE hInst);
+HRESULT __stdcall AtlInternalQueryInterface(void *pThis, const _ATL_INTMAP_ENTRY *pEntries, REFIID iid, void **ppvObject);
+void __stdcall AtlWinModuleAddCreateWndData(_ATL_WIN_MODULE *pWinModule, _AtlCreateWndData *pData, void *pObject);
+void *__stdcall AtlWinModuleExtractCreateWndData(_ATL_WIN_MODULE *pWinModule);
+HRESULT __stdcall AtlComModuleGetClassObject(_ATL_COM_MODULE *pComModule, REFCLSID rclsid, REFIID riid, LPVOID *ppv);
+
+template<class TLock>
+class CComCritSecLock
+{
+private:
+       bool                                                                            m_bLocked;
+       TLock                                                                           &m_cs;
+public:
+       CComCritSecLock(TLock &cs, bool bInitialLock = true) : m_cs(cs)
+       {
+               HRESULT                                                                 hResult;
+
+               m_bLocked = false;
+               if (bInitialLock)
+               {
+                       hResult = Lock();
+                       if (FAILED(hResult))
+                       {
+                               ATLASSERT(false);
+                       }
+               }
+       }
+
+       ~CComCritSecLock()
+       {
+               if (m_bLocked)
+                       Unlock();
+       }
+
+       HRESULT Lock()
+       {
+               HRESULT                                                                 hResult;
+
+               ATLASSERT(!m_bLocked);
+               hResult = m_cs.Lock();
+               if (FAILED(hResult))
+                       return hResult;
+               m_bLocked = true;
+
+               return S_OK;
+       }
+
+       void Unlock()
+       {
+               HRESULT                                                                 hResult;
+
+               ATLASSERT(m_bLocked);
+               hResult = m_cs.Unlock();
+               if (FAILED(hResult))
+               {
+                       ATLASSERT(false);
+               }
+               m_bLocked = false;
+       }
+};
+
+inline BOOL WINAPI InlineIsEqualUnknown(REFGUID rguid1)
+{
+   return (
+         ((unsigned long *)&rguid1)[0] == 0 &&
+         ((unsigned long *)&rguid1)[1] == 0 &&
+         ((unsigned long *)&rguid1)[2] == 0x000000C0 &&
+         ((unsigned long *)&rguid1)[3] == 0x46000000);
+}
+
+class CComMultiThreadModelNoCS
+{
+public:
+       typedef CComFakeCriticalSection AutoCriticalSection;
+       typedef CComFakeCriticalSection CriticalSection;
+       typedef CComMultiThreadModelNoCS ThreadModelNoCS;
+       typedef CComFakeCriticalSection AutoDeleteCriticalSection;
+
+       static ULONG WINAPI Increment(LPLONG p)
+       {
+               return InterlockedIncrement(p);
+       }
+
+       static ULONG WINAPI Decrement(LPLONG p)
+       {
+               return InterlockedDecrement(p);
+       }
+};
+
+class CComMultiThreadModel
+{
+public:
+       typedef CComAutoCriticalSection AutoCriticalSection;
+       typedef CComCriticalSection CriticalSection;
+       typedef CComMultiThreadModelNoCS ThreadModelNoCS;
+       typedef CComAutoDeleteCriticalSection AutoDeleteCriticalSection;
+
+       static ULONG WINAPI Increment(LPLONG p)
+       {
+               return InterlockedIncrement(p);
+       }
+
+       static ULONG WINAPI Decrement(LPLONG p)
+       {
+               return InterlockedDecrement(p);
+       }
+};
+
+class CComSingleThreadModel
+{
+public:
+       typedef CComFakeCriticalSection AutoCriticalSection;
+       typedef CComFakeCriticalSection CriticalSection;
+       typedef CComSingleThreadModel ThreadModelNoCS;
+       typedef CComFakeCriticalSection AutoDeleteCriticalSection;
+
+       static ULONG WINAPI Increment(LPLONG p)
+       {
+               return ++*p;
+       }
+
+       static ULONG WINAPI Decrement(LPLONG p)
+       {
+               return --*p;
+       }
+};
+
+#if defined(_ATL_FREE_THREADED)
+
+       typedef CComMultiThreadModel CComObjectThreadModel;
+       typedef CComMultiThreadModel CComGlobalsThreadModel;
+
+#elif defined(_ATL_APARTMENT_THREADED)
+
+       typedef CComSingleThreadModel CComObjectThreadModel;
+       typedef CComMultiThreadModel CComGlobalsThreadModel;
+
+#elif defined(_ATL_SINGLE_THREADED)
+
+       typedef CComSingleThreadModel CComObjectThreadModel;
+       typedef CComSingleThreadModel CComGlobalsThreadModel;
+
+#else
+#error No threading model
+#endif
+
+class CAtlModule : public _ATL_MODULE
+{
+public:
+       CAtlModule()
+       {
+               ATLASSERT(_pAtlModule == NULL);
+               _pAtlModule = this;
+               cbSize = sizeof(_ATL_MODULE);
+               m_nLockCnt = 0;
+       }
+
+       virtual LONG GetLockCount()
+       {
+               return m_nLockCnt;
+       }
+
+       virtual LONG Lock()
+       {
+               return CComGlobalsThreadModel::Increment(&m_nLockCnt);
+       }
+
+       virtual LONG Unlock()
+       {
+               return CComGlobalsThreadModel::Decrement(&m_nLockCnt);
+       }
+
+       virtual HRESULT AddCommonRGSReplacements(IRegistrarBase* /*pRegistrar*/) = 0;
+
+       HRESULT WINAPI UpdateRegistryFromResource(LPCTSTR lpszRes, BOOL bRegister, struct _ATL_REGMAP_ENTRY *pMapEntries = NULL)
+       {
+               CRegObject                                                              registrar;
+               TCHAR                                                                   modulePath[MAX_PATH];
+               HRESULT                                                                 hResult;
+
+               hResult = CommonInitRegistrar(registrar, modulePath, sizeof(modulePath) / sizeof(modulePath[0]), pMapEntries);
+               if (FAILED(hResult))
+                       return hResult;
+
+               if (bRegister != FALSE)
+                       hResult = registrar.ResourceRegisterSz(modulePath, lpszRes, _T("REGISTRY"));
+               else
+                       hResult = registrar.ResourceUnregisterSz(modulePath, lpszRes, _T("REGISTRY"));
+
+               return hResult;
+       }
+
+       HRESULT WINAPI UpdateRegistryFromResource(UINT nResID, BOOL bRegister, struct _ATL_REGMAP_ENTRY *pMapEntries = NULL)
+       {
+               CRegObject                                                              registrar;
+               TCHAR                                                                   modulePath[MAX_PATH];
+               HRESULT                                                                 hResult;
+
+               hResult = CommonInitRegistrar(registrar, modulePath, sizeof(modulePath) / sizeof(modulePath[0]), pMapEntries);
+               if (FAILED(hResult))
+                       return hResult;
+
+               if (bRegister != FALSE)
+                       hResult = registrar.ResourceRegister(modulePath, nResID, _T("REGISTRY"));
+               else
+                       hResult = registrar.ResourceRegister(modulePath, nResID, _T("REGISTRY"));
+
+               return hResult;
+       }
+
+private:
+       HRESULT CommonInitRegistrar(CRegObject &registrar, TCHAR *modulePath, DWORD modulePathCount, struct _ATL_REGMAP_ENTRY *pMapEntries)
+       {
+               HINSTANCE                                                               hInstance;
+               DWORD                                                                   dwFLen;
+               HRESULT                                                                 hResult;
+
+               hInstance = _AtlBaseModule.GetModuleInstance();
+               dwFLen = GetModuleFileName(hInstance, modulePath, modulePathCount);
+               if (dwFLen == modulePathCount)
+                       return HRESULT_FROM_WIN32(ERROR_INSUFFICIENT_BUFFER);
+               else if (dwFLen == 0)
+                       return HRESULT_FROM_WIN32(GetLastError());
+
+               if (pMapEntries != NULL)
+               {
+                       while (pMapEntries->szKey != NULL)
+                       {
+                               ATLASSERT(pMapEntries->szData != NULL);
+                               hResult = registrar.AddReplacement(pMapEntries->szKey, pMapEntries->szData);
+                               if (FAILED(hResult))
+                                       return hResult;
+                               pMapEntries++;
+                       }
+               }
+
+               hResult = AddCommonRGSReplacements(&registrar);
+               if (FAILED(hResult))
+                       return hResult;
+
+               hResult = registrar.AddReplacement(_T("Module"), modulePath);
+               if (FAILED(hResult))
+                       return hResult;
+
+               hResult = registrar.AddReplacement(_T("Module_Raw"), modulePath);
+               if (FAILED(hResult))
+                       return hResult;
+
+               return S_OK;
+       }
+};
+
+template <class T>
+class CAtlModuleT : public CAtlModule
+{
+public:
+
+       virtual HRESULT AddCommonRGSReplacements(IRegistrarBase *pRegistrar)
+       {
+               return pRegistrar->AddReplacement(L"APPID", T::GetAppId());
+       }
+
+       static LPCOLESTR GetAppId()
+       {
+               return L"";
+       }
+};
+
+class CAtlComModule : public _ATL_COM_MODULE
+{
+public:
+       CAtlComModule()
+       {
+               m_hInstTypeLib = reinterpret_cast<HINSTANCE>(&__ImageBase);
+               m_ppAutoObjMapFirst = NULL;
+               m_ppAutoObjMapLast = NULL;
+               if (FAILED(m_csObjMap.Init()))
+               {
+                       ATLASSERT(0);
+                       CAtlBaseModule::m_bInitFailed = true;
+                       return;
+               }
+               cbSize = sizeof(_ATL_COM_MODULE);
+       }
+
+       ~CAtlComModule()
+       {
+               Term();
+       }
+
+       void Term()
+       {
+               if (cbSize != 0)
+               {
+                       ATLASSERT(m_ppAutoObjMapFirst == NULL);
+                       ATLASSERT(m_ppAutoObjMapLast == NULL);
+                       m_csObjMap.Term();
+                       cbSize = 0;
+               }
+       }
+};
+
+template <class T>
+class CAtlDllModuleT : public CAtlModuleT<T>
+{
+public:
+       CAtlDllModuleT()
+       {
+       }
+
+       HRESULT DllCanUnloadNow()
+       {
+               T                                                                       *pThis;
+
+               pThis = static_cast<T *>(this);
+               if (pThis->GetLockCount() == 0)
+                       return S_OK;
+               return S_FALSE;
+       }
+
+       HRESULT DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
+       {
+               T                                                                       *pThis;
+
+               pThis = static_cast<T *>(this);
+               return pThis->GetClassObject(rclsid, riid, ppv);
+       }
+
+       HRESULT DllRegisterServer(BOOL bRegTypeLib = TRUE)
+       {
+               T                                                                       *pThis;
+               HRESULT                                                         hResult;
+
+               pThis = static_cast<T *>(this);
+               hResult = pThis->RegisterServer(bRegTypeLib);
+               return hResult;
+       }
+
+       HRESULT DllUnregisterServer(BOOL bUnRegTypeLib = TRUE)
+       {
+               T                                                                       *pThis;
+               HRESULT                                                         hResult;
+
+               pThis = static_cast<T *>(this);
+               hResult = pThis->UnregisterServer(bUnRegTypeLib);
+               return hResult;
+       }
+
+       HRESULT GetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
+       {
+               return AtlComModuleGetClassObject(&_AtlComModule, rclsid, riid, ppv);
+       }
+};
+
+class CComModule : public CAtlModuleT<CComModule>
+{
+public:
+       _ATL_OBJMAP_ENTRY                                               *m_pObjMap;
+public:
+       CComModule()
+       {
+               ATLASSERT(_pModule == NULL);
+               _pModule = this;
+       }
+
+       ~CComModule()
+       {
+               _pModule = NULL;
+       }
+
+       HRESULT GetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
+       {
+               _ATL_OBJMAP_ENTRY                                       *objectMapEntry;
+               HRESULT                                                         hResult;
+
+               ATLASSERT(ppv != NULL);
+               if (ppv == NULL)
+                       return E_POINTER;
+               hResult = S_OK;
+               if (m_pObjMap != NULL)
+               {
+                       objectMapEntry = m_pObjMap;
+                       while (objectMapEntry->pclsid != NULL)
+                       {
+                               if (objectMapEntry->pfnGetClassObject != NULL && InlineIsEqualGUID(rclsid, *objectMapEntry->pclsid) != FALSE)
+                               {
+                                       if (objectMapEntry->pCF == NULL)
+                                       {
+                                               CComCritSecLock<CComCriticalSection> lock(_AtlComModule.m_csObjMap, true);
+
+                                               if (objectMapEntry->pCF == NULL)
+                                                       hResult = objectMapEntry->pfnGetClassObject(reinterpret_cast<void *>(objectMapEntry->pfnCreateInstance), IID_IUnknown, reinterpret_cast<LPVOID *>(&objectMapEntry->pCF));
+                                       }
+                                       if (objectMapEntry->pCF != NULL)
+                                               hResult = objectMapEntry->pCF->QueryInterface(riid, ppv);
+                                       break;
+                               }
+                               objectMapEntry++;
+                       }
+               }
+               return hResult;
+       }
+
+       HRESULT RegisterServer(BOOL bRegTypeLib = FALSE, const CLSID *pCLSID = NULL)
+       {
+               _ATL_OBJMAP_ENTRY                                       *objectMapEntry;
+               HRESULT                                                         hResult;
+
+               hResult = S_OK;
+               objectMapEntry = m_pObjMap;
+               if (objectMapEntry != NULL)
+               {
+                       while (objectMapEntry->pclsid != NULL)
+                       {
+                               if (pCLSID == NULL || IsEqualGUID(*pCLSID, *objectMapEntry->pclsid) != FALSE)
+                               {
+                                       hResult = objectMapEntry->pfnUpdateRegistry(TRUE);
+                                       if (FAILED(hResult))
+                                               break;
+                               }
+                               objectMapEntry++;
+                       }
+               }
+               return hResult;
+       }
+
+       HRESULT UnregisterServer(BOOL bUnRegTypeLib, const CLSID *pCLSID = NULL)
+       {
+               _ATL_OBJMAP_ENTRY                                       *objectMapEntry;
+               HRESULT                                                         hResult;
+
+               hResult = S_OK;
+               objectMapEntry = m_pObjMap;
+               if (objectMapEntry != NULL)
+               {
+                       while (objectMapEntry->pclsid != NULL)
+                       {
+                               if (pCLSID == NULL || IsEqualGUID(*pCLSID, *objectMapEntry->pclsid) != FALSE)
+                               {
+                                       hResult = objectMapEntry->pfnUpdateRegistry(FALSE); //unregister
+                                       if (FAILED(hResult))
+                                               break;
+                               }
+                               objectMapEntry++;
+                       }
+               }
+               return hResult;
+       }
+
+       HRESULT DllCanUnloadNow()
+       {
+               if (GetLockCount() == 0)
+                       return S_OK;
+               return S_FALSE;
+       }
+
+       HRESULT DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
+       {
+               return GetClassObject(rclsid, riid, ppv);
+       }
+
+       HRESULT DllRegisterServer(BOOL bRegTypeLib = TRUE)
+       {
+               return RegisterServer(bRegTypeLib);
+       }
+
+       HRESULT DllUnregisterServer(BOOL bUnRegTypeLib = TRUE)
+       {
+               return UnregisterServer(bUnRegTypeLib);
+       }
+
+};
+
+class CAtlWinModule : public _ATL_WIN_MODULE
+{
+public:
+       CAtlWinModule()
+       {
+               HRESULT                                                         hResult;
+
+               hResult = AtlWinModuleInit(this);
+               if (FAILED(hResult))
+               {
+                       CAtlBaseModule::m_bInitFailed = true;
+                       ATLASSERT(0);
+               }
+       }
+
+       ~CAtlWinModule()
+       {
+               Term();
+       }
+
+       void Term()
+       {
+        AtlWinModuleTerm(this, _AtlBaseModule.GetModuleInstance());
+       }
+
+       void AddCreateWndData(_AtlCreateWndData *pData, void *pObject)
+       {
+               AtlWinModuleAddCreateWndData(this, pData, pObject);
+       }
+
+       void *ExtractCreateWndData()
+       {
+               return AtlWinModuleExtractCreateWndData(this);
+       }
+};
+
+extern CAtlWinModule _AtlWinModule;
+
+template<class T>
+class CComPtr
+{
+public:
+       T                                                                               *p;
+public:
+       CComPtr()
+       {
+               p = NULL;
+       }
+
+       CComPtr(T *lp)
+       {
+               p = lp;
+               if (p != NULL)
+                       p->AddRef();
+       }
+
+       CComPtr(const CComPtr<T> &lp)
+       {
+               p = lp.p;
+               if (p != NULL)
+                       p->AddRef();
+       }
+
+       ~CComPtr()
+       {
+               if (p != NULL)
+                       p->Release();
+       }
+
+       T *operator = (T *lp)
+       {
+               if (p != NULL)
+                       p->Release();
+               p = lp;
+               if (p != NULL)
+                       p->AddRef();
+        return *this;
+       }
+
+       T *operator = (const CComPtr<T> &lp)
+       {
+               if (p != NULL)
+                       p->Release();
+               p = lp.p;
+               if (p != NULL)
+                       p->AddRef();
+        return *this;
+       }
+
+       void Release()
+       {
+               if (p != NULL)
+               {
+                       p->Release();
+                       p = NULL;
+               }
+       }
+
+       void Attach(T *lp)
+       {
+               if (p != NULL)
+                       p->Release();
+               p = lp;
+       }
+
+       T *Detach()
+       {
+               T                                                                       *saveP;
+
+               saveP = p;
+               p = NULL;
+               return saveP;
+       }
+
+       T **operator & ()
+       {
+               ATLASSERT(p == NULL);
+               return &p;
+       }
+
+       operator T * ()
+       {
+               return p;
+       }
+
+       T *operator -> ()
+       {
+               ATLASSERT(p != NULL);
+               return p;
+       }
+};
+
+class CComBSTR
+{
+public:
+       BSTR                                                                    m_str;
+public:
+       CComBSTR(LPCOLESTR pSrc)
+       {
+               if (pSrc == NULL)
+                       m_str = NULL;
+               else
+                       m_str = ::SysAllocString(pSrc);
+       }
+       ~CComBSTR()
+       {
+               ::SysFreeString(m_str);
+               m_str = NULL;
+       }
+};
+
+class CComVariant : public tagVARIANT
+{
+public:
+       CComVariant()
+       {
+               ::VariantInit(this);
+       }
+
+       ~CComVariant()
+       {
+               Clear();
+       }
+
+       HRESULT Clear()
+       {
+               return ::VariantClear(this);
+       }
+};
+
+inline HRESULT __stdcall AtlAdvise(IUnknown *pUnkCP, IUnknown *pUnk, const IID &iid, LPDWORD pdw)
+{
+       CComPtr<IConnectionPointContainer>              container;
+       CComPtr<IConnectionPoint>                               connectionPoint;
+       HRESULT                                                                 hResult;
+
+       if (pUnkCP == NULL)
+               return E_INVALIDARG;
+       hResult = pUnkCP->QueryInterface(IID_IConnectionPointContainer, (void **)&container);
+       if (FAILED(hResult))
+               return hResult;
+       hResult = container->FindConnectionPoint(iid, &connectionPoint);
+       if (FAILED(hResult))
+               return hResult;
+       return connectionPoint->Advise(pUnk, pdw);
+}
+
+inline HRESULT __stdcall AtlUnadvise(IUnknown *pUnkCP, const IID &iid, DWORD dw)
+{
+       CComPtr<IConnectionPointContainer>              container;
+       CComPtr<IConnectionPoint>                               connectionPoint;
+       HRESULT                                                                 hResult;
+
+       if (pUnkCP == NULL)
+               return E_INVALIDARG;
+       hResult = pUnkCP->QueryInterface(IID_IConnectionPointContainer, (void **)&container);
+       if (FAILED(hResult))
+               return hResult;
+       hResult = container->FindConnectionPoint(iid, &connectionPoint);
+       if (FAILED(hResult))
+               return hResult;
+       return connectionPoint->Unadvise(dw);
+}
+
+inline HRESULT __stdcall AtlInternalQueryInterface(void *pThis, const _ATL_INTMAP_ENTRY *pEntries, REFIID iid, void **ppvObject)
+{
+       int                                                     i;
+       IUnknown                                        *resultInterface;
+       HRESULT                                         hResult;
+
+       ATLASSERT(pThis != NULL && pEntries != NULL);
+       if (pThis == NULL || pEntries == NULL)
+               return E_INVALIDARG;
+       ATLASSERT(ppvObject != NULL);
+       if (ppvObject == NULL)
+               return E_POINTER;
+
+       if (InlineIsEqualUnknown(iid))
+       {
+               resultInterface = reinterpret_cast<IUnknown *>(reinterpret_cast<char *>(pThis) + pEntries[0].dw);
+               *ppvObject = resultInterface;
+               resultInterface->AddRef();
+               return S_OK;
+       }
+
+       i = 0;
+       while (pEntries[i].pFunc != 0)
+       {
+               if (pEntries[i].piid == NULL || InlineIsEqualGUID(iid, *pEntries[i].piid))
+               {
+                       if (pEntries[i].pFunc == reinterpret_cast<_ATL_CREATORARGFUNC *>(1))
+                       {
+                               ATLASSERT(pEntries[i].piid != NULL);
+                               resultInterface = reinterpret_cast<IUnknown *>(reinterpret_cast<char *>(pThis) + pEntries[i].dw);
+                               *ppvObject = resultInterface;
+                               resultInterface->AddRef();
+                               return S_OK;
+                       }
+                       else
+                       {
+                               hResult = pEntries[i].pFunc(pThis, iid, ppvObject, 0);
+                               if (hResult == S_OK || (FAILED(hResult) && pEntries[i].piid != NULL))
+                                       return hResult;
+                       }
+                       break;
+               }
+               i++;
+       }
+       *ppvObject = NULL;
+       return E_NOINTERFACE;
+}
+
+inline HRESULT __stdcall AtlWinModuleInit(_ATL_WIN_MODULE *pWinModule)
+{
+       if (pWinModule == NULL)
+               return E_INVALIDARG;
+       pWinModule->m_pCreateWndList = NULL;
+       return pWinModule->m_csWindowCreate.Init();
+}
+
+inline HRESULT __stdcall AtlWinModuleTerm(_ATL_WIN_MODULE *pWinModule, HINSTANCE hInst)
+{
+       if (pWinModule == NULL)
+               return E_INVALIDARG;
+       pWinModule->m_csWindowCreate.Term();
+       return S_OK;
+}
+
+inline void __stdcall AtlWinModuleAddCreateWndData(_ATL_WIN_MODULE *pWinModule, _AtlCreateWndData *pData, void *pObject)
+{
+       CComCritSecLock<CComCriticalSection>    lock(pWinModule->m_csWindowCreate, true);
+
+       ATLASSERT(pWinModule != NULL);
+       ATLASSERT(pObject != NULL);
+
+       pData->m_pThis = pObject;
+       pData->m_dwThreadID = ::GetCurrentThreadId();
+       pData->m_pNext = pWinModule->m_pCreateWndList;
+       pWinModule->m_pCreateWndList = pData;
+}
+
+inline void *__stdcall AtlWinModuleExtractCreateWndData(_ATL_WIN_MODULE *pWinModule)
+{
+       CComCritSecLock<CComCriticalSection>    lock(pWinModule->m_csWindowCreate, true);
+       void                                                                    *result;
+       _AtlCreateWndData                                               *currentEntry;
+       _AtlCreateWndData                                               **previousLink;
+       DWORD                                                                   threadID;
+
+       ATLASSERT(pWinModule != NULL);
+
+       result = NULL;
+       threadID = GetCurrentThreadId();
+       currentEntry = pWinModule->m_pCreateWndList;
+       previousLink = &pWinModule->m_pCreateWndList;
+       while (currentEntry != NULL)
+       {
+               if (currentEntry->m_dwThreadID == threadID)
+               {
+                       *previousLink = currentEntry->m_pNext;
+                       result = currentEntry->m_pThis;
+                       break;
+               }
+               previousLink = &currentEntry->m_pNext;
+               currentEntry = currentEntry->m_pNext;
+       }
+       return result;
+}
+
+}; // namespace ATL
+
+#ifndef _ATL_NO_AUTOMATIC_NAMESPACE
+using namespace ATL;
+#endif //!_ATL_NO_AUTOMATIC_NAMESPACE
+
+#endif // _atlbase_h
diff --git a/reactos/lib/atl/atlcom.h b/reactos/lib/atl/atlcom.h
new file mode 100644 (file)
index 0000000..289594e
--- /dev/null
@@ -0,0 +1,1133 @@
+/*
+ * ReactOS ATL
+ *
+ * Copyright 2009 Andrew Hill <ash77@reactos.org>
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+#ifndef _atlcom_h
+#define _atlcom_h
+
+namespace ATL
+{
+
+template <class Base, const IID *piid, class T, class Copy, class ThreadModel = CComObjectThreadModel>
+class CComEnum;
+
+#define DECLARE_CLASSFACTORY_EX(cf) typedef ATL::CComCreator<ATL::CComObjectCached<cf> > _ClassFactoryCreatorClass;
+#define DECLARE_CLASSFACTORY() DECLARE_CLASSFACTORY_EX(ATL::CComClassFactory)
+
+class CComObjectRootBase
+{
+public:
+       long                                                                    m_dwRef;
+public:
+       CComObjectRootBase()
+       {
+               m_dwRef = 0;
+       }
+
+       ~CComObjectRootBase()
+       {
+       }
+
+       void SetVoid(void *)
+       {
+       }
+
+       HRESULT _AtlFinalConstruct()
+       {
+               return S_OK;
+       }
+
+       HRESULT FinalConstruct()
+       {
+               return S_OK;
+       }
+
+       void InternalFinalConstructAddRef()
+       {
+       }
+
+       void InternalFinalConstructRelease()
+       {
+       }
+
+       void FinalRelease()
+       {
+       }
+
+       static void WINAPI ObjectMain(bool)
+       {
+       }
+
+       static const struct _ATL_CATMAP_ENTRY *GetCategoryMap()
+       {
+               return NULL;
+       }
+
+       static HRESULT WINAPI InternalQueryInterface(void *pThis, const _ATL_INTMAP_ENTRY *pEntries, REFIID iid, void **ppvObject)
+       {
+               return AtlInternalQueryInterface(pThis, pEntries, iid, ppvObject);
+       }
+
+};
+
+template <class ThreadModel>
+class CComObjectRootEx : public CComObjectRootBase
+{
+private:
+       typename ThreadModel::AutoDeleteCriticalSection         m_critsec;
+public:
+       ~CComObjectRootEx()
+       {
+       }
+
+       ULONG InternalAddRef()
+       {
+               ATLASSERT(m_dwRef >= 0);
+               return ThreadModel::Increment(reinterpret_cast<int *>(&m_dwRef));
+       }
+
+       ULONG InternalRelease()
+       {
+               ATLASSERT(m_dwRef > 0);
+               return ThreadModel::Decrement(reinterpret_cast<int *>(&m_dwRef));
+       }
+
+       void Lock()
+       {
+               m_critsec.Lock();
+       }
+
+       void Unlock()
+       {
+               m_critsec.Unlock();
+       }
+
+       HRESULT _AtlInitialConstruct()
+       {
+               return m_critsec.Init();
+       }
+};
+
+template <class Base>
+class CComObject : public Base
+{
+public:
+       CComObject(void * = NULL)
+       {
+               _pAtlModule->Lock();
+       }
+
+       virtual ~CComObject()
+       {
+               CComObject<Base>                                        *pThis;
+
+               pThis = reinterpret_cast<CComObject<Base> *>(this);
+               pThis->FinalRelease();
+               _pAtlModule->Unlock();
+       }
+
+       STDMETHOD_(ULONG, AddRef)()
+       {
+               CComObject<Base>                                        *pThis;
+
+               pThis = reinterpret_cast<CComObject<Base> *>(this);
+               return pThis->InternalAddRef();
+       }
+
+       STDMETHOD_(ULONG, Release)()
+       {
+               CComObject<Base>                                        *pThis;
+               ULONG                                                           l;
+
+               pThis = reinterpret_cast<CComObject<Base> *>(this);
+               l = pThis->InternalRelease();
+               if (l == 0)
+                       delete this;
+               return l;
+       }
+
+       STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
+       {
+               CComObject<Base>                                        *pThis;
+
+               pThis = reinterpret_cast<CComObject<Base> *>(this);
+               return pThis->_InternalQueryInterface(iid, ppvObject);
+       }
+
+       static HRESULT WINAPI CreateInstance(CComObject<Base> **pp)
+       {
+               CComObject<Base>                                        *newInstance;
+               HRESULT                                                         hResult;
+
+               ATLASSERT(pp != NULL);
+               if (pp == NULL)
+                       return E_POINTER;
+
+               hResult = E_OUTOFMEMORY;
+               newInstance = NULL;
+               ATLTRY(newInstance = new CComObject<Base>())
+               if (newInstance != NULL)
+               {
+                       newInstance->SetVoid(NULL);
+                       newInstance->InternalFinalConstructAddRef();
+                       hResult = newInstance->_AtlInitialConstruct();
+                       if (SUCCEEDED(hResult))
+                               hResult = newInstance->FinalConstruct();
+                       if (SUCCEEDED(hResult))
+                               hResult = newInstance->_AtlFinalConstruct();
+                       newInstance->InternalFinalConstructRelease();
+                       if (hResult != S_OK)
+                       {
+                               delete newInstance;
+                               newInstance = NULL;
+                       }
+               }
+               *pp = newInstance;
+               return hResult;
+       }
+
+
+};
+
+template <HRESULT hResult>
+class CComFailCreator
+{
+public:
+       static HRESULT WINAPI CreateInstance(void *, REFIID, LPVOID *ppv)
+       {
+               ATLASSERT(ppv != NULL);
+               if (ppv == NULL)
+                       return E_POINTER;
+               *ppv = NULL;
+
+               return hResult;
+       }
+};
+
+template <class T1>
+class CComCreator
+{
+public:
+       static HRESULT WINAPI CreateInstance(void *pv, REFIID riid, LPVOID *ppv)
+       {
+               T1                                                                      *newInstance;
+               HRESULT                                                         hResult;
+
+               ATLASSERT(ppv != NULL);
+               if (ppv == NULL)
+                       return E_POINTER;
+               *ppv = NULL;
+
+               hResult = E_OUTOFMEMORY;
+               newInstance = NULL;
+               ATLTRY(newInstance = new T1())
+               if (newInstance != NULL)
+               {
+                       newInstance->SetVoid(pv);
+                       newInstance->InternalFinalConstructAddRef();
+                       hResult = newInstance->_AtlInitialConstruct();
+                       if (SUCCEEDED(hResult))
+                               hResult = newInstance->FinalConstruct();
+                       if (SUCCEEDED(hResult))
+                               hResult = newInstance->_AtlFinalConstruct();
+                       newInstance->InternalFinalConstructRelease();
+                       if (SUCCEEDED(hResult))
+                               hResult = newInstance->QueryInterface(riid, ppv);
+                       if (FAILED(hResult))
+                       {
+                               delete newInstance;
+                               newInstance = NULL;
+                       }
+               }
+               return hResult;
+       }
+};
+
+template <class T1, class T2>
+class CComCreator2
+{
+public:
+       static HRESULT WINAPI CreateInstance(void *pv, REFIID riid, LPVOID *ppv)
+       {
+               ATLASSERT(ppv != NULL && ppv != NULL);
+
+               if (pv == NULL)
+                       return T1::CreateInstance(NULL, riid, ppv);
+               else
+                       return T2::CreateInstance(pv, riid, ppv);
+       }
+};
+
+template <class Base>
+class CComObjectCached : public Base
+{
+public:
+       CComObjectCached(void * = NULL)
+       {
+       }
+
+       STDMETHOD_(ULONG, AddRef)()
+       {
+               CComObjectCached<Base>                          *pThis;
+               ULONG                                                           newRefCount;
+
+               pThis = reinterpret_cast<CComObjectCached<Base>*>(this);
+               newRefCount = pThis->InternalAddRef();
+               if (newRefCount == 2)
+                       _pAtlModule->Lock();
+               return newRefCount;
+       }
+
+       STDMETHOD_(ULONG, Release)()
+       {
+               CComObjectCached<Base>                          *pThis;
+               ULONG                                                           newRefCount;
+
+               pThis = reinterpret_cast<CComObjectCached<Base>*>(this);
+               newRefCount = pThis->InternalRelease();
+               if (newRefCount == 0)
+                       delete this;
+               else if (newRefCount == 1)
+                       _pAtlModule->Unlock();
+               return newRefCount;
+       }
+
+       STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
+       {
+               CComObjectCached<Base>                          *pThis;
+
+               pThis = reinterpret_cast<CComObjectCached<Base>*>(this);
+               return pThis->_InternalQueryInterface(iid, ppvObject);
+       }
+};
+
+#define BEGIN_COM_MAP(x)                                                                                                               \
+public:                                                                                                                                                        \
+       typedef x _ComMapClass;                                                                                                         \
+       HRESULT _InternalQueryInterface(REFIID iid, void **ppvObject)                           \
+       {                                                                                                                                                       \
+               return InternalQueryInterface(this, _GetEntries(), iid, ppvObject);             \
+       }                                                                                                                                                       \
+       const static ATL::_ATL_INTMAP_ENTRY *WINAPI _GetEntries()                                       \
+       {                                                                                                                                                       \
+               static const ATL::_ATL_INTMAP_ENTRY _entries[] = {
+
+#define END_COM_MAP()                                                                                                                  \
+                       {NULL, 0, 0}                                                                                                            \
+               };                                                                                                                                              \
+               return _entries;                                                                                                                \
+       }                                                                                                                                                       \
+       virtual ULONG STDMETHODCALLTYPE AddRef() = 0;                                                           \
+       virtual ULONG STDMETHODCALLTYPE Release() = 0;                                                          \
+       STDMETHOD(QueryInterface)(REFIID, void **) = 0;
+
+#define COM_INTERFACE_ENTRY_IID(iid, x)                                                                                        \
+       {&iid, offsetofclass(x, _ComMapClass), _ATL_SIMPLEMAPENTRY},
+
+#define COM_INTERFACE_ENTRY2_IID(iid, x, x2)                                                                   \
+       {&iid,                                                                                                                                          \
+                       reinterpret_cast<DWORD_PTR>(static_cast<x *>(static_cast<x2 *>(reinterpret_cast<_ComMapClass *>(_ATL_PACKING)))) - _ATL_PACKING,        \
+                       _ATL_SIMPLEMAPENTRY},
+
+#define COM_INTERFACE_ENTRY_BREAK(x)                                                                                   \
+    {&_ATL_IIDOF(x),                                                                                                                   \
+    NULL,                                                                                                                                              \
+    _Break},    // Break is a function that issues int 3.
+
+#define COM_INTERFACE_ENTRY_NOINTERFACE(x)                                                                             \
+    {&_ATL_IIDOF(x),                                                                                                                   \
+    NULL,                                                                                                                                              \
+    _NoInterface}, // NoInterface returns E_NOINTERFACE.
+
+#define COM_INTERFACE_ENTRY_FUNC(iid, dw, func)                                                                        \
+    {&iid,                                                                                                                                             \
+    dw,                                                                                                                                                        \
+    func},
+
+#define COM_INTERFACE_ENTRY_FUNC_BLIND(dw, func)                                                               \
+    {NULL,                                                                                                                                             \
+    dw,                                                                                                                                                        \
+    func},
+
+#define COM_INTERFACE_ENTRY_CHAIN(classname)                                                                   \
+    {NULL,                                                                                                                                             \
+    reinterpret_cast<DWORD>(&_CComChainData<classname, _ComMapClass>::data),   \
+    _Chain},
+
+#define DECLARE_REGISTRY_RESOURCEID(x)\
+       static HRESULT WINAPI UpdateRegistry(BOOL bRegister)\
+       {\
+               return ATL::_pAtlModule->UpdateRegistryFromResource(x, bRegister); \
+       }
+
+#define DECLARE_NOT_AGGREGATABLE(x)                                                                                            \
+public:                                                                                                                                                        \
+       typedef ATL::CComCreator2<ATL::CComCreator<ATL::CComObject<x> >, ATL::CComFailCreator<CLASS_E_NOAGGREGATION> > _CreatorClass;
+
+#define DECLARE_AGGREGATABLE(x)                                                                                                        \
+public:                                                                                                                                                        \
+       typedef CComCreator2<CComCreator<CComObject<x> >, CComCreator<CComAggObject<x> > > _CreatorClass;
+
+#define DECLARE_ONLY_AGGREGATABLE(x)                                                                                   \
+public:                                                                                                                                                        \
+       typedef CComCreator2<CComFailCreator<E_FAIL>, CComCreator<CComAggObject<x> > > _CreatorClass;
+
+#define DECLARE_POLY_AGGREGATABLE(x)                                                                                   \
+public:                                                                                                                                                        \
+       typedef CComCreator<CComPolyObject<x> > _CreatorClass;
+
+#define DECLARE_GET_CONTROLLING_UNKNOWN()                                                                              \
+public:                                                                                                                                                        \
+       virtual IUnknown *GetControllingUnknown()                                                                       \
+       {                                                                                                                                                       \
+               return GetUnknown();                                                                                                    \
+       }
+
+#define DECLARE_PROTECT_FINAL_CONSTRUCT()                                                                              \
+       void InternalFinalConstructAddRef()                                                                                     \
+       {                                                                                                                                                       \
+               InternalAddRef();                                                                                                               \
+       }                                                                                                                                                       \
+       void InternalFinalConstructRelease()                                                                            \
+       {                                                                                                                                                       \
+               InternalRelease();                                                                                                              \
+       }
+
+#define BEGIN_OBJECT_MAP(x) static ATL::_ATL_OBJMAP_ENTRY x[] = {
+
+#define END_OBJECT_MAP()   {NULL, NULL, NULL, NULL, NULL, 0, NULL, NULL, NULL}};
+
+#define OBJECT_ENTRY(clsid, class)                                                                                             \
+{                                                                                                                                                              \
+       &clsid,                                                                                                                                         \
+       class::UpdateRegistry,                                                                                                          \
+       class::_ClassFactoryCreatorClass::CreateInstance,                                                       \
+       class::_CreatorClass::CreateInstance,                                                                           \
+       NULL,                                                                                                                                           \
+       0,                                                                                                                                                      \
+       class::GetObjectDescription,                                                                                            \
+       class::GetCategoryMap,                                                                                                          \
+       class::ObjectMain },
+
+class CComClassFactory :
+       public IClassFactory,
+       public CComObjectRootEx<CComGlobalsThreadModel>
+{
+public:
+       _ATL_CREATORFUNC                                                *m_pfnCreateInstance;
+public:
+       STDMETHOD(CreateInstance)(LPUNKNOWN pUnkOuter, REFIID riid, void **ppvObj)
+       {
+               HRESULT                                                         hResult;
+
+               ATLASSERT(m_pfnCreateInstance != NULL);
+
+               if (ppvObj == NULL)
+                       return E_POINTER;
+               *ppvObj = NULL;
+
+               if (pUnkOuter != NULL && InlineIsEqualUnknown(riid) == FALSE)
+                       hResult = CLASS_E_NOAGGREGATION;
+               else
+                       hResult = m_pfnCreateInstance(pUnkOuter, riid, ppvObj);
+               return hResult;
+       }
+
+       STDMETHOD(LockServer)(BOOL fLock)
+       {
+               if (fLock)
+                       _pAtlModule->Lock();
+               else
+                       _pAtlModule->Unlock();
+               return S_OK;
+       }
+
+       void SetVoid(void *pv)
+       {
+               m_pfnCreateInstance = (_ATL_CREATORFUNC *)pv;
+       }
+
+       BEGIN_COM_MAP(CComClassFactory)
+               COM_INTERFACE_ENTRY_IID(IID_IClassFactory, IClassFactory)
+       END_COM_MAP()
+};
+
+template <class T, const CLSID *pclsid = &CLSID_NULL>
+class CComCoClass
+{
+public:
+       DECLARE_CLASSFACTORY()
+
+       static LPCTSTR WINAPI GetObjectDescription()
+       {
+               return NULL;
+       }
+};
+
+template <class T>
+class _Copy
+{
+public:
+       static HRESULT copy(T *pTo, const T *pFrom)
+       {
+               memcpy(pTo, pFrom, sizeof(T));
+               return S_OK;
+       }
+
+       static void init(T *)
+       {
+       }
+
+       static void destroy(T *)
+       {
+       }
+};
+
+template<>
+class _Copy<CONNECTDATA>
+{
+public:
+       static HRESULT copy(CONNECTDATA *pTo, const CONNECTDATA *pFrom)
+       {
+               *pTo = *pFrom;
+               if (pTo->pUnk)
+                       pTo->pUnk->AddRef();
+               return S_OK;
+       }
+
+       static void init(CONNECTDATA *)
+       {
+       }
+
+       static void destroy(CONNECTDATA *p)
+       {
+               if (p->pUnk)
+                       p->pUnk->Release();
+       }
+};
+
+template <class T>
+class _CopyInterface
+{
+public:
+       static HRESULT copy(T **pTo, T **pFrom)
+       {
+               *pTo = *pFrom;
+               if (*pTo)
+                       (*pTo)->AddRef();
+               return S_OK;
+       }
+
+       static void init(T **)
+       {
+       }
+
+       static void destroy(T **p)
+       {
+               if (*p)
+                       (*p)->Release();
+       }
+};
+
+enum CComEnumFlags
+{
+       AtlFlagNoCopy = 0,
+       AtlFlagTakeOwnership = 2,               // BitOwn
+       AtlFlagCopy = 3                                 // BitOwn | BitCopy
+};
+
+template <class Base, const IID *piid, class T, class Copy>
+class CComEnumImpl : public Base
+{
+private:
+       typedef CComObject<CComEnum<Base, piid, T, Copy> > enumeratorClass;
+public:
+       CComPtr<IUnknown>                                               m_spUnk;
+       DWORD                                                                   m_dwFlags;
+       T                                                                               *m_begin;
+       T                                                                               *m_end;
+       T                                                                               *m_iter;
+public:
+       CComEnumImpl()
+       {
+               m_dwFlags = 0;
+               m_begin = NULL;
+               m_end = NULL;
+               m_iter = NULL;
+       }
+
+       virtual ~CComEnumImpl()
+       {
+               T                                                                       *x;
+
+               if ((m_dwFlags & BitOwn) != 0)
+               {
+                       for (x = m_begin; x != m_end; x++)
+                               Copy::destroy(x);
+                       delete [] m_begin;
+               }
+       }
+
+       HRESULT Init(T *begin, T *end, IUnknown *pUnk, CComEnumFlags flags = AtlFlagNoCopy)
+       {
+               T                                                                       *newBuffer;
+               T                                                                       *sourcePtr;
+               T                                                                       *destPtr;
+               T                                                                       *cleanupPtr;
+               HRESULT                                                         hResult;
+
+               if (flags == AtlFlagCopy)
+               {
+                       ATLTRY(newBuffer = new T[end - begin])
+                       if (newBuffer == NULL)
+                               return E_OUTOFMEMORY;
+                       destPtr = newBuffer;
+                       for (sourcePtr = begin; sourcePtr != end; sourcePtr++)
+                       {
+                               Copy::init(destPtr);
+                               hResult = Copy::copy(destPtr, sourcePtr);
+                               if (FAILED(hResult))
+                               {
+                                       cleanupPtr = m_begin;
+                                       while (cleanupPtr < destPtr)
+                                               Copy::destroy(cleanupPtr++);
+                                       delete [] newBuffer;
+                                       return hResult;
+                               }
+                               destPtr++;
+                       }
+                       m_begin = newBuffer;
+                       m_end = m_begin + (end - begin);
+               }
+               else
+               {
+                       m_begin = begin;
+                       m_end = end;
+               }
+               m_spUnk = pUnk;
+               m_dwFlags = flags;
+               m_iter = m_begin;
+               return S_OK;
+       }
+
+       STDMETHOD(Next)(ULONG celt, T *rgelt, ULONG *pceltFetched)
+       {
+               ULONG                                                           numAvailable;
+               ULONG                                                           numToFetch;
+               T                                                                       *rgeltTemp;
+               HRESULT                                                         hResult;
+
+               if (pceltFetched != NULL)
+                       *pceltFetched = 0;
+               if (celt == 0)
+                       return E_INVALIDARG;
+               if (rgelt == NULL || (celt != 1 && pceltFetched == NULL))
+                       return E_POINTER;
+               if (m_begin == NULL || m_end == NULL || m_iter == NULL)
+                       return E_FAIL;
+
+               numAvailable = static_cast<ULONG>(m_end - m_iter);
+               if (celt < numAvailable)
+                       numToFetch = celt;
+               else
+                       numToFetch = numAvailable;
+               if (pceltFetched != NULL)
+                       *pceltFetched = numToFetch;
+               rgeltTemp = rgelt;
+               while (numToFetch != 0)
+               {
+                       hResult = Copy::copy(rgeltTemp, m_iter);
+                       if (FAILED(hResult))
+                       {
+                               while (rgelt < rgeltTemp)
+                                       Copy::destroy(rgelt++);
+                               if (pceltFetched != NULL)
+                                       *pceltFetched = 0;
+                               return hResult;
+                       }
+                       rgeltTemp++;
+                       m_iter++;
+                       numToFetch--;
+               }
+               if (numAvailable < celt)
+                       return S_FALSE;
+               return S_OK;
+       }
+
+       STDMETHOD(Skip)(ULONG celt)
+       {
+               ULONG                                                           numAvailable;
+               ULONG                                                           numToSkip;
+
+               if (celt == 0)
+                       return E_INVALIDARG;
+
+               numAvailable = static_cast<ULONG>(m_end - m_iter);
+               if (celt < numAvailable)
+                       numToSkip = celt;
+               else
+                       numToSkip = numAvailable;
+               m_iter += numToSkip;
+               if (numAvailable < celt)
+                       return S_FALSE;
+               return S_OK;
+       }
+
+       STDMETHOD(Reset)()
+       {
+               m_iter = m_begin;
+               return S_OK;
+       }
+
+       STDMETHOD(Clone)(Base **ppEnum)
+       {
+               enumeratorClass                                         *newInstance;
+               HRESULT                                                         hResult;
+
+               hResult = E_POINTER;
+               if (ppEnum != NULL)
+               {
+                       *ppEnum = NULL;
+                       hResult = enumeratorClass::CreateInstance(&newInstance);
+                       if (SUCCEEDED(hResult))
+                       {
+                               hResult = newInstance->Init(m_begin, m_end, (m_dwFlags & BitOwn) ? this : m_spUnk);
+                               if (SUCCEEDED(hResult))
+                               {
+                                       newInstance->m_iter = m_iter;
+                                       hResult = newInstance->_InternalQueryInterface(*piid, (void **)ppEnum);
+                               }
+                               if (FAILED(hResult))
+                                       delete newInstance;
+                       }
+               }
+               return hResult;
+       }
+
+protected:
+       enum FlagBits
+       {
+               BitCopy = 1,
+               BitOwn = 2
+       };
+};
+
+template <class Base, const IID *piid, class T, class Copy, class ThreadModel>
+class CComEnum :
+       public CComEnumImpl<Base, piid, T, Copy>,
+       public CComObjectRootEx<ThreadModel>
+{
+public:
+       typedef CComEnum<Base, piid, T, Copy > _CComEnum;
+       typedef CComEnumImpl<Base, piid, T, Copy > _CComEnumBase;
+
+       BEGIN_COM_MAP(_CComEnum)
+               COM_INTERFACE_ENTRY_IID(*piid, _CComEnumBase)
+       END_COM_MAP()
+};
+
+#ifndef _DEFAULT_VECTORLENGTH
+#define _DEFAULT_VECTORLENGTH 4
+#endif
+
+class CComDynamicUnkArray
+{
+public:
+       int                                                                             m_nSize;
+       IUnknown                                                                **m_ppUnk;
+public:
+       CComDynamicUnkArray()
+       {
+               m_nSize = 0;
+               m_ppUnk = NULL;
+       }
+
+       ~CComDynamicUnkArray()
+       {
+               free(m_ppUnk);
+       }
+
+       IUnknown **begin()
+       {
+               return m_ppUnk;
+       }
+
+       IUnknown **end()
+       {
+               return &m_ppUnk[m_nSize];
+       }
+
+       IUnknown *GetAt(int nIndex)
+       {
+               ATLASSERT(nIndex >= 0 && nIndex < m_nSize);
+               if (nIndex >= 0 && nIndex < m_nSize)
+                       return m_ppUnk[nIndex];
+               else
+                       return NULL;
+       }
+
+       IUnknown *WINAPI GetUnknown(DWORD dwCookie)
+       {
+               ATLASSERT(dwCookie != 0 && dwCookie <= static_cast<DWORD>(m_nSize));
+               if (dwCookie != 0 && dwCookie <= static_cast<DWORD>(m_nSize))
+                       return GetAt(dwCookie - 1);
+               else
+                       return NULL;
+       }
+
+       DWORD WINAPI GetCookie(IUnknown **ppFind)
+       {
+               IUnknown                                                        **x;
+               DWORD                                                           curCookie;
+
+               ATLASSERT(ppFind != NULL && *ppFind != NULL);
+               if (ppFind != NULL && *ppFind != NULL)
+               {
+                       curCookie = 1;
+                       for (x = begin(); x < end(); x++)
+                       {
+                               if (*x == *ppFind)
+                                       return curCookie;
+                               curCookie++;
+                       }
+               }
+               return 0;
+       }
+
+       DWORD Add(IUnknown *pUnk)
+       {
+               IUnknown                                                        **x;
+               IUnknown                                                        **newArray;
+               int                                                                     newSize;
+               DWORD                                                           curCookie;
+
+               ATLASSERT(pUnk != NULL);
+               if (m_nSize == 0)
+               {
+                       newSize = _DEFAULT_VECTORLENGTH * sizeof(IUnknown *);
+                       ATLTRY(newArray = reinterpret_cast<IUnknown **>(malloc(newSize)));
+                       if (newArray == NULL)
+                               return 0;
+                       memset(newArray, 0, newSize);
+                       m_ppUnk = newArray;
+                       m_nSize = _DEFAULT_VECTORLENGTH;
+               }
+               curCookie = 1;
+               for (x = begin(); x < end(); x++)
+               {
+                       if (*x == NULL)
+                       {
+                               *x = pUnk;
+                               return curCookie;
+                       }
+                       curCookie++;
+               }
+               newSize = m_nSize * 2;
+               newArray = reinterpret_cast<IUnknown **>(realloc(m_ppUnk, newSize * sizeof(IUnknown *)));
+               if (newArray == NULL)
+                       return 0;
+               memset(&m_ppUnk[m_nSize], 0, (newSize - m_nSize) * sizeof(IUnknown *));
+               m_ppUnk = newArray;
+               m_nSize = newSize;
+               m_ppUnk[m_nSize] = pUnk;
+               return m_nSize + 1;
+       }
+
+       BOOL Remove(DWORD dwCookie)
+       {
+               DWORD                                                           index;
+
+               index = dwCookie - 1;
+               ATLASSERT(index < dwCookie && index < static_cast<DWORD>(m_nSize));
+               if (index < dwCookie && index < static_cast<DWORD>(m_nSize) && m_ppUnk[index] != NULL)
+               {
+                       m_ppUnk[index] = NULL;
+                       return TRUE;
+               }
+               return FALSE;
+       }
+
+};
+
+struct _ATL_CONNMAP_ENTRY
+{
+       DWORD_PTR                                                               dwOffset;
+};
+
+template <const IID *piid>
+class _ICPLocator
+{
+public:
+       STDMETHOD(_LocCPQueryInterface)(REFIID riid, void **ppvObject) = 0;
+       virtual ULONG STDMETHODCALLTYPE AddRef() = 0;
+       virtual ULONG STDMETHODCALLTYPE Release() = 0;
+};
+
+template<class T, const IID *piid, class CDV = CComDynamicUnkArray>
+class IConnectionPointImpl : public _ICPLocator<piid>
+{
+       typedef CComEnum<IEnumConnections, &IID_IEnumConnections, CONNECTDATA, _Copy<CONNECTDATA> > CComEnumConnections;
+public:
+       CDV                                                                             m_vec;
+public:
+       ~IConnectionPointImpl()
+       {
+               IUnknown                                                        **x;
+
+               for (x = m_vec.begin(); x < m_vec.end(); x++)
+                       if (*x != NULL)
+                               (*x)->Release();
+       }
+
+       STDMETHOD(_LocCPQueryInterface)(REFIID riid, void **ppvObject)
+       {
+               IConnectionPointImpl<T, piid, CDV>      *pThis;
+
+               pThis = reinterpret_cast<IConnectionPointImpl<T, piid, CDV>*>(this);
+
+               ATLASSERT(ppvObject != NULL);
+               if (ppvObject == NULL)
+                       return E_POINTER;
+
+               if (InlineIsEqualGUID(riid, IID_IConnectionPoint) || InlineIsEqualUnknown(riid))
+               {
+                       *ppvObject = this;
+                       pThis->AddRef();
+                       return S_OK;
+               }
+               else
+               {
+                       *ppvObject = NULL;
+                       return E_NOINTERFACE;
+               }
+       }
+
+       STDMETHOD(GetConnectionInterface)(IID *piid2)
+       {
+               if (piid2 == NULL)
+                       return E_POINTER;
+               *piid2 = *piid;
+               return S_OK;
+       }
+
+       STDMETHOD(GetConnectionPointContainer)(IConnectionPointContainer **ppCPC)
+       {
+               T                                                                       *pThis;
+
+               pThis = static_cast<T *>(this);
+               return pThis->QueryInterface(IID_IConnectionPointContainer, reinterpret_cast<void **>(ppCPC));
+       }
+
+       STDMETHOD(Advise)(IUnknown *pUnkSink, DWORD *pdwCookie)
+       {
+               IUnknown                                                        *adviseTarget;
+               IID                                                                     interfaceID;
+               HRESULT                                                         hResult;
+
+               if (pdwCookie != NULL)
+                       *pdwCookie = 0;
+               if (pUnkSink == NULL || pdwCookie == NULL)
+                       return E_POINTER;
+               GetConnectionInterface(&interfaceID);                   // can't fail
+               hResult = pUnkSink->QueryInterface(interfaceID, reinterpret_cast<void **>(&adviseTarget));
+               if (SUCCEEDED(hResult))
+               {
+                       *pdwCookie = m_vec.Add(adviseTarget);
+                       if (*pdwCookie != 0)
+                               hResult = S_OK;
+                       else
+                       {
+                               adviseTarget->Release();
+                               hResult = CONNECT_E_ADVISELIMIT;
+                       }
+               }
+               else if (hResult == E_NOINTERFACE)
+                       hResult = CONNECT_E_CANNOTCONNECT;
+               return hResult;
+       }
+
+       STDMETHOD(Unadvise)(DWORD dwCookie)
+       {
+               IUnknown                                                        *adviseTarget;
+               HRESULT                                                         hResult;
+
+               adviseTarget = m_vec.GetUnknown(dwCookie);
+               if (m_vec.Remove(dwCookie))
+               {
+                       if (adviseTarget != NULL)
+                               adviseTarget->Release();
+                       hResult = S_OK;
+               }
+               else
+                       hResult = CONNECT_E_NOCONNECTION;
+               return hResult;
+       }
+
+       STDMETHOD(EnumConnections)(IEnumConnections **ppEnum)
+       {
+               CComObject<CComEnumConnections>         *newEnumerator;
+               CONNECTDATA                                                     *itemBuffer;
+               CONNECTDATA                                                     *itemBufferEnd;
+               IUnknown                                                        **x;
+               HRESULT                                                         hResult;
+
+               ATLASSERT(ppEnum != NULL);
+               if (ppEnum == NULL)
+                       return E_POINTER;
+               *ppEnum = NULL;
+
+               ATLTRY(itemBuffer = new CONNECTDATA[m_vec.end() - m_vec.begin()])
+               if (itemBuffer == NULL)
+                       return E_OUTOFMEMORY;
+               itemBufferEnd = itemBuffer;
+               for (x = m_vec.begin(); x < m_vec.end(); x++)
+               {
+                       if (*x != NULL)
+                       {
+                               (*x)->AddRef();
+                               itemBufferEnd->pUnk = *x;
+                               itemBufferEnd->dwCookie = m_vec.GetCookie(x);
+                               itemBufferEnd++;
+                       }
+               }
+               ATLTRY(newEnumerator = new CComObject<CComEnumConnections>)
+               if (newEnumerator == NULL)
+                       return E_OUTOFMEMORY;
+               newEnumerator->Init(itemBuffer, itemBufferEnd, NULL, AtlFlagTakeOwnership);             // can't fail
+               hResult = newEnumerator->_InternalQueryInterface(IID_IEnumConnections, (void **)ppEnum);
+               if (FAILED(hResult))
+                       delete newEnumerator;
+               return hResult;
+       }
+};
+
+template <class T>
+class IConnectionPointContainerImpl : public IConnectionPointContainer
+{
+               typedef const _ATL_CONNMAP_ENTRY * (*handlerFunctionType)(int *);
+               typedef CComEnum<IEnumConnectionPoints, &IID_IEnumConnectionPoints, IConnectionPoint *, _CopyInterface<IConnectionPoint> >
+                                               CComEnumConnectionPoints;
+
+public:
+       STDMETHOD(EnumConnectionPoints)(IEnumConnectionPoints **ppEnum)
+       {
+               const _ATL_CONNMAP_ENTRY                        *entryPtr;
+               int                                                                     connectionPointCount;
+               IConnectionPoint                                        **itemBuffer;
+               int                                                                     destIndex;
+               handlerFunctionType                                     handlerFunction;
+               CComEnumConnectionPoints                        *newEnumerator;
+               HRESULT                                                         hResult;
+
+               ATLASSERT(ppEnum != NULL);
+               if (ppEnum == NULL)
+                       return E_POINTER;
+               *ppEnum = NULL;
+
+               entryPtr = T::GetConnMap(&connectionPointCount);
+               ATLTRY(itemBuffer = new IConnectionPoint * [connectionPointCount])
+               if (itemBuffer == NULL)
+                       return E_OUTOFMEMORY;
+
+               destIndex = 0;
+               while (entryPtr->dwOffset != static_cast<DWORD_PTR>(-1))
+               {
+                       if (entryPtr->dwOffset == static_cast<DWORD_PTR>(-2))
+                       {
+                               entryPtr++;
+                               handlerFunction = reinterpret_cast<handlerFunctionType>(entryPtr->dwOffset);
+                               entryPtr = handlerFunction(NULL);
+                       }
+                       else
+                       {
+                               itemBuffer[destIndex++] = reinterpret_cast<IConnectionPoint *>((char *)this + entryPtr->dwOffset);
+                               entryPtr++;
+                       }
+               }
+
+               ATLTRY(newEnumerator = new CComObject<CComEnumConnectionPoints>)
+               if (newEnumerator == NULL)
+               {
+                       delete [] itemBuffer;
+                       return E_OUTOFMEMORY;
+               }
+
+               newEnumerator->Init(&itemBuffer[0], &itemBuffer[destIndex], NULL, AtlFlagTakeOwnership);        // can't fail
+               hResult = newEnumerator->QueryInterface(IID_IEnumConnectionPoints, (void**)ppEnum);
+               if (FAILED(hResult))
+                       delete newEnumerator;
+               return hResult;
+       }
+
+       STDMETHOD(FindConnectionPoint)(REFIID riid, IConnectionPoint **ppCP)
+       {
+               IID                                                                     interfaceID;
+               const _ATL_CONNMAP_ENTRY                        *entryPtr;
+               handlerFunctionType                                     handlerFunction;
+               IConnectionPoint                                        *connectionPoint;
+               HRESULT                                                         hResult;
+
+               if (ppCP == NULL)
+                       return E_POINTER;
+               *ppCP = NULL;
+               hResult = CONNECT_E_NOCONNECTION;
+               entryPtr = T::GetConnMap(NULL);
+               while (entryPtr->dwOffset != static_cast<DWORD_PTR>(-1))
+               {
+                       if (entryPtr->dwOffset == static_cast<DWORD_PTR>(-2))
+                       {
+                               entryPtr++;
+                               handlerFunction = reinterpret_cast<handlerFunctionType>(entryPtr->dwOffset);
+                               entryPtr = handlerFunction(NULL);
+                       }
+                       else
+                       {
+                               connectionPoint = reinterpret_cast<IConnectionPoint *>(reinterpret_cast<char *>(this) + entryPtr->dwOffset);
+                               if (SUCCEEDED(connectionPoint->GetConnectionInterface(&interfaceID)) && InlineIsEqualGUID(riid, interfaceID))
+                               {
+                                       *ppCP = connectionPoint;
+                                       connectionPoint->AddRef();
+                                       hResult = S_OK;
+                                       break;
+                               }
+                               entryPtr++;
+                       }
+               }
+               return hResult;
+       }
+};
+
+#define BEGIN_CONNECTION_POINT_MAP(x)                                                                                  \
+       typedef x _atl_conn_classtype;                                                                                          \
+       static const ATL::_ATL_CONNMAP_ENTRY *GetConnMap(int *pnEntries) {                      \
+       static const ATL::_ATL_CONNMAP_ENTRY _entries[] = {
+
+#define END_CONNECTION_POINT_MAP()                                                                                             \
+       {(DWORD_PTR)-1} };                                                                                                                      \
+       if (pnEntries)                                                                                                                          \
+               *pnEntries = sizeof(_entries) / sizeof(ATL::_ATL_CONNMAP_ENTRY) - 1;    \
+       return _entries;}
+
+#define CONNECTION_POINT_ENTRY(iid)                                                                                            \
+       {offsetofclass(ATL::_ICPLocator<&iid>, _atl_conn_classtype) -                           \
+       offsetofclass(ATL::IConnectionPointContainerImpl<_atl_conn_classtype>, _atl_conn_classtype)},
+
+}; // namespace ATL
+
+#endif // _atlcom_h
diff --git a/reactos/lib/atl/atlcore.cpp b/reactos/lib/atl/atlcore.cpp
new file mode 100644 (file)
index 0000000..83eb43f
--- /dev/null
@@ -0,0 +1,30 @@
+/*
+ * ReactOS ATL
+ *
+ * Copyright 2009 Andrew Hill <ash77@reactos.org>
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+#include <windows.h>
+#include <ocidl.h>
+#include "atlcore.h"
+
+namespace ATL
+{
+
+bool CAtlBaseModule::m_bInitFailed = false;
+
+};  // namespace ATL
diff --git a/reactos/lib/atl/atlcore.h b/reactos/lib/atl/atlcore.h
new file mode 100644 (file)
index 0000000..49198bb
--- /dev/null
@@ -0,0 +1,199 @@
+/*
+ * ReactOS ATL
+ *
+ * Copyright 2009 Andrew Hill <ash77@reactos.org>
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+#ifndef _atlcore_h
+#define _atlcore_h
+
+#include <malloc.h>
+#include <olectl.h>
+#include <ocidl.h>
+
+#include <crtdbg.h>
+
+#ifndef ATLASSERT
+#define ATLASSERT(expr) _ASSERTE(expr)
+#endif // ATLASSERT
+
+namespace ATL
+{
+
+class CComCriticalSection
+{
+public:
+       CRITICAL_SECTION                                                m_sec;
+public:
+       CComCriticalSection()
+       {
+               memset(&m_sec, 0, sizeof(CRITICAL_SECTION));
+       }
+
+       ~CComCriticalSection()
+       {
+       }
+
+       HRESULT Lock()
+       {
+               EnterCriticalSection(&m_sec);
+               return S_OK;
+       }
+
+       HRESULT Unlock()
+       {
+               LeaveCriticalSection(&m_sec);
+               return S_OK;
+       }
+
+       HRESULT Init()
+       {
+               InitializeCriticalSection(&m_sec);
+               return S_OK;
+       }
+
+       HRESULT Term()
+       {
+               DeleteCriticalSection(&m_sec);
+               return S_OK;
+       }
+};
+
+class CComFakeCriticalSection
+{
+public:
+       HRESULT Lock()
+       {
+               return S_OK;
+       }
+
+       HRESULT Unlock()
+       {
+               return S_OK;
+       }
+
+       HRESULT Init()
+       {
+               return S_OK;
+       }
+
+       HRESULT Term()
+       {
+               return S_OK;
+       }
+};
+
+class CComAutoCriticalSection : public CComCriticalSection
+{
+public:
+       CComAutoCriticalSection()
+       {
+               HRESULT                                                         hResult;
+
+               hResult = CComCriticalSection::Init();
+               ATLASSERT(SUCCEEDED(hResult));
+       }
+       ~CComAutoCriticalSection()
+       {
+               CComCriticalSection::Term();
+       }
+};
+
+class CComSafeDeleteCriticalSection : public CComCriticalSection
+{
+private:
+       bool                                                                    m_bInitialized;
+public:
+       CComSafeDeleteCriticalSection()
+       {
+               m_bInitialized = false;
+       }
+
+       ~CComSafeDeleteCriticalSection()
+       {
+               Term();
+       }
+
+       HRESULT Lock()
+       {
+               ATLASSERT(m_bInitialized);
+               return CComCriticalSection::Lock();
+       }
+
+       HRESULT Init()
+       {
+               HRESULT                                                         hResult;
+
+               ATLASSERT(!m_bInitialized);
+               hResult = CComCriticalSection::Init();
+               if (SUCCEEDED(hResult))
+                       m_bInitialized = true;
+               return hResult;
+       }
+
+       HRESULT Term()
+       {
+               if (!m_bInitialized)
+                       return S_OK;
+               m_bInitialized = false;
+               return CComCriticalSection::Term();
+       }
+};
+
+class CComAutoDeleteCriticalSection : public CComSafeDeleteCriticalSection
+{
+private:
+       // CComAutoDeleteCriticalSection::Term should never be called
+       HRESULT Term();
+};
+
+struct _ATL_BASE_MODULE70
+{
+       UINT                                                                    cbSize;
+       HINSTANCE                                                               m_hInst;
+       HINSTANCE                                                               m_hInstResource;
+       bool                                                                    m_bNT5orWin98;
+       DWORD                                                                   dwAtlBuildVer;
+       GUID                                                                    *pguidVer;
+       CRITICAL_SECTION                                                m_csResource;
+#ifdef NOTYET
+       CSimpleArray<HINSTANCE>                                 m_rgResourceInstance;
+#endif
+};
+typedef _ATL_BASE_MODULE70 _ATL_BASE_MODULE;
+
+class CAtlBaseModule : public _ATL_BASE_MODULE
+{
+public :
+       static bool                                                             m_bInitFailed;
+public:
+       HINSTANCE GetModuleInstance()
+       {
+               return m_hInst;
+       }
+
+       HINSTANCE GetResourceInstance()
+       {
+               return m_hInstResource;
+       }
+};
+
+extern CAtlBaseModule _AtlBaseModule;
+
+}; // namespace ATL
+
+#endif // _atlcore_h
diff --git a/reactos/lib/atl/atlwin.h b/reactos/lib/atl/atlwin.h
new file mode 100644 (file)
index 0000000..e85f8b7
--- /dev/null
@@ -0,0 +1,751 @@
+/*
+ * ReactOS ATL
+ *
+ * Copyright 2009 Andrew Hill <ash77@reactos.org>
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+#ifndef _atlin_h
+#define _atlwin_h
+
+#ifdef __GNUC__
+#define GCCU(x)        x __attribute__((unused))
+#define Unused(x)
+#else
+#define GCCU(x)
+#define Unused(x)      (x);
+#endif // __GNUC__
+
+#ifdef SetWindowLongPtr
+#undef SetWindowLongPtr
+inline LONG_PTR SetWindowLongPtr(HWND hWnd, int nIndex, LONG_PTR dwNewLong)
+{
+       return SetWindowLong(hWnd, nIndex, (LONG)dwNewLong);
+}
+#endif
+
+#ifdef GetWindowLongPtr
+#undef GetWindowLongPtr
+inline LONG_PTR GetWindowLongPtr(HWND hWnd, int nIndex)
+{
+       return (LONG_PTR)GetWindowLong(hWnd, nIndex);
+}
+#endif
+
+namespace ATL
+{
+
+struct _ATL_WNDCLASSINFOW;
+typedef _ATL_WNDCLASSINFOW CWndClassInfo;
+
+template <DWORD t_dwStyle = 0, DWORD t_dwExStyle = 0>
+class CWinTraits
+{
+public:
+       static DWORD GetWndStyle(DWORD dwStyle)
+       {
+               if (dwStyle == 0)
+                       return t_dwStyle;
+               return dwStyle;
+       }
+
+       static DWORD GetWndExStyle(DWORD dwExStyle)
+       {
+               if (dwExStyle == 0)
+                       return t_dwExStyle;
+               return dwExStyle;
+       }
+};
+
+typedef CWinTraits<WS_CHILD | WS_VISIBLE | WS_CLIPCHILDREN | WS_CLIPSIBLINGS, 0> CControlWinTraits;
+typedef CWinTraits<WS_OVERLAPPEDWINDOW | WS_CLIPCHILDREN | WS_CLIPSIBLINGS, WS_EX_APPWINDOW | WS_EX_WINDOWEDGE> CFrameWinTraits;
+typedef CWinTraits<WS_OVERLAPPEDWINDOW | WS_CHILD | WS_VISIBLE | WS_CLIPCHILDREN | WS_CLIPSIBLINGS, WS_EX_MDICHILD> CMDIChildWinTraits;
+
+template <DWORD t_dwStyle = 0, DWORD t_dwExStyle = 0, class TWinTraits = CControlWinTraits>
+class CWinTraitsOR
+{
+public:
+       static DWORD GetWndStyle(DWORD dwStyle)
+       {
+               return dwStyle | t_dwStyle | TWinTraits::GetWndStyle(dwStyle);
+       }
+
+       static DWORD GetWndExStyle(DWORD dwExStyle)
+       {
+               return dwExStyle | t_dwExStyle | TWinTraits::GetWndExStyle(dwExStyle);
+       }
+};
+
+class _U_MENUorID
+{
+public:
+       HMENU                                                                   m_hMenu;
+public:
+       _U_MENUorID(HMENU hMenu)
+       {
+               m_hMenu = hMenu;
+       }
+
+       _U_MENUorID(UINT nID)
+       {
+               m_hMenu = (HMENU)(UINT_PTR)nID;
+       }
+};
+
+class _U_RECT
+{
+public:
+       LPRECT                                                                  m_lpRect;
+public:
+       _U_RECT(LPRECT lpRect)
+       {
+               m_lpRect = lpRect;
+       }
+
+       _U_RECT(RECT &rc)
+       {
+               m_lpRect = &rc;
+       }
+};
+
+struct _ATL_MSG : public MSG
+{
+public:
+       BOOL                                                                    bHandled;
+public:
+       _ATL_MSG(HWND hWnd, UINT uMsg, WPARAM wParamIn, LPARAM lParamIn, BOOL bHandledIn = TRUE)
+       {
+               hwnd = hWnd;
+               message = uMsg;
+               wParam = wParamIn;
+               lParam = lParamIn;
+               time = 0;
+               pt.x = 0;
+               pt.y = 0;
+               bHandled = bHandledIn;
+       }
+};
+
+#if defined(_M_IX86)
+
+#pragma pack(push,1)
+struct thunkCode
+{
+       DWORD                                                                   m_mov;
+       DWORD                                                                   m_this;
+       BYTE                                                                    m_jmp;
+       DWORD                                                                   m_relproc;
+};
+#pragma pack(pop)
+
+class CWndProcThunk
+{
+public:
+       thunkCode                                                               m_thunk;
+       _AtlCreateWndData                                               cd;
+public:
+       BOOL Init(WNDPROC proc, void *pThis)
+       {
+               m_thunk.m_mov = 0x042444C7;
+               m_thunk.m_this = PtrToUlong(pThis);
+               m_thunk.m_jmp = 0xe9;
+               m_thunk.m_relproc = DWORD(reinterpret_cast<char *>(proc) - (reinterpret_cast<char *>(this) + sizeof(thunkCode)));
+               return TRUE;
+       }
+
+       WNDPROC GetWNDPROC()
+       {
+               return reinterpret_cast<WNDPROC>(&m_thunk);
+       }
+};
+
+#else
+#error Only X86 supported
+#endif
+
+class CMessageMap
+{
+public:
+       virtual BOOL ProcessWindowMessage(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam, LRESULT &lResult, DWORD dwMsgMapID) = 0;
+};
+
+class CWindow
+{
+public:
+       HWND                                                                    m_hWnd;
+       static RECT                                                             rcDefault;
+public:
+       CWindow(HWND hWnd = NULL)
+       {
+               m_hWnd = hWnd;
+       }
+
+       operator HWND() const
+       {
+               return m_hWnd;
+       }
+
+       static LPCTSTR GetWndClassName()
+       {
+               return NULL;
+       }
+
+       HDC BeginPaint(LPPAINTSTRUCT lpPaint)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::BeginPaint(m_hWnd, lpPaint);
+       }
+
+       BOOL DestroyWindow()
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+
+               if (!::DestroyWindow(m_hWnd))
+                       return FALSE;
+
+               m_hWnd = NULL;
+               return TRUE;
+       }
+
+       void EndPaint(LPPAINTSTRUCT lpPaint)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               ::EndPaint(m_hWnd, lpPaint);
+       }
+
+       BOOL GetClientRect(LPRECT lpRect) const
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::GetClientRect(m_hWnd, lpRect);
+       }
+
+       CWindow GetParent() const
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return CWindow(::GetParent(m_hWnd));
+       }
+
+       BOOL Invalidate(BOOL bErase = TRUE)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::InvalidateRect(m_hWnd, NULL, bErase);
+       }
+
+       BOOL InvalidateRect(LPCRECT lpRect, BOOL bErase = TRUE)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::InvalidateRect(m_hWnd, lpRect, bErase);
+       }
+
+       BOOL IsWindow() const
+       {
+               return ::IsWindow(m_hWnd);
+       }
+
+       BOOL KillTimer(UINT_PTR nIDEvent)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::KillTimer(m_hWnd, nIDEvent);
+       }
+
+       BOOL LockWindowUpdate(BOOL bLock = TRUE)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               if (bLock)
+                       return ::LockWindowUpdate(m_hWnd);
+               return ::LockWindowUpdate(NULL);
+       }
+
+       BOOL ScreenToClient(LPPOINT lpPoint) const
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::ScreenToClient(m_hWnd, lpPoint);
+       }
+
+       LRESULT SendMessage(UINT message, WPARAM wParam = 0, LPARAM lParam = 0)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::SendMessage(m_hWnd, message, wParam, lParam);
+       }
+
+       static LRESULT SendMessage(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam)
+       {
+               ATLASSERT(::IsWindow(hWnd));
+               return ::SendMessage(hWnd, message, wParam, lParam);
+       }
+
+       HWND SetCapture()
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::SetCapture(m_hWnd);
+       }
+
+       HWND SetFocus()
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::SetFocus(m_hWnd);
+       }
+
+       UINT_PTR SetTimer(UINT_PTR nIDEvent, UINT nElapse, void (CALLBACK *lpfnTimer)(HWND, UINT, UINT_PTR, DWORD) = NULL)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::SetTimer(m_hWnd, nIDEvent, nElapse, reinterpret_cast<TIMERPROC>(lpfnTimer));
+       }
+
+       BOOL SetWindowPos(HWND hWndInsertAfter, int x, int y, int cx, int cy, UINT nFlags)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::SetWindowPos(m_hWnd, hWndInsertAfter, x, y, cx, cy, nFlags);
+       }
+
+       BOOL SetWindowText(LPCTSTR lpszString)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::SetWindowText(m_hWnd, lpszString);
+       }
+
+       BOOL ShowWindow(int nCmdShow)
+       {
+               ATLASSERT(::IsWindow(m_hWnd));
+               return ::ShowWindow(m_hWnd, nCmdShow);
+       }
+
+};
+
+_declspec(selectany) RECT CWindow::rcDefault = { CW_USEDEFAULT, CW_USEDEFAULT, 0, 0 };
+
+template <class TBase = CWindow, class TWinTraits = CControlWinTraits>
+class CWindowImplBaseT : public TBase, public CMessageMap
+{
+public:
+       enum { WINSTATE_DESTROYED = 0x00000001 };
+       DWORD                                                                   m_dwState;
+       const _ATL_MSG                                                  *m_pCurrentMsg;
+       CWndProcThunk                                                   m_thunk;
+       WNDPROC                                                                 m_pfnSuperWindowProc;
+public:
+       CWindowImplBaseT()
+       {
+               m_dwState = 0;
+               m_pCurrentMsg = NULL;
+               m_pfnSuperWindowProc = ::DefWindowProc;
+       }
+
+       virtual void OnFinalMessage(HWND /* hWnd */)
+       {
+       }
+
+       BOOL SubclassWindow(HWND hWnd)
+       {
+               CWindowImplBaseT<TBase, TWinTraits>     *pThis;
+               WNDPROC                                                         newWindowProc;
+               WNDPROC                                                         oldWindowProc;
+               BOOL                                                            result;
+
+               ATLASSERT(m_hWnd == NULL);
+               ATLASSERT(::IsWindow(hWnd));
+
+               pThis = reinterpret_cast<CWindowImplBaseT<TBase, TWinTraits>*>(this);
+
+               result = m_thunk.Init(GetWindowProc(), this);
+               if (result == FALSE)
+                       return FALSE;
+               newWindowProc = m_thunk.GetWNDPROC();
+               oldWindowProc = reinterpret_cast<WNDPROC>(::SetWindowLongPtr(hWnd, GWLP_WNDPROC, reinterpret_cast<LONG_PTR>(newWindowProc)));
+               if (oldWindowProc == NULL)
+                       return FALSE;
+               m_pfnSuperWindowProc = oldWindowProc;
+               pThis->m_hWnd = hWnd;
+               return TRUE;
+       }
+
+       virtual WNDPROC GetWindowProc()
+       {
+               return WindowProc;
+       }
+
+       static DWORD GetWndStyle(DWORD dwStyle)
+       {
+               return TWinTraits::GetWndStyle(dwStyle);
+       }
+
+       static DWORD GetWndExStyle(DWORD dwExStyle)
+       {
+               return TWinTraits::GetWndExStyle(dwExStyle);
+       }
+
+       LRESULT DefWindowProc(UINT uMsg, WPARAM wParam, LPARAM lParam)
+       {
+               CWindowImplBaseT<TBase, TWinTraits>     *pThis;
+
+               pThis = reinterpret_cast<CWindowImplBaseT<TBase, TWinTraits> *>(this);
+               return ::CallWindowProc(m_pfnSuperWindowProc, pThis->m_hWnd, uMsg, wParam, lParam);
+       }
+
+       static LRESULT CALLBACK StartWindowProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam)
+       {
+               CWindowImplBaseT<TBase, TWinTraits>     *pThis;
+               WNDPROC                                                         newWindowProc;
+               WNDPROC                                                         GCCU(pOldProc);
+
+               pThis = reinterpret_cast<CWindowImplBaseT<TBase, TWinTraits> *>(_AtlWinModule.ExtractCreateWndData());
+               ATLASSERT(pThis != NULL);
+               if (pThis == NULL)
+                       return 0;
+               pThis->m_thunk.Init(pThis->GetWindowProc(), pThis);
+               newWindowProc = pThis->m_thunk.GetWNDPROC();
+               pOldProc = reinterpret_cast<WNDPROC>(::SetWindowLongPtr(hWnd, GWLP_WNDPROC, reinterpret_cast<LONG_PTR>(newWindowProc)));
+               Unused(pOldProc); // TODO: should generate trace message if overwriting another subclass
+               pThis->m_hWnd = hWnd;
+               return newWindowProc(hWnd, uMsg, wParam, lParam);
+       }
+
+       static LRESULT CALLBACK WindowProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam)
+       {
+               CWindowImplBaseT<TBase, TWinTraits>     *pThis = reinterpret_cast<CWindowImplBaseT< TBase, TWinTraits> *>(hWnd);
+               _ATL_MSG                                                        msg(pThis->m_hWnd, uMsg, wParam, lParam);
+               LRESULT                                                         lResult;
+               const _ATL_MSG                                          *previousMessage;
+               BOOL                                                            handled;
+               LONG_PTR                                                        saveWindowProc;
+
+               ATLASSERT(pThis != NULL && (pThis->m_dwState & WINSTATE_DESTROYED) == 0 && pThis->m_hWnd != NULL);
+               if (pThis == NULL || (pThis->m_dwState & WINSTATE_DESTROYED) != 0 || pThis->m_hWnd == NULL)
+                       return 0;
+
+               hWnd = pThis->m_hWnd;
+               previousMessage = pThis->m_pCurrentMsg;
+               pThis->m_pCurrentMsg = &msg;
+
+               handled = pThis->ProcessWindowMessage(hWnd, uMsg, wParam, lParam, lResult, 0);
+               ATLASSERT(pThis->m_pCurrentMsg == &msg);
+
+               if (handled == FALSE)
+               {
+                       if (uMsg == WM_NCDESTROY)
+                       {
+                               saveWindowProc = ::GetWindowLongPtr(hWnd, GWLP_WNDPROC);
+                               lResult = pThis->DefWindowProc(uMsg, wParam, lParam);
+                               if (pThis->m_pfnSuperWindowProc != ::DefWindowProc && saveWindowProc == ::GetWindowLongPtr(hWnd, GWLP_WNDPROC))
+                                       ::SetWindowLongPtr(hWnd, GWLP_WNDPROC, reinterpret_cast<LONG_PTR>(pThis->m_pfnSuperWindowProc));
+                               pThis->m_dwState |= WINSTATE_DESTROYED;
+                       }
+                       else
+                               lResult = pThis->DefWindowProc(uMsg, wParam, lParam);
+               }
+               ATLASSERT(pThis->m_pCurrentMsg == &msg);
+               pThis->m_pCurrentMsg = previousMessage;
+               if (previousMessage == NULL && (pThis->m_dwState & WINSTATE_DESTROYED) != 0)
+               {
+                       pThis->m_dwState &= ~WINSTATE_DESTROYED;
+                       pThis->m_hWnd = NULL;
+                       pThis->OnFinalMessage(hWnd);
+               }
+               return lResult;
+       }
+
+       HWND Create(HWND hWndParent, _U_RECT rect, LPCTSTR szWindowName, DWORD dwStyle, DWORD dwExStyle,
+                                       _U_MENUorID MenuOrID, ATOM atom, LPVOID lpCreateParam)
+       {
+               HWND                                                            hWnd;
+
+               ATLASSERT(m_hWnd == NULL);
+               ATLASSERT(atom != 0);
+               if (atom == 0)
+                       return NULL;
+               if (m_thunk.Init(NULL, NULL) == FALSE)
+               {
+                       SetLastError(ERROR_OUTOFMEMORY);
+                       return NULL;
+               }
+
+               _AtlWinModule.AddCreateWndData(&m_thunk.cd, this);
+               if (MenuOrID.m_hMenu == NULL && (dwStyle & WS_CHILD) != 0)
+                       MenuOrID.m_hMenu = (HMENU)(UINT_PTR)this;
+               if (rect.m_lpRect == NULL)
+                       rect.m_lpRect = &TBase::rcDefault;
+               hWnd = ::CreateWindowEx(dwExStyle, reinterpret_cast<LPCWSTR>(MAKEINTATOM(atom)), szWindowName, dwStyle, rect.m_lpRect->left,
+                                       rect.m_lpRect->top, rect.m_lpRect->right - rect.m_lpRect->left, rect.m_lpRect->bottom - rect.m_lpRect->top,
+                                       hWndParent, MenuOrID.m_hMenu, _AtlBaseModule.GetModuleInstance(), lpCreateParam);
+
+               ATLASSERT(m_hWnd == hWnd);
+
+               return hWnd;
+       }
+};
+
+template <class T, class TBase = CWindow, class TWinTraits = CControlWinTraits>
+class CWindowImpl : public CWindowImplBaseT<TBase, TWinTraits>
+{
+public:
+       static LPCTSTR GetWndCaption()
+       {
+               return NULL;
+       }
+
+       HWND Create(HWND hWndParent, _U_RECT rect = NULL, LPCTSTR szWindowName = NULL, DWORD dwStyle = 0,
+                                       DWORD dwExStyle = 0, _U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+       {
+               CWindowImplBaseT<TBase, TWinTraits>     *pThis;
+               ATOM                                                            atom;
+
+               ATLASSERT(m_hWnd == NULL);
+               pThis = reinterpret_cast<CWindowImplBaseT<TBase, TWinTraits>*>(this);
+
+               if (T::GetWndClassInfo().m_lpszOrigName == NULL)
+                       T::GetWndClassInfo().m_lpszOrigName = pThis->GetWndClassName();
+               atom = T::GetWndClassInfo().Register(&pThis->m_pfnSuperWindowProc);
+
+               if (szWindowName == NULL)
+                       szWindowName = T::GetWndCaption();
+               dwStyle = T::GetWndStyle(dwStyle);
+               dwExStyle = T::GetWndExStyle(dwExStyle);
+
+               return CWindowImplBaseT<TBase, TWinTraits>::Create(hWndParent, rect, szWindowName, dwStyle,
+                                               dwExStyle, MenuOrID, atom, lpCreateParam);
+       }
+};
+
+template <class TBase = CWindow, class TWinTraits = CControlWinTraits>
+class CContainedWindowT : public TBase
+{
+public:
+       CWndProcThunk                                                   m_thunk;
+       LPCTSTR                                                                 m_lpszClassName;
+       WNDPROC                                                                 m_pfnSuperWindowProc;
+       CMessageMap                                                             *m_pObject;
+       DWORD                                                                   m_dwMsgMapID;
+       const _ATL_MSG                                                  *m_pCurrentMsg;
+public:
+       CContainedWindowT(CMessageMap *pObject, DWORD dwMsgMapID = 0)
+       {
+               m_lpszClassName = TBase::GetWndClassName();
+               m_pfnSuperWindowProc = ::DefWindowProc;
+               m_pObject = pObject;
+               m_dwMsgMapID = dwMsgMapID;
+               m_pCurrentMsg = NULL;
+       }
+
+       CContainedWindowT(LPTSTR lpszClassName, CMessageMap *pObject, DWORD dwMsgMapID = 0)
+       {
+               m_lpszClassName = lpszClassName;
+               m_pfnSuperWindowProc = ::DefWindowProc;
+               m_pObject = pObject;
+               m_dwMsgMapID = dwMsgMapID;
+               m_pCurrentMsg = NULL;
+       }
+
+       LRESULT DefWindowProc(UINT uMsg, WPARAM wParam, LPARAM lParam)
+       {
+               CWindowImplBaseT<TBase, TWinTraits>     *pThis;
+
+               pThis = reinterpret_cast<CWindowImplBaseT<TBase, TWinTraits> *>(this);
+               return ::CallWindowProc(m_pfnSuperWindowProc, pThis->m_hWnd, uMsg, wParam, lParam);
+       }
+
+       BOOL SubclassWindow(HWND hWnd)
+       {
+               CContainedWindowT<TBase>                        *pThis;
+               WNDPROC                                                         newWindowProc;
+               WNDPROC                                                         oldWindowProc;
+               BOOL                                                            result;
+
+               ATLASSERT(m_hWnd == NULL);
+               ATLASSERT(::IsWindow(hWnd));
+
+               pThis = reinterpret_cast<CContainedWindowT<TBase> *>(this);
+
+               result = m_thunk.Init(WindowProc, pThis);
+               if (result == FALSE)
+                       return FALSE;
+               newWindowProc = m_thunk.GetWNDPROC();
+               oldWindowProc = reinterpret_cast<WNDPROC>(::SetWindowLongPtr(hWnd, GWLP_WNDPROC, reinterpret_cast<LONG_PTR>(newWindowProc)));
+               if (oldWindowProc == NULL)
+                       return FALSE;
+               m_pfnSuperWindowProc = oldWindowProc;
+               pThis->m_hWnd = hWnd;
+               return TRUE;
+       }
+
+       static LRESULT CALLBACK StartWindowProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam)
+       {
+               CContainedWindowT<TBase>                        *pThis;
+               WNDPROC                                                         newWindowProc;
+               WNDPROC                                                         GCCU(pOldProc);
+
+               pThis = reinterpret_cast<CContainedWindowT<TBase> *>(_AtlWinModule.ExtractCreateWndData());
+               ATLASSERT(pThis != NULL);
+               if (pThis == NULL)
+                       return 0;
+               pThis->m_thunk.Init(WindowProc, pThis);
+               newWindowProc = pThis->m_thunk.GetWNDPROC();
+               pOldProc = reinterpret_cast<WNDPROC>(::SetWindowLongPtr(hWnd, GWLP_WNDPROC, reinterpret_cast<LONG_PTR>(newWindowProc)));
+               Unused(pOldProc); // TODO: should generate trace message if overwriting another subclass
+               pThis->m_hWnd = hWnd;
+               return newWindowProc(hWnd, uMsg, wParam, lParam);
+       }
+
+       static LRESULT CALLBACK WindowProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam)
+       {
+               CContainedWindowT<TBase>                        *pThis = reinterpret_cast<CContainedWindowT<TBase> *>(hWnd);
+               _ATL_MSG                                                        msg(pThis->m_hWnd, uMsg, wParam, lParam);
+               LRESULT                                                         lResult;
+               const _ATL_MSG                                          *previousMessage;
+               BOOL                                                            handled;
+               LONG_PTR                                                        saveWindowProc;
+
+               ATLASSERT(pThis != NULL && pThis->m_hWnd != NULL && pThis->m_pObject != NULL);
+               if (pThis == NULL || pThis->m_hWnd == NULL || pThis->m_pObject == NULL)
+                       return 0;
+
+               hWnd = pThis->m_hWnd;
+               previousMessage = pThis->m_pCurrentMsg;
+               pThis->m_pCurrentMsg = &msg;
+
+               handled = pThis->m_pObject->ProcessWindowMessage(hWnd, uMsg, wParam, lParam, lResult, pThis->m_dwMsgMapID);
+               ATLASSERT(pThis->m_pCurrentMsg == &msg);
+
+               pThis->m_pCurrentMsg = previousMessage;
+               if (handled == FALSE)
+               {
+                       if (uMsg == WM_NCDESTROY)
+                       {
+                               saveWindowProc = ::GetWindowLongPtr(hWnd, GWLP_WNDPROC);
+                               lResult = pThis->DefWindowProc(uMsg, wParam, lParam);
+                               if (pThis->m_pfnSuperWindowProc != ::DefWindowProc && saveWindowProc == ::GetWindowLongPtr(hWnd, GWLP_WNDPROC))
+                                       ::SetWindowLongPtr(hWnd, GWLP_WNDPROC, reinterpret_cast<LONG_PTR>(pThis->m_pfnSuperWindowProc));
+                               pThis->m_hWnd = NULL;
+                       }
+                       else
+                               lResult = pThis->DefWindowProc(uMsg, wParam, lParam);
+               }
+               return lResult;
+       }
+
+};
+typedef CContainedWindowT<CWindow>     CContainedWindow;
+
+#define BEGIN_MSG_MAP(theClass)                                                                                                                                                                                                \
+public:                                                                                                                                                                                                                                                \
+       BOOL ProcessWindowMessage(HWND GCCU(hWnd), UINT GCCU(uMsg), WPARAM GCCU(wParam), LPARAM GCCU(lParam), LRESULT &GCCU(lResult), DWORD dwMsgMapID = 0)     \
+       {                                                                                                                                                                                                                                               \
+               BOOL GCCU(bHandled) = TRUE;                                                                                                                                                                                     \
+               Unused(hWnd);                                                                                                                                                                                                           \
+               Unused(uMsg);                                                                                                                                                                                                           \
+               Unused(wParam);                                                                                                                                                                                                         \
+               Unused(lParam);                                                                                                                                                                                                         \
+               Unused(lResult);                                                                                                                                                                                                        \
+               Unused(bHandled);                                                                                                                                                                                                       \
+               switch(dwMsgMapID)                                                                                                                                                                                                      \
+               {                                                                                                                                                                                                                                       \
+               case 0:
+
+#define END_MSG_MAP()                                                                                                                                                  \
+                       break;                                                                                                                                                          \
+               default:                                                                                                                                                                \
+                       ATLASSERT(FALSE);                                                                                                                                       \
+                       break;                                                                                                                                                          \
+               }                                                                                                                                                                               \
+               return FALSE;                                                                                                                                                   \
+       }
+
+#define MESSAGE_HANDLER(msg, func)                                                                                                                             \
+       if (uMsg == msg)                                                                                                                                                        \
+       {                                                                                                                                                                                       \
+               bHandled = TRUE;                                                                                                                                                \
+               lResult = func(uMsg, wParam, lParam, bHandled);                                                                                 \
+               if (bHandled)                                                                                                                                                   \
+                       return TRUE;                                                                                                                                            \
+       }
+
+#define MESSAGE_RANGE_HANDLER(msgFirst, msgLast, func)                                                                                 \
+       if (uMsg >= msgFirst && uMsg <= msgLast)                                                                                                        \
+       {                                                                                                                                                                                       \
+               bHandled = TRUE;                                                                                                                                                \
+               lResult = func(uMsg, wParam, lParam, bHandled);                                                                                 \
+               if (bHandled)                                                                                                                                                   \
+                       return TRUE;                                                                                                                                            \
+       }
+
+#define COMMAND_ID_HANDLER(id, func)                                                                                                                   \
+       if (uMsg == WM_COMMAND && id == LOWORD(wParam))                                                                                         \
+       {                                                                                                                                                                                       \
+               bHandled = TRUE;                                                                                                                                                \
+               lResult = func(HIWORD(wParam), LOWORD(wParam), (HWND)lParam, bHandled);                                 \
+               if (bHandled)                                                                                                                                                   \
+                       return TRUE;                                                                                                                                            \
+       }
+
+#define COMMAND_RANGE_HANDLER(idFirst, idLast, func)                                                                                   \
+       if (uMsg == WM_COMMAND && LOWORD(wParam) >= idFirst  && LOWORD(wParam) <= idLast)                       \
+       {                                                                                                                                                                                       \
+               bHandled = TRUE;                                                                                                                                                \
+               lResult = func(HIWORD(wParam), LOWORD(wParam), (HWND)lParam, bHandled);                                 \
+               if (bHandled)                                                                                                                                                   \
+                       return TRUE;                                                                                                                                            \
+       }
+
+#define NOTIFY_CODE_HANDLER(cd, func)                                                                                                                  \
+       if(uMsg == WM_NOTIFY && cd == ((LPNMHDR)lParam)->code)                                                                          \
+       {                                                                                                                                                                                       \
+               bHandled = TRUE;                                                                                                                                                \
+               lResult = func((int)wParam, (LPNMHDR)lParam, bHandled);                                                                 \
+               if (bHandled)                                                                                                                                                   \
+                       return TRUE;                                                                                                                                            \
+       }
+
+#define NOTIFY_HANDLER(id, cd, func)                                                                                                                   \
+       if(uMsg == WM_NOTIFY && id == ((LPNMHDR)lParam)->idFrom && cd == ((LPNMHDR)lParam)->code)       \
+       {                                                                                                                                                                                       \
+               bHandled = TRUE;                                                                                                                                                \
+               lResult = func((int)wParam, (LPNMHDR)lParam, bHandled);                                                                 \
+               if (bHandled)                                                                                                                                                   \
+                       return TRUE;                                                                                                                                            \
+       }
+
+#define DECLARE_WND_CLASS_EX(WndClassName, style, bkgnd)                                                                               \
+static ATL::CWndClassInfo& GetWndClassInfo()                                                                                                   \
+{                                                                                                                                                                                              \
+       static ATL::CWndClassInfo wc =                                                                                                                          \
+       {                                                                                                                                                                                       \
+               { sizeof(WNDCLASSEX), style, StartWindowProc,                                                                                   \
+                 0, 0, NULL, NULL, NULL, (HBRUSH)(bkgnd + 1), NULL, WndClassName, NULL },                              \
+               NULL, NULL, IDC_ARROW, TRUE, 0, _T("")                                                                                                  \
+       };                                                                                                                                                                                      \
+       return wc;                                                                                                                                                                      \
+}
+
+struct _ATL_WNDCLASSINFOW
+{
+       WNDCLASSEXW                                                             m_wc;
+       LPCWSTR                                                                 m_lpszOrigName;
+       WNDPROC                                                                 pWndProc;
+       LPCWSTR                                                                 m_lpszCursorID;
+       BOOL                                                                    m_bSystemCursor;
+       ATOM                                                                    m_atom;
+       WCHAR                                                                   m_szAutoName[5 + sizeof(void *)];
+
+       ATOM Register(WNDPROC *p)
+       {
+               if (m_atom == 0)
+                       m_atom = RegisterClassEx(&m_wc);
+               return m_atom;
+       }
+};
+
+}; // namespace ATL
+
+#endif // _atlwin_h
diff --git a/reactos/lib/atl/statreg.h b/reactos/lib/atl/statreg.h
new file mode 100644 (file)
index 0000000..19532d9
--- /dev/null
@@ -0,0 +1,679 @@
+/*
+ * ReactOS ATL
+ *
+ * Copyright 2005 Jacek Caban
+ * Copyright 2009 Andrew Hill <ash77@reactos.org>
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+#ifndef _statreg_h
+#define _statreg_h
+
+class IRegistrarBase : public IUnknown
+{
+public:
+       virtual HRESULT STDMETHODCALLTYPE AddReplacement(LPCOLESTR key, LPCOLESTR item) = 0;
+       virtual HRESULT STDMETHODCALLTYPE ClearReplacements() = 0;
+};
+
+namespace ATL
+{
+
+class CRegObject : public IRegistrarBase
+{
+public:
+       typedef struct rep_list_str
+       {
+               LPOLESTR                                                        key;
+               LPOLESTR                                                        item;
+               int                                                                     key_len;
+               struct rep_list_str                                     *next;
+       } rep_list;
+
+       typedef struct
+       {
+               LPOLESTR                                                        str;
+               DWORD                                                           alloc;
+               DWORD                                                           len;
+       } strbuf;
+
+       rep_list                                                                *m_rep;
+
+public:
+       CRegObject()
+       {
+               m_rep = NULL;
+       }
+
+       ~CRegObject()
+       {
+               HRESULT                                                         hResult;
+
+               hResult = ClearReplacements();
+               ATLASSERT(SUCCEEDED(hResult));
+       }
+
+       HRESULT STDMETHODCALLTYPE QueryInterface(const IID & /* riid */, void ** /* ppvObject */ )
+       {
+               ATLASSERT(_T("statically linked in CRegObject is not a com object. Do not callthis function"));
+               return E_NOTIMPL;
+       }
+
+       ULONG STDMETHODCALLTYPE AddRef()
+       {
+               ATLASSERT(_T("statically linked in CRegObject is not a com object. Do not callthis function"));
+               return 1;
+       }
+
+       ULONG STDMETHODCALLTYPE Release()
+       {
+               ATLASSERT(_T("statically linked in CRegObject is not a com object. Do not callthis function"));
+               return 0;
+       }
+
+       HRESULT STDMETHODCALLTYPE AddReplacement(LPCOLESTR key, LPCOLESTR item)
+       {
+               int                                                                     len;
+               rep_list                                                        *new_rep;
+
+               new_rep = reinterpret_cast<rep_list *>(HeapAlloc(GetProcessHeap(), 0, sizeof(rep_list)));
+
+               new_rep->key_len  = lstrlenW(key);
+               new_rep->key = reinterpret_cast<OLECHAR *>(HeapAlloc(GetProcessHeap(), 0, (new_rep->key_len + 1) * sizeof(OLECHAR)));
+               memcpy(new_rep->key, key, (new_rep->key_len + 1) * sizeof(OLECHAR));
+
+               len = lstrlenW(item) + 1;
+               new_rep->item = reinterpret_cast<OLECHAR *>(HeapAlloc(GetProcessHeap(), 0, len * sizeof(OLECHAR)));
+               memcpy(new_rep->item, item, len * sizeof(OLECHAR));
+
+               new_rep->next = m_rep;
+               m_rep = new_rep;
+
+               return S_OK;
+       }
+
+       HRESULT STDMETHODCALLTYPE ClearReplacements()
+       {
+               rep_list                                                        *iter;
+               rep_list                                                        *iter2;
+
+               iter = m_rep;
+               while (iter)
+               {
+                       iter2 = iter->next;
+                       HeapFree(GetProcessHeap(), 0, iter->key);
+                       HeapFree(GetProcessHeap(), 0, iter->item);
+                       HeapFree(GetProcessHeap(), 0, iter);
+                       iter = iter2;
+               }
+
+               m_rep = NULL;
+               return S_OK;
+       }
+
+       HRESULT STDMETHODCALLTYPE ResourceRegisterSz(LPCOLESTR resFileName, LPCOLESTR szID, LPCOLESTR szType)
+       {
+               return RegisterWithResource(resFileName, szID, szType, TRUE);
+       }
+
+       HRESULT STDMETHODCALLTYPE ResourceUnregisterSz(LPCOLESTR resFileName, LPCOLESTR szID, LPCOLESTR szType)
+       {
+               return RegisterWithResource(resFileName, szID, szType, FALSE);
+       }
+
+       HRESULT STDMETHODCALLTYPE FileRegister(LPCOLESTR fileName)
+       {
+               return RegisterWithFile(fileName, TRUE);
+       }
+
+       HRESULT STDMETHODCALLTYPE FileUnregister(LPCOLESTR fileName)
+       {
+               return RegisterWithFile(fileName, FALSE);
+       }
+
+       HRESULT STDMETHODCALLTYPE StringRegister(LPCOLESTR data)
+       {
+               return RegisterWithString(data, TRUE);
+       }
+
+       HRESULT STDMETHODCALLTYPE StringUnregister(LPCOLESTR data)
+       {
+               return RegisterWithString(data, FALSE);
+       }
+
+       HRESULT STDMETHODCALLTYPE ResourceRegister(LPCOLESTR resFileName, UINT nID, LPCOLESTR szType)
+       {
+               return ResourceRegisterSz(resFileName, MAKEINTRESOURCEW(nID), szType);
+       }
+
+       HRESULT STDMETHODCALLTYPE ResourceUnregister(LPCOLESTR resFileName, UINT nID, LPCOLESTR szType)
+       {
+               return ResourceRegisterSz(resFileName, MAKEINTRESOURCEW(nID), szType);
+       }
+
+protected:
+       HRESULT STDMETHODCALLTYPE RegisterWithResource(LPCOLESTR resFileName, LPCOLESTR szID, LPCOLESTR szType, BOOL doRegister)
+       {
+               return resource_register(resFileName, szID, szType, doRegister);
+       }
+
+       HRESULT STDMETHODCALLTYPE RegisterWithFile(LPCOLESTR fileName, BOOL doRegister)
+       {
+               return file_register(fileName, doRegister);
+       }
+
+       HRESULT STDMETHODCALLTYPE RegisterWithString(LPCOLESTR data, BOOL doRegister)
+       {
+               return string_register(data, doRegister);
+       }
+
+private:
+       inline LONG RegDeleteTreeX(HKEY parentKey, LPCTSTR subKeyName)
+       {
+               wchar_t                                                         szBuffer[256];
+               DWORD                                                           dwSize;
+               FILETIME                                                        time;
+               HKEY                                                            childKey;
+               LONG                                                            lRes;
+
+               ATLASSERT(parentKey != NULL);
+               lRes = RegOpenKeyEx(parentKey, subKeyName, 0, KEY_READ | KEY_WRITE, &childKey);
+               if (lRes != ERROR_SUCCESS)
+                       return lRes;
+
+               dwSize = sizeof(szBuffer) / sizeof(szBuffer[0]);
+               while (RegEnumKeyExW(parentKey, 0, szBuffer, &dwSize, NULL, NULL, NULL, &time) == ERROR_SUCCESS)
+               {
+                       lRes = RegDeleteTreeX(childKey, szBuffer);
+                       if (lRes != ERROR_SUCCESS)
+                               return lRes;
+                       dwSize = sizeof(szBuffer) / sizeof(szBuffer[0]);
+               }
+               RegCloseKey(childKey);
+               return RegDeleteKey(parentKey, subKeyName);
+       }
+
+       void strbuf_init(strbuf *buf)
+       {
+               buf->str = reinterpret_cast<LPOLESTR>(HeapAlloc(GetProcessHeap(), 0, 128 * sizeof(WCHAR)));
+               buf->alloc = 128;
+               buf->len = 0;
+       }
+
+       void strbuf_write(LPCOLESTR str, strbuf *buf, int len)
+       {
+               if (len == -1)
+                       len = lstrlenW(str);
+               if (buf->len+len+1 >= buf->alloc)
+               {
+                       buf->alloc = (buf->len + len) * 2;
+                       buf->str = reinterpret_cast<LPOLESTR>(HeapReAlloc(GetProcessHeap(), 0, buf->str, buf->alloc * sizeof(WCHAR)));
+               }
+               memcpy(buf->str + buf->len, str, len * sizeof(OLECHAR));
+               buf->len += len;
+               buf->str[buf->len] = '\0';
+       }
+
+
+       HRESULT file_register(LPCOLESTR fileName, BOOL do_register)
+       {
+               HANDLE                                                          file;
+               DWORD                                                           filelen;
+               DWORD                                                           len;
+               LPWSTR                                                          regstrw;
+               LPSTR                                                           regstra;
+               LRESULT                                                         lres;
+               HRESULT                                                         hResult;
+
+               file = CreateFileW(fileName, GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_READONLY, NULL);
+               if (file != INVALID_HANDLE_VALUE)
+               {
+                       filelen = GetFileSize(file, NULL);
+                       regstra = reinterpret_cast<LPSTR>(HeapAlloc(GetProcessHeap(), 0, filelen));
+                       lres = ReadFile(file, regstra, filelen, NULL, NULL);
+                       if (lres == ERROR_SUCCESS)
+                       {
+                               len = MultiByteToWideChar(CP_ACP, 0, regstra, filelen, NULL, 0) + 1;
+                               regstrw = reinterpret_cast<LPWSTR>(HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, len * sizeof(WCHAR)));
+                               MultiByteToWideChar(CP_ACP, 0, regstra, filelen, regstrw, len);
+                               regstrw[len - 1] = '\0';
+
+                               hResult = string_register(regstrw, do_register);
+
+                               HeapFree(GetProcessHeap(), 0, regstrw);
+                       }
+                       else
+                       {
+                               hResult = HRESULT_FROM_WIN32(lres);
+                       }
+                       HeapFree(GetProcessHeap(), 0, regstra);
+                       CloseHandle(file);
+               }
+               else
+               {
+                       hResult = HRESULT_FROM_WIN32(GetLastError());
+               }
+
+               return hResult;
+       }
+
+       HRESULT resource_register(LPCOLESTR resFileName, LPCOLESTR szID, LPCOLESTR szType, BOOL do_register)
+       {
+               HINSTANCE                                                       hins;
+               HRSRC                                                           src;
+               HGLOBAL                                                         regstra;
+               LPWSTR                                                          regstrw;
+               DWORD                                                           len;
+               DWORD                                                           reslen;
+               HRESULT                                                         hResult;
+
+               hins = LoadLibraryExW(resFileName, NULL, LOAD_LIBRARY_AS_DATAFILE);
+               if (hins)
+               {
+                       src = FindResourceW(hins, szID, szType);
+                       if (src)
+                       {
+                               regstra = LoadResource(hins, src);
+                               reslen = SizeofResource(hins, src);
+                               if (regstra)
+                               {
+                                       len = MultiByteToWideChar(CP_ACP, 0, reinterpret_cast<LPCSTR>(regstra), reslen, NULL, 0) + 1;
+                                       regstrw = reinterpret_cast<LPWSTR>(HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, len * sizeof(WCHAR)));
+                                       MultiByteToWideChar(CP_ACP, 0, reinterpret_cast<LPCSTR>(regstra), reslen, regstrw, len);
+                                       regstrw[len - 1] = '\0';
+
+                                       hResult = string_register(regstrw, do_register);
+
+                                       HeapFree(GetProcessHeap(), 0, regstrw);
+                               }
+                               else
+                                       hResult = HRESULT_FROM_WIN32(GetLastError());
+                       }
+                       else
+                               hResult = HRESULT_FROM_WIN32(GetLastError());
+                       FreeLibrary(hins);
+               }
+               else
+                       hResult = HRESULT_FROM_WIN32(GetLastError());
+
+               return hResult;
+       }
+
+       HRESULT string_register(LPCOLESTR data, BOOL do_register)
+       {
+               strbuf                                                          buf;
+               HRESULT                                                         hResult;
+
+               strbuf_init(&buf);
+               hResult = do_preprocess(data, &buf);
+               if (SUCCEEDED(hResult))
+               {
+                       hResult = do_process_root_key(buf.str, do_register);
+                       if (FAILED(hResult) && do_register)
+                               do_process_root_key(buf.str, FALSE);
+               }
+
+               HeapFree(GetProcessHeap(), 0, buf.str);
+               return hResult;
+       }
+
+       HRESULT do_preprocess(LPCOLESTR data, strbuf *buf)
+       {
+               LPCOLESTR                                                       iter;
+               LPCOLESTR                                                       iter2;
+               rep_list                                                        *rep_iter;
+
+               iter2 = data;
+               iter = wcschr(data, '%');
+               while (iter)
+               {
+                       strbuf_write(iter2, buf, static_cast<int>(iter - iter2));
+
+                       iter2 = ++iter;
+                       if (!*iter2)
+                               return DISP_E_EXCEPTION;
+                       iter = wcschr(iter2, '%');
+                       if (!iter)
+                               return DISP_E_EXCEPTION;
+
+                       if (iter == iter2)
+                               strbuf_write(_T("%"), buf, 1);
+                       else
+                       {
+                               for (rep_iter = m_rep; rep_iter; rep_iter = rep_iter->next)
+                               {
+                                       if (rep_iter->key_len == iter - iter2 && !_memicmp(iter2, rep_iter->key, rep_iter->key_len * sizeof(wchar_t)))
+                                               break;
+                               }
+                               if (!rep_iter)
+                                       return DISP_E_EXCEPTION;
+
+                               strbuf_write(rep_iter->item, buf, -1);
+                       }
+
+                       iter2 = ++iter;
+                       iter = wcschr(iter, '%');
+               }
+
+               strbuf_write(iter2, buf, -1);
+
+               return S_OK;
+       }
+
+       HRESULT get_word(LPCOLESTR *str, strbuf *buf)
+       {
+               LPCOLESTR                                                       iter;
+               LPCOLESTR                                                       iter2;
+
+               iter2 = *str;
+               buf->len = 0;
+               buf->str[0] = '\0';
+
+               while (iswspace (*iter2))
+                       iter2++;
+               iter = iter2;
+               if (!*iter)
+               {
+                       *str = iter;
+                       return S_OK;
+               }
+
+               if (*iter == '}' || *iter == '=')
+               {
+                       strbuf_write(iter++, buf, 1);
+               }
+               else if (*iter == '\'')
+               {
+                       iter2 = ++iter;
+                       iter = wcschr(iter, '\'');
+                       if (!iter)
+                       {
+                               *str = iter;
+                               return DISP_E_EXCEPTION;
+                       }
+                       strbuf_write(iter2, buf, static_cast<int>(iter - iter2));
+                       iter++;
+               }
+               else
+               {
+                       while (*iter && !iswspace(*iter))
+                               iter++;
+                       strbuf_write(iter2, buf, static_cast<int>(iter - iter2));
+               }
+
+               while (iswspace(*iter))
+                       iter++;
+               *str = iter;
+               return S_OK;
+       }
+
+       HRESULT do_process_key(LPCOLESTR *pstr, HKEY parent_key, strbuf *buf, BOOL do_register)
+       {
+               LPCOLESTR                                                       iter;
+               HRESULT                                                         hres;
+               LONG                                                            lres;
+               HKEY                                                            hkey;
+               strbuf                                                          name;
+           
+               enum {
+                       NORMAL,
+                       NO_REMOVE,
+                       IS_VAL,
+                       FORCE_REMOVE,
+                       DO_DELETE
+               } key_type = NORMAL; 
+
+               static const wchar_t *wstrNoRemove = _T("NoRemove");
+               static const wchar_t *wstrForceRemove = _T("ForceRemove");
+               static const wchar_t *wstrDelete = _T("Delete");
+               static const wchar_t *wstrval = _T("val");
+
+               iter = *pstr;
+               hkey = NULL;
+               iter = *pstr;
+               hres = get_word(&iter, buf);
+               if (FAILED(hres))
+                       return hres;
+               strbuf_init(&name);
+
+               while(buf->str[1] || buf->str[0] != '}')
+               {
+                       key_type = NORMAL;
+                       if (!lstrcmpiW(buf->str, wstrNoRemove))
+                               key_type = NO_REMOVE;
+                       else if (!lstrcmpiW(buf->str, wstrForceRemove))
+                               key_type = FORCE_REMOVE;
+                       else if (!lstrcmpiW(buf->str, wstrval))
+                               key_type = IS_VAL;
+                       else if (!lstrcmpiW(buf->str, wstrDelete))
+                               key_type = DO_DELETE;
+
+                       if (key_type != NORMAL)
+                       {
+                               hres = get_word(&iter, buf);
+                               if(FAILED(hres))
+                                       break;
+                       }
+           
+                       if (do_register)
+                       {
+                               if (key_type == IS_VAL)
+                               {
+                                       hkey = parent_key;
+                                       strbuf_write(buf->str, &name, -1);
+                               }
+                               else if (key_type == DO_DELETE)
+                               {
+                                       RegDeleteTreeX(parent_key, buf->str);
+                               }
+                               else
+                               {
+                                       if (key_type == FORCE_REMOVE)
+                                               RegDeleteTreeX(parent_key, buf->str);
+                                       lres = RegCreateKey(parent_key, buf->str, &hkey);
+                                       if (lres != ERROR_SUCCESS)
+                                       {
+                                               hres = HRESULT_FROM_WIN32(lres);
+                                               break;
+                                       }
+                               }
+                       }
+                       else if (key_type != IS_VAL && key_type != DO_DELETE)
+                       {
+                               strbuf_write(buf->str, &name, -1);
+                               lres = RegOpenKey(parent_key, buf->str, &hkey);
+                               if (lres != ERROR_SUCCESS)
+                               {
+                               }
+                       }
+
+                       if (key_type != DO_DELETE && *iter == '=')
+                       {
+                               iter++;
+                               hres = get_word(&iter, buf);
+                               if (FAILED(hres))
+                                       break;
+                               if (buf->len != 1)
+                               {
+                                       hres = DISP_E_EXCEPTION;
+                                       break;
+                               }
+                               if (do_register)
+                               {
+                                       switch(buf->str[0])
+                                       {
+                                               case 's':
+                                                       hres = get_word(&iter, buf);
+                                                       if (FAILED(hres))
+                                                               break;
+                                                       lres = RegSetValueEx(hkey, name.len ? name.str :  NULL, 0, REG_SZ, (PBYTE)buf->str,
+                                                                       (lstrlenW(buf->str) + 1) * sizeof(WCHAR));
+                                                       if (lres != ERROR_SUCCESS)
+                                                       {
+                                                               hres = HRESULT_FROM_WIN32(lres);
+                                                               break;
+                                                       }
+                                                       break;
+                                               case 'd':
+                                                       {
+                                                               WCHAR *end;
+                                                               DWORD dw;
+                                                               if(*iter == '0' && iter[1] == 'x')
+                                                               {
+                                                                       iter += 2;
+                                                                       dw = wcstol(iter, &end, 16);
+                                                               }
+                                                               else
+                                                               {
+                                                                       dw = wcstol(iter, &end, 10);
+                                                               }
+                                                               iter = end;
+                                                               lres = RegSetValueEx(hkey, name.len ? name.str :  NULL, 0, REG_DWORD, (PBYTE)&dw, sizeof(dw));
+                                                               if (lres != ERROR_SUCCESS)
+                                                               {
+                                                                       hres = HRESULT_FROM_WIN32(lres);
+                                                                       break;
+                                                               }
+                                                               break;
+                                                       }
+                                               default:
+                                                       hres = DISP_E_EXCEPTION;
+                                       }
+                                       if (FAILED(hres))
+                                               break;
+                               }
+                               else
+                               {
+                                       if (*iter == '-')
+                                               iter++;
+                                       hres = get_word(&iter, buf);
+                                       if (FAILED(hres))
+                                               break;
+                               }
+                       }
+                       else if(key_type == IS_VAL)
+                       {
+                               hres = DISP_E_EXCEPTION;
+                               break;
+                       }
+
+                       if (key_type != IS_VAL && key_type != DO_DELETE && *iter == '{' && iswspace(iter[1]))
+                       {
+                               hres = get_word(&iter, buf);
+                               if (FAILED(hres))
+                                       break;
+                               hres = do_process_key(&iter, hkey, buf, do_register);
+                               if (FAILED(hres))
+                                       break;
+                       }
+
+                       if (!do_register && (key_type == NORMAL || key_type == FORCE_REMOVE))
+                       {
+                               RegDeleteKey(parent_key, name.str);
+                       }
+
+                       if (hkey && key_type != IS_VAL)
+                               RegCloseKey(hkey);
+                       hkey = 0;
+                       name.len = 0;
+               
+                       hres = get_word(&iter, buf);
+                       if (FAILED(hres))
+                               break;
+               }
+
+               HeapFree(GetProcessHeap(), 0, name.str);
+               if (hkey && key_type != IS_VAL)
+                       RegCloseKey(hkey);
+               *pstr = iter;
+               return hres;
+       }
+
+       HRESULT do_process_root_key(LPCOLESTR data, BOOL do_register)
+       {
+               LPCOLESTR                                                       iter;
+               strbuf                                                          buf;
+               unsigned int                                            i;
+               HRESULT                                                         hResult;
+               static const struct {
+                       wchar_t                                                 *name;
+                       HKEY                                                    key;
+               } root_keys[] = {
+                       {_T("HKEY_CLASSES_ROOT"), HKEY_CLASSES_ROOT},
+                       {_T("HKEY_CURRENT_USER"), HKEY_CURRENT_USER},
+                       {_T("HKEY_LOCAL_MACHINE"), HKEY_LOCAL_MACHINE},
+                       {_T("HKEY_USERS"), HKEY_USERS},
+                       {_T("HKEY_PERFORMANCE_DATA"), HKEY_PERFORMANCE_DATA},
+                       {_T("HKEY_DYN_DATA"), HKEY_DYN_DATA},
+                       {_T("HKEY_CURRENT_CONFIG"), HKEY_CURRENT_CONFIG},
+                       {_T("HKCR"), HKEY_CLASSES_ROOT},
+                       {_T("HKCU"), HKEY_CURRENT_USER},
+                       {_T("HKLM"), HKEY_LOCAL_MACHINE},
+                       {_T("HKU"), HKEY_USERS},
+                       {_T("HKPD"), HKEY_PERFORMANCE_DATA},
+                       {_T("HKDD"), HKEY_DYN_DATA},
+                       {_T("HKCC"), HKEY_CURRENT_CONFIG},
+               };
+
+               iter = data;
+               hResult = S_OK;
+
+               strbuf_init(&buf);
+               hResult = get_word(&iter, &buf);
+               if (FAILED(hResult))
+                       return hResult;
+
+               while (*iter)
+               {
+                       if (!buf.len)
+                       {
+                               hResult = DISP_E_EXCEPTION;
+                               break;
+                       }
+                       for (i = 0; i < sizeof(root_keys) / sizeof(root_keys[0]); i++)
+                       {
+                               if (!lstrcmpiW(buf.str, root_keys[i].name))
+                                       break;
+                       }
+                       if (i == sizeof(root_keys) / sizeof(root_keys[0]))
+                       {
+                               hResult = DISP_E_EXCEPTION;
+                               break;
+                       }
+                       hResult = get_word(&iter, &buf);
+                       if (FAILED(hResult))
+                               break;
+                       if (buf.str[1] || buf.str[0] != '{')
+                       {
+                               hResult = DISP_E_EXCEPTION;
+                               break;
+                       }
+                       hResult = do_process_key(&iter, root_keys[i].key, &buf, do_register);
+                       if (FAILED(hResult))
+                               break;
+                       hResult = get_word(&iter, &buf);
+                       if (FAILED(hResult))
+                               break;
+               }
+               HeapFree(GetProcessHeap(), 0, buf.str);
+               return hResult;
+       }
+
+};
+
+}; //namespace ATL
+
+#endif // _statreg_h
index 0b149cd..d8e3f4d 100644 (file)
@@ -7,6 +7,9 @@
        <directory name="sdk">
                <xi:include href="sdk/sdk.rbuild" />
        </directory>
+       <directory name="atl">
+               <xi:include href="atl/atl.rbuild" />
+       </directory>
        <directory name="cmlib">
                <xi:include href="cmlib/cmlib.rbuild" />
        </directory>