[ATL] Add CString.AllocSysString
[reactos.git] / sdk / lib / atl / atlsimpstr.h
1 #ifndef __ATLSIMPSTR_H__
2 #define __ATLSIMPSTR_H__
3
4 #pragma once
5
6 #include <atlcore.h>
7 #include <atlexcept.h>
8
9 namespace ATL
10 {
11 struct CStringData;
12
13 // Pure virtual interface
14 class IAtlStringMgr
15 {
16 public:
17
18 virtual ~IAtlStringMgr() {}
19
20 virtual _Ret_maybenull_ _Post_writable_byte_size_(sizeof(CStringData) + nAllocLength*nCharSize)
21 CStringData* Allocate(
22 _In_ int nAllocLength,
23 _In_ int nCharSize
24 ) = 0;
25
26 virtual void Free(
27 _Inout_ CStringData* pData
28 ) = 0;
29
30 virtual _Ret_maybenull_ _Post_writable_byte_size_(sizeof(CStringData) + nAllocLength*nCharSize)
31 CStringData* Reallocate(
32 _Inout_ CStringData* pData,
33 _In_ int nAllocLength,
34 _In_ int nCharSize
35 ) = 0;
36
37 virtual CStringData* GetNilString(void) = 0;
38 virtual IAtlStringMgr* Clone(void) = 0;
39 };
40
41
42 struct CStringData
43 {
44 IAtlStringMgr* pStringMgr;
45 int nAllocLength;
46 int nDataLength;
47 long nRefs;
48
49 void* data() throw()
50 {
51 return (this + 1);
52 }
53
54 void AddRef() throw()
55 {
56 ATLASSERT(nRefs > 0);
57 _InterlockedIncrement(&nRefs);
58 }
59
60 void Release() throw()
61 {
62 ATLASSERT(nRefs != 0);
63
64 if (_InterlockedDecrement(&nRefs) <= 0)
65 {
66 pStringMgr->Free(this);
67 }
68 }
69
70 bool IsLocked() const throw()
71 {
72 return (nRefs < 0);
73 }
74
75 bool IsShared() const throw()
76 {
77 return (nRefs > 1);
78 }
79 };
80
81 class CNilStringData :
82 public CStringData
83 {
84 public:
85 CNilStringData() throw()
86 {
87 pStringMgr = NULL;
88 nRefs = 2;
89 nDataLength = 0;
90 nAllocLength = 0;
91 achNil[0] = 0;
92 achNil[1] = 0;
93 }
94
95 void SetManager(_In_ IAtlStringMgr* pMgr) throw()
96 {
97 ATLASSERT(pStringMgr == NULL);
98 pStringMgr = pMgr;
99 }
100
101 public:
102 wchar_t achNil[2];
103 };
104
105
106 template< typename BaseType = char >
107 class ChTraitsBase
108 {
109 public:
110 typedef char XCHAR;
111 typedef LPSTR PXSTR;
112 typedef LPCSTR PCXSTR;
113 typedef wchar_t YCHAR;
114 typedef LPWSTR PYSTR;
115 typedef LPCWSTR PCYSTR;
116 };
117
118 template<>
119 class ChTraitsBase<wchar_t>
120 {
121 public:
122 typedef wchar_t XCHAR;
123 typedef LPWSTR PXSTR;
124 typedef LPCWSTR PCXSTR;
125 typedef char YCHAR;
126 typedef LPSTR PYSTR;
127 typedef LPCSTR PCYSTR;
128 };
129
130 template< typename BaseType, bool t_bMFCDLL = false>
131 class CSimpleStringT
132 {
133 public:
134 typedef typename ChTraitsBase<BaseType>::XCHAR XCHAR;
135 typedef typename ChTraitsBase<BaseType>::PXSTR PXSTR;
136 typedef typename ChTraitsBase<BaseType>::PCXSTR PCXSTR;
137 typedef typename ChTraitsBase<BaseType>::YCHAR YCHAR;
138 typedef typename ChTraitsBase<BaseType>::PYSTR PYSTR;
139 typedef typename ChTraitsBase<BaseType>::PCYSTR PCYSTR;
140
141 private:
142 PXSTR m_pszData;
143
144 public:
145 explicit CSimpleStringT(_Inout_ IAtlStringMgr* pStringMgr)
146 {
147 CStringData* pData = pStringMgr->GetNilString();
148 Attach(pData);
149 }
150
151 CSimpleStringT(_In_ const CSimpleStringT& strSrc)
152 {
153 CStringData* pSrcData = strSrc.GetData();
154 CStringData* pNewData = CloneData(pSrcData);
155 Attach(pNewData);
156 }
157
158 CSimpleStringT(
159 _In_z_ PCXSTR pszSrc,
160 _Inout_ IAtlStringMgr* pStringMgr)
161 {
162 int nLength = StringLength(pszSrc);
163 CStringData* pData = pStringMgr->Allocate(nLength, sizeof(XCHAR));
164 if (pData == NULL)
165 ThrowMemoryException();
166
167 Attach(pData);
168 SetLength(nLength);
169 CopyChars(m_pszData, nLength, pszSrc, nLength);
170 }
171
172 CSimpleStringT(
173 _In_count_(nLength) const XCHAR* pchSrc,
174 _In_ int nLength,
175 _Inout_ IAtlStringMgr* pStringMgr)
176 {
177 if (pchSrc == NULL && nLength != 0)
178 ThrowInvalidArgException();
179
180 CStringData* pData = pStringMgr->Allocate(nLength, sizeof(XCHAR));
181 if (pData == NULL)
182 {
183 ThrowMemoryException();
184 }
185 Attach(pData);
186 SetLength(nLength);
187 CopyChars(m_pszData, nLength, pchSrc, nLength);
188 }
189
190 ~CSimpleStringT() throw()
191 {
192 CStringData* pData = GetData();
193 pData->Release();
194 }
195
196 CSimpleStringT& operator=(_In_opt_z_ PCXSTR pszSrc)
197 {
198 SetString(pszSrc);
199 return *this;
200 }
201
202 CSimpleStringT& operator=(_In_ const CSimpleStringT& strSrc)
203 {
204 CStringData* pData = GetData();
205 CStringData* pNewData = strSrc.GetData();
206
207 if (pNewData != pData)
208 {
209 if (!pData->IsLocked() && (pNewData->pStringMgr == pData->pStringMgr))
210 {
211 pNewData = CloneData(pNewData);
212 pData->Release();
213 Attach(pNewData);
214 }
215 else
216 {
217 SetString(strSrc.GetString(), strSrc.GetLength());
218 }
219 }
220
221 return *this;
222 }
223
224 CSimpleStringT& operator+=(_In_ const CSimpleStringT& strSrc)
225 {
226 Append(strSrc);
227 return *this;
228 }
229
230 CSimpleStringT& operator+=(_In_z_ PCXSTR pszSrc)
231 {
232 Append(pszSrc);
233 return *this;
234 }
235
236 CSimpleStringT& operator+=(XCHAR ch)
237 {
238 Append(&ch, 1);
239 return *this;
240 }
241
242 operator PCXSTR() const throw()
243 {
244 return m_pszData;
245 }
246
247 void Empty() throw()
248 {
249 CStringData* pOldData = GetData();
250 IAtlStringMgr* pStringMgr = pOldData->pStringMgr;
251 if (pOldData->nDataLength == 0) return;
252
253 if (pOldData->IsLocked())
254 {
255 SetLength(0);
256 }
257 else
258 {
259 pOldData->Release();
260 CStringData* pNewData = pStringMgr->GetNilString();
261 Attach(pNewData);
262 }
263 }
264
265 void Append(
266 _In_count_(nLength) PCXSTR pszSrc,
267 _In_ int nLength)
268 {
269 UINT_PTR nOffset = pszSrc - GetString();
270
271 int nOldLength = GetLength();
272 if (nOldLength < 0)
273 nOldLength = 0;
274
275 ATLASSERT(nLength >= 0);
276
277 #if 0 // FIXME: See comment for StringLengthN below.
278 nLength = StringLengthN(pszSrc, nLength);
279 if (!(INT_MAX - nLength >= nOldLength))
280 throw;
281 #endif
282
283 int nNewLength = nOldLength + nLength;
284 PXSTR pszBuffer = GetBuffer(nNewLength);
285 if (nOffset <= (UINT_PTR)nOldLength)
286 {
287 pszSrc = pszBuffer + nOffset;
288 }
289 CopyChars(pszBuffer + nOldLength, nLength, pszSrc, nLength);
290 ReleaseBufferSetLength(nNewLength);
291 }
292
293 void Append(_In_z_ PCXSTR pszSrc)
294 {
295 Append(pszSrc, StringLength(pszSrc));
296 }
297
298 void Append(_In_ const CSimpleStringT& strSrc)
299 {
300 Append(strSrc.GetString(), strSrc.GetLength());
301 }
302
303 void SetString(_In_opt_z_ PCXSTR pszSrc)
304 {
305 SetString(pszSrc, StringLength(pszSrc));
306 }
307
308 void SetString(_In_reads_opt_(nLength) PCXSTR pszSrc,
309 _In_ int nLength)
310 {
311 if (nLength == 0)
312 {
313 Empty();
314 }
315 else
316 {
317 UINT nOldLength = GetLength();
318 UINT_PTR nOffset = pszSrc - GetString();
319
320 PXSTR pszBuffer = GetBuffer(nLength);
321 if (nOffset <= nOldLength)
322 {
323 CopyCharsOverlapped(pszBuffer, GetAllocLength(),
324 pszBuffer + nOffset, nLength);
325 }
326 else
327 {
328 CopyChars(pszBuffer, GetAllocLength(), pszSrc, nLength);
329 }
330 ReleaseBufferSetLength(nLength);
331 }
332 }
333
334 PXSTR GetBuffer()
335 {
336 CStringData* pData = GetData();
337 if (pData->IsShared())
338 {
339 // We should fork here
340 Fork(pData->nDataLength);
341 }
342
343 return m_pszData;
344 }
345
346 _Ret_notnull_ _Post_writable_size_(nMinBufferLength + 1) PXSTR GetBuffer(_In_ int nMinBufferLength)
347 {
348 return PrepareWrite(nMinBufferLength);
349 }
350
351 int GetAllocLength() const throw()
352 {
353 return GetData()->nAllocLength;
354 }
355
356 int GetLength() const throw()
357 {
358 return GetData()->nDataLength;
359 }
360
361 PCXSTR GetString() const throw()
362 {
363 return m_pszData;
364 }
365
366 void ReleaseBufferSetLength(_In_ int nNewLength)
367 {
368 ATLASSERT(nNewLength >= 0);
369 SetLength(nNewLength);
370 }
371
372 void ReleaseBuffer(_In_ int nNewLength = -1)
373 {
374 if (nNewLength < 0)
375 nNewLength = StringLength(m_pszData);
376 ReleaseBufferSetLength(nNewLength);
377 }
378
379 bool IsEmpty() const throw()
380 {
381 return (GetLength() == 0);
382 }
383
384 CStringData* GetData() const throw()
385 {
386 return (reinterpret_cast<CStringData*>(m_pszData) - 1);
387 }
388
389 IAtlStringMgr* GetManager() const throw()
390 {
391 IAtlStringMgr* pStringMgr = GetData()->pStringMgr;
392 return (pStringMgr ? pStringMgr->Clone() : NULL);
393 }
394
395 public:
396 friend CSimpleStringT operator+(
397 _In_ const CSimpleStringT& str1,
398 _In_ const CSimpleStringT& str2)
399 {
400 CSimpleStringT s(str1.GetManager());
401 Concatenate(s, str1, str1.GetLength(), str2, str2.GetLength());
402 return s;
403 }
404
405 friend CSimpleStringT operator+(
406 _In_ const CSimpleStringT& str1,
407 _In_z_ PCXSTR psz2)
408 {
409 CSimpleStringT s(str1.GetManager());
410 Concatenate(s, str1, str1.GetLength(), psz2, StringLength(psz2));
411 return s;
412 }
413
414 friend CSimpleStringT operator+(
415 _In_z_ PCXSTR psz1,
416 _In_ const CSimpleStringT& str2)
417 {
418 CSimpleStringT s(str2.GetManager());
419 Concatenate(s, psz1, StringLength(psz1), str2, str2.GetLength());
420 return s;
421 }
422
423 static void __cdecl CopyChars(
424 _Out_writes_to_(nDestLen, nChars) XCHAR* pchDest,
425 _In_ size_t nDestLen,
426 _In_reads_opt_(nChars) const XCHAR* pchSrc,
427 _In_ int nChars) throw()
428 {
429 memcpy(pchDest, pchSrc, nChars * sizeof(XCHAR));
430 }
431
432 static void __cdecl CopyCharsOverlapped(
433 _Out_writes_to_(nDestLen, nDestLen) XCHAR* pchDest,
434 _In_ size_t nDestLen,
435 _In_reads_(nChars) const XCHAR* pchSrc,
436 _In_ int nChars) throw()
437 {
438 memmove(pchDest, pchSrc, nChars * sizeof(XCHAR));
439 }
440
441 static int __cdecl StringLength(_In_opt_z_ const char* psz) throw()
442 {
443 if (psz == NULL) return 0;
444 return (int)strlen(psz);
445 }
446
447 static int __cdecl StringLength(_In_opt_z_ const wchar_t* psz) throw()
448 {
449 if (psz == NULL) return 0;
450 return (int)wcslen(psz);
451 }
452
453 #if 0 // For whatever reason we do not link with strnlen / wcsnlen. Please investigate!
454 // strnlen / wcsnlen are available in MSVCRT starting Vista+.
455 static int __cdecl StringLengthN(
456 _In_opt_z_count_(sizeInXChar) const char* psz,
457 _In_ size_t sizeInXChar) throw()
458 {
459 if (psz == NULL) return 0;
460 return (int)strnlen(psz, sizeInXChar);
461 }
462
463 static int __cdecl StringLengthN(
464 _In_opt_z_count_(sizeInXChar) const wchar_t* psz,
465 _In_ size_t sizeInXChar) throw()
466 {
467 if (psz == NULL) return 0;
468 return (int)wcsnlen(psz, sizeInXChar);
469 }
470 #endif
471
472 protected:
473 static void __cdecl Concatenate(
474 _Inout_ CSimpleStringT& strResult,
475 _In_count_(nLength1) PCXSTR psz1,
476 _In_ int nLength1,
477 _In_count_(nLength2) PCXSTR psz2,
478 _In_ int nLength2)
479 {
480 int nNewLength = nLength1 + nLength2;
481 PXSTR pszBuffer = strResult.GetBuffer(nNewLength);
482 CopyChars(pszBuffer, nLength1, psz1, nLength1);
483 CopyChars(pszBuffer + nLength1, nLength2, psz2, nLength2);
484 strResult.ReleaseBufferSetLength(nNewLength);
485 }
486
487 private:
488 void Attach(_Inout_ CStringData* pData) throw()
489 {
490 m_pszData = static_cast<PXSTR>(pData->data());
491 }
492
493 __declspec(noinline) void Fork(_In_ int nLength)
494 {
495 CStringData* pOldData = GetData();
496 int nOldLength = pOldData->nDataLength;
497 CStringData* pNewData = pOldData->pStringMgr->Clone()->Allocate(nLength, sizeof(XCHAR));
498 if (pNewData == NULL)
499 {
500 ThrowMemoryException();
501 }
502 int nCharsToCopy = ((nOldLength < nLength) ? nOldLength : nLength) + 1;
503 CopyChars(PXSTR(pNewData->data()), nCharsToCopy,
504 PCXSTR(pOldData->data()), nCharsToCopy);
505 pNewData->nDataLength = nOldLength;
506 pOldData->Release();
507 Attach(pNewData);
508 }
509
510 PXSTR PrepareWrite(_In_ int nLength)
511 {
512 CStringData* pOldData = GetData();
513 int nShared = 1 - pOldData->nRefs;
514 int nTooShort = pOldData->nAllocLength - nLength;
515 if ((nShared | nTooShort) < 0)
516 {
517 PrepareWrite2(nLength);
518 }
519
520 return m_pszData;
521 }
522 void PrepareWrite2(_In_ int nLength)
523 {
524 CStringData* pOldData = GetData();
525 if (pOldData->nDataLength > nLength)
526 {
527 nLength = pOldData->nDataLength;
528 }
529 if (pOldData->IsShared())
530 {
531 Fork(nLength);
532 //ATLASSERT(FALSE);
533 }
534 else if (pOldData->nAllocLength < nLength)
535 {
536 int nNewLength = pOldData->nAllocLength;
537 if (nNewLength > 1024 * 1024 * 1024)
538 {
539 nNewLength += 1024 * 1024;
540 }
541 else
542 {
543 nNewLength = nNewLength + nNewLength / 2;
544 }
545 if (nNewLength < nLength)
546 {
547 nNewLength = nLength;
548 }
549 Reallocate(nNewLength);
550 }
551 }
552
553 void Reallocate(_In_ int nLength)
554 {
555 CStringData* pOldData = GetData();
556 ATLASSERT(pOldData->nAllocLength < nLength);
557 IAtlStringMgr* pStringMgr = pOldData->pStringMgr;
558 if (pOldData->nAllocLength >= nLength || nLength <= 0)
559 {
560 return;
561 }
562 CStringData* pNewData = pStringMgr->Reallocate(pOldData, nLength, sizeof(XCHAR));
563 if (pNewData == NULL)
564 {
565 ThrowMemoryException();
566 }
567
568 Attach(pNewData);
569 }
570
571 void SetLength(_In_ int nLength)
572 {
573 ATLASSERT(nLength >= 0);
574 ATLASSERT(nLength <= GetData()->nAllocLength);
575
576 if (nLength < 0 || nLength > GetData()->nAllocLength)
577 {
578 AtlThrow(E_INVALIDARG);
579 }
580
581 GetData()->nDataLength = nLength;
582 m_pszData[nLength] = 0;
583 }
584
585 static CStringData* __cdecl CloneData(_Inout_ CStringData* pData)
586 {
587 CStringData* pNewData = NULL;
588
589 IAtlStringMgr* pNewStringMgr = pData->pStringMgr->Clone();
590 if (!pData->IsLocked() && (pNewStringMgr == pData->pStringMgr))
591 {
592 pNewData = pData;
593 pNewData->AddRef();
594 }
595 else
596 {
597 pNewData = pNewStringMgr->Allocate(pData->nDataLength, sizeof(XCHAR));
598 if (pNewData == NULL)
599 {
600 ThrowMemoryException();
601 }
602
603 pNewData->nDataLength = pData->nDataLength;
604 CopyChars(PXSTR(pNewData->data()), pData->nDataLength + 1,
605 PCXSTR(pData->data()), pData->nDataLength + 1);
606 }
607
608 return pNewData;
609 }
610
611
612 static void ThrowMemoryException()
613 {
614 AtlThrow(E_OUTOFMEMORY);
615 }
616
617 static void ThrowInvalidArgException()
618 {
619 AtlThrow(E_INVALIDARG);
620 }
621
622 };
623
624 #ifdef UNICODE
625 typedef CSimpleStringT<WCHAR> CSimpleString;
626 #else
627 typedef CSimpleStringT<CHAR> CSimpleString;
628 #endif
629 }
630
631 #endif