d71e6f1fe0b2eb4c7a9ab751cfad15cb7730b7eb
[reactos.git] / reactos / dll / win32 / msi / string.c
1 /*
2 * String Table Functions
3 *
4 * Copyright 2002-2004, Mike McCormack for CodeWeavers
5 * Copyright 2007 Robert Shearman for CodeWeavers
6 * Copyright 2010 Hans Leidekker for CodeWeavers
7 *
8 * This library is free software; you can redistribute it and/or
9 * modify it under the terms of the GNU Lesser General Public
10 * License as published by the Free Software Foundation; either
11 * version 2.1 of the License, or (at your option) any later version.
12 *
13 * This library is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 * Lesser General Public License for more details.
17 *
18 * You should have received a copy of the GNU Lesser General Public
19 * License along with this library; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21 */
22
23 #define WIN32_NO_STATUS
24 #define _INC_WINDOWS
25 #define COM_NO_WINDOWS_H
26
27 #define COBJMACROS
28
29 //#include <stdarg.h>
30 //#include <assert.h>
31
32 //#include "windef.h"
33 //#include "winbase.h"
34 //#include "winerror.h"
35 #include <wine/debug.h>
36 #include <wine/unicode.h>
37 //#include "msi.h"
38 //#include "msiquery.h"
39 //#include "objbase.h"
40 //#include "objidl.h"
41 #include "msipriv.h"
42 //#include "winnls.h"
43
44 //#include "query.h"
45
46 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
47
48 struct msistring
49 {
50 USHORT persistent_refcount;
51 USHORT nonpersistent_refcount;
52 WCHAR *data;
53 int len;
54 };
55
56 struct string_table
57 {
58 UINT maxcount; /* the number of strings */
59 UINT freeslot;
60 UINT codepage;
61 UINT sortcount;
62 struct msistring *strings; /* an array of strings */
63 UINT *sorted; /* index */
64 };
65
66 static BOOL validate_codepage( UINT codepage )
67 {
68 if (codepage != CP_ACP && !IsValidCodePage( codepage ))
69 {
70 WARN("invalid codepage %u\n", codepage);
71 return FALSE;
72 }
73 return TRUE;
74 }
75
76 static string_table *init_stringtable( int entries, UINT codepage )
77 {
78 string_table *st;
79
80 if (!validate_codepage( codepage ))
81 return NULL;
82
83 st = msi_alloc( sizeof (string_table) );
84 if( !st )
85 return NULL;
86 if( entries < 1 )
87 entries = 1;
88
89 st->strings = msi_alloc_zero( sizeof(struct msistring) * entries );
90 if( !st->strings )
91 {
92 msi_free( st );
93 return NULL;
94 }
95
96 st->sorted = msi_alloc( sizeof (UINT) * entries );
97 if( !st->sorted )
98 {
99 msi_free( st->strings );
100 msi_free( st );
101 return NULL;
102 }
103
104 st->maxcount = entries;
105 st->freeslot = 1;
106 st->codepage = codepage;
107 st->sortcount = 0;
108
109 return st;
110 }
111
112 VOID msi_destroy_stringtable( string_table *st )
113 {
114 UINT i;
115
116 for( i=0; i<st->maxcount; i++ )
117 {
118 if( st->strings[i].persistent_refcount ||
119 st->strings[i].nonpersistent_refcount )
120 msi_free( st->strings[i].data );
121 }
122 msi_free( st->strings );
123 msi_free( st->sorted );
124 msi_free( st );
125 }
126
127 static int st_find_free_entry( string_table *st )
128 {
129 UINT i, sz, *s;
130 struct msistring *p;
131
132 TRACE("%p\n", st);
133
134 if( st->freeslot )
135 {
136 for( i = st->freeslot; i < st->maxcount; i++ )
137 if( !st->strings[i].persistent_refcount &&
138 !st->strings[i].nonpersistent_refcount )
139 return i;
140 }
141 for( i = 1; i < st->maxcount; i++ )
142 if( !st->strings[i].persistent_refcount &&
143 !st->strings[i].nonpersistent_refcount )
144 return i;
145
146 /* dynamically resize */
147 sz = st->maxcount + 1 + st->maxcount/2;
148 p = msi_realloc_zero( st->strings, sz * sizeof(struct msistring) );
149 if( !p )
150 return -1;
151
152 s = msi_realloc( st->sorted, sz*sizeof(UINT) );
153 if( !s )
154 {
155 msi_free( p );
156 return -1;
157 }
158
159 st->strings = p;
160 st->sorted = s;
161
162 st->freeslot = st->maxcount;
163 st->maxcount = sz;
164 if( st->strings[st->freeslot].persistent_refcount ||
165 st->strings[st->freeslot].nonpersistent_refcount )
166 ERR("oops. expected freeslot to be free...\n");
167 return st->freeslot;
168 }
169
170 static inline int cmp_string( const WCHAR *str1, int len1, const WCHAR *str2, int len2 )
171 {
172 if (len1 < len2) return -1;
173 else if (len1 > len2) return 1;
174 while (len1)
175 {
176 if (*str1 == *str2) { str1++; str2++; }
177 else return *str1 - *str2;
178 len1--;
179 }
180 return 0;
181 }
182
183 static int find_insert_index( const string_table *st, UINT string_id )
184 {
185 int i, c, low = 0, high = st->sortcount - 1;
186
187 while (low <= high)
188 {
189 i = (low + high) / 2;
190 c = cmp_string( st->strings[string_id].data, st->strings[string_id].len,
191 st->strings[st->sorted[i]].data, st->strings[st->sorted[i]].len );
192 if (c < 0)
193 high = i - 1;
194 else if (c > 0)
195 low = i + 1;
196 else
197 return -1; /* already exists */
198 }
199 return high + 1;
200 }
201
202 static void insert_string_sorted( string_table *st, UINT string_id )
203 {
204 int i;
205
206 i = find_insert_index( st, string_id );
207 if (i == -1)
208 return;
209
210 memmove( &st->sorted[i] + 1, &st->sorted[i], (st->sortcount - i) * sizeof(UINT) );
211 st->sorted[i] = string_id;
212 st->sortcount++;
213 }
214
215 static void set_st_entry( string_table *st, UINT n, WCHAR *str, int len, USHORT refcount,
216 enum StringPersistence persistence )
217 {
218 if (persistence == StringPersistent)
219 {
220 st->strings[n].persistent_refcount = refcount;
221 st->strings[n].nonpersistent_refcount = 0;
222 }
223 else
224 {
225 st->strings[n].persistent_refcount = 0;
226 st->strings[n].nonpersistent_refcount = refcount;
227 }
228
229 st->strings[n].data = str;
230 st->strings[n].len = len;
231
232 insert_string_sorted( st, n );
233
234 if( n < st->maxcount )
235 st->freeslot = n + 1;
236 }
237
238 static UINT msi_string2idA( const string_table *st, LPCSTR buffer, UINT *id )
239 {
240 DWORD sz;
241 UINT r = ERROR_INVALID_PARAMETER;
242 LPWSTR str;
243
244 TRACE("Finding string %s in string table\n", debugstr_a(buffer) );
245
246 if( buffer[0] == 0 )
247 {
248 *id = 0;
249 return ERROR_SUCCESS;
250 }
251
252 sz = MultiByteToWideChar( st->codepage, 0, buffer, -1, NULL, 0 );
253 if( sz <= 0 )
254 return r;
255 str = msi_alloc( sz*sizeof(WCHAR) );
256 if( !str )
257 return ERROR_NOT_ENOUGH_MEMORY;
258 MultiByteToWideChar( st->codepage, 0, buffer, -1, str, sz );
259
260 r = msi_string2id( st, str, sz - 1, id );
261 msi_free( str );
262 return r;
263 }
264
265 static int msi_addstring( string_table *st, UINT n, const char *data, UINT len, USHORT refcount, enum StringPersistence persistence )
266 {
267 LPWSTR str;
268 int sz;
269
270 if( !data || !len )
271 return 0;
272 if( n > 0 )
273 {
274 if( st->strings[n].persistent_refcount ||
275 st->strings[n].nonpersistent_refcount )
276 return -1;
277 }
278 else
279 {
280 if( ERROR_SUCCESS == msi_string2idA( st, data, &n ) )
281 {
282 if (persistence == StringPersistent)
283 st->strings[n].persistent_refcount += refcount;
284 else
285 st->strings[n].nonpersistent_refcount += refcount;
286 return n;
287 }
288 n = st_find_free_entry( st );
289 if( n == -1 )
290 return -1;
291 }
292
293 if( n < 1 )
294 {
295 ERR("invalid index adding %s (%d)\n", debugstr_a( data ), n );
296 return -1;
297 }
298
299 /* allocate a new string */
300 sz = MultiByteToWideChar( st->codepage, 0, data, len, NULL, 0 );
301 str = msi_alloc( (sz+1)*sizeof(WCHAR) );
302 if( !str )
303 return -1;
304 MultiByteToWideChar( st->codepage, 0, data, len, str, sz );
305 str[sz] = 0;
306
307 set_st_entry( st, n, str, sz, refcount, persistence );
308 return n;
309 }
310
311 int msi_addstringW( string_table *st, const WCHAR *data, int len, USHORT refcount, enum StringPersistence persistence )
312 {
313 UINT n;
314 LPWSTR str;
315
316 if( !data )
317 return 0;
318
319 if (len < 0) len = strlenW( data );
320
321 if( !data[0] && !len )
322 return 0;
323
324 if (msi_string2id( st, data, len, &n) == ERROR_SUCCESS )
325 {
326 if (persistence == StringPersistent)
327 st->strings[n].persistent_refcount += refcount;
328 else
329 st->strings[n].nonpersistent_refcount += refcount;
330 return n;
331 }
332
333 n = st_find_free_entry( st );
334 if( n == -1 )
335 return -1;
336
337 /* allocate a new string */
338 TRACE( "%s, n = %d len = %d\n", debugstr_wn(data, len), n, len );
339
340 str = msi_alloc( (len+1)*sizeof(WCHAR) );
341 if( !str )
342 return -1;
343 memcpy( str, data, len*sizeof(WCHAR) );
344 str[len] = 0;
345
346 set_st_entry( st, n, str, len, refcount, persistence );
347 return n;
348 }
349
350 /* find the string identified by an id - return null if there's none */
351 const WCHAR *msi_string_lookup( const string_table *st, UINT id, int *len )
352 {
353 if( id == 0 )
354 {
355 if (len) *len = 0;
356 return szEmpty;
357 }
358 if( id >= st->maxcount )
359 return NULL;
360
361 if( id && !st->strings[id].persistent_refcount && !st->strings[id].nonpersistent_refcount)
362 return NULL;
363
364 if (len) *len = st->strings[id].len;
365
366 return st->strings[id].data;
367 }
368
369 /*
370 * msi_id2stringA
371 *
372 * [in] st - pointer to the string table
373 * [in] id - id of the string to retrieve
374 * [out] buffer - destination of the UTF8 string
375 * [in/out] sz - number of bytes available in the buffer on input
376 * number of bytes used on output
377 *
378 * Returned string is not nul terminated.
379 */
380 static UINT msi_id2stringA( const string_table *st, UINT id, LPSTR buffer, UINT *sz )
381 {
382 int len, lenW;
383 const WCHAR *str;
384
385 TRACE("Finding string %d of %d\n", id, st->maxcount);
386
387 str = msi_string_lookup( st, id, &lenW );
388 if( !str )
389 return ERROR_FUNCTION_FAILED;
390
391 len = WideCharToMultiByte( st->codepage, 0, str, lenW, NULL, 0, NULL, NULL );
392 if( *sz < len )
393 {
394 *sz = len;
395 return ERROR_MORE_DATA;
396 }
397 *sz = WideCharToMultiByte( st->codepage, 0, str, lenW, buffer, *sz, NULL, NULL );
398 return ERROR_SUCCESS;
399 }
400
401 /*
402 * msi_string2id
403 *
404 * [in] st - pointer to the string table
405 * [in] str - string to find in the string table
406 * [out] id - id of the string, if found
407 */
408 UINT msi_string2id( const string_table *st, const WCHAR *str, int len, UINT *id )
409 {
410 int i, c, low = 0, high = st->sortcount - 1;
411
412 if (len < 0) len = strlenW( str );
413
414 while (low <= high)
415 {
416 i = (low + high) / 2;
417 c = cmp_string( str, len, st->strings[st->sorted[i]].data, st->strings[st->sorted[i]].len );
418
419 if (c < 0)
420 high = i - 1;
421 else if (c > 0)
422 low = i + 1;
423 else
424 {
425 *id = st->sorted[i];
426 return ERROR_SUCCESS;
427 }
428 }
429 return ERROR_INVALID_PARAMETER;
430 }
431
432 static void string_totalsize( const string_table *st, UINT *datasize, UINT *poolsize )
433 {
434 UINT i, len, holesize;
435
436 if( st->strings[0].data || st->strings[0].persistent_refcount || st->strings[0].nonpersistent_refcount)
437 ERR("oops. element 0 has a string\n");
438
439 *poolsize = 4;
440 *datasize = 0;
441 holesize = 0;
442 for( i=1; i<st->maxcount; i++ )
443 {
444 if( !st->strings[i].persistent_refcount )
445 {
446 TRACE("[%u] nonpersistent = %s\n", i, debugstr_wn(st->strings[i].data, st->strings[i].len));
447 (*poolsize) += 4;
448 }
449 else if( st->strings[i].data )
450 {
451 TRACE("[%u] = %s\n", i, debugstr_wn(st->strings[i].data, st->strings[i].len));
452 len = WideCharToMultiByte( st->codepage, 0, st->strings[i].data, st->strings[i].len + 1,
453 NULL, 0, NULL, NULL);
454 if( len )
455 len--;
456 (*datasize) += len;
457 if (len>0xffff)
458 (*poolsize) += 4;
459 (*poolsize) += holesize + 4;
460 holesize = 0;
461 }
462 else
463 holesize += 4;
464 }
465 TRACE("data %u pool %u codepage %x\n", *datasize, *poolsize, st->codepage );
466 }
467
468 HRESULT msi_init_string_table( IStorage *stg )
469 {
470 USHORT zero[2] = { 0, 0 };
471 UINT ret;
472
473 /* create the StringPool stream... add the zero string to it*/
474 ret = write_stream_data(stg, szStringPool, zero, sizeof zero, TRUE);
475 if (ret != ERROR_SUCCESS)
476 return E_FAIL;
477
478 /* create the StringData stream... make it zero length */
479 ret = write_stream_data(stg, szStringData, NULL, 0, TRUE);
480 if (ret != ERROR_SUCCESS)
481 return E_FAIL;
482
483 return S_OK;
484 }
485
486 string_table *msi_load_string_table( IStorage *stg, UINT *bytes_per_strref )
487 {
488 string_table *st = NULL;
489 CHAR *data = NULL;
490 USHORT *pool = NULL;
491 UINT r, datasize = 0, poolsize = 0, codepage;
492 DWORD i, count, offset, len, n, refs;
493
494 r = read_stream_data( stg, szStringPool, TRUE, (BYTE **)&pool, &poolsize );
495 if( r != ERROR_SUCCESS)
496 goto end;
497 r = read_stream_data( stg, szStringData, TRUE, (BYTE **)&data, &datasize );
498 if( r != ERROR_SUCCESS)
499 goto end;
500
501 if ( (poolsize > 4) && (pool[1] & 0x8000) )
502 *bytes_per_strref = LONG_STR_BYTES;
503 else
504 *bytes_per_strref = sizeof(USHORT);
505
506 count = poolsize/4;
507 if( poolsize > 4 )
508 codepage = pool[0] | ( (pool[1] & ~0x8000) << 16 );
509 else
510 codepage = CP_ACP;
511 st = init_stringtable( count, codepage );
512 if (!st)
513 goto end;
514
515 offset = 0;
516 n = 1;
517 i = 1;
518 while( i<count )
519 {
520 /* the string reference count is always the second word */
521 refs = pool[i*2+1];
522
523 /* empty entries have two zeros, still have a string id */
524 if (pool[i*2] == 0 && refs == 0)
525 {
526 i++;
527 n++;
528 continue;
529 }
530
531 /*
532 * If a string is over 64k, the previous string entry is made null
533 * and its the high word of the length is inserted in the null string's
534 * reference count field.
535 */
536 if( pool[i*2] == 0)
537 {
538 len = (pool[i*2+3] << 16) + pool[i*2+2];
539 i += 2;
540 }
541 else
542 {
543 len = pool[i*2];
544 i += 1;
545 }
546
547 if ( (offset + len) > datasize )
548 {
549 ERR("string table corrupt?\n");
550 break;
551 }
552
553 r = msi_addstring( st, n, data+offset, len, refs, StringPersistent );
554 if( r != n )
555 ERR("Failed to add string %d\n", n );
556 n++;
557 offset += len;
558 }
559
560 if ( datasize != offset )
561 ERR("string table load failed! (%08x != %08x), please report\n", datasize, offset );
562
563 TRACE("Loaded %d strings\n", count);
564
565 end:
566 msi_free( pool );
567 msi_free( data );
568
569 return st;
570 }
571
572 UINT msi_save_string_table( const string_table *st, IStorage *storage, UINT *bytes_per_strref )
573 {
574 UINT i, datasize = 0, poolsize = 0, sz, used, r, codepage, n;
575 UINT ret = ERROR_FUNCTION_FAILED;
576 CHAR *data = NULL;
577 USHORT *pool = NULL;
578
579 TRACE("\n");
580
581 /* construct the new table in memory first */
582 string_totalsize( st, &datasize, &poolsize );
583
584 TRACE("%u %u %u\n", st->maxcount, datasize, poolsize );
585
586 pool = msi_alloc( poolsize );
587 if( ! pool )
588 {
589 WARN("Failed to alloc pool %d bytes\n", poolsize );
590 goto err;
591 }
592 data = msi_alloc( datasize );
593 if( ! data )
594 {
595 WARN("Failed to alloc data %d bytes\n", datasize );
596 goto err;
597 }
598
599 used = 0;
600 codepage = st->codepage;
601 pool[0] = codepage & 0xffff;
602 pool[1] = codepage >> 16;
603 if (st->maxcount > 0xffff)
604 {
605 pool[1] |= 0x8000;
606 *bytes_per_strref = LONG_STR_BYTES;
607 }
608 else
609 *bytes_per_strref = sizeof(USHORT);
610
611 n = 1;
612 for( i=1; i<st->maxcount; i++ )
613 {
614 if( !st->strings[i].persistent_refcount )
615 {
616 pool[ n*2 ] = 0;
617 pool[ n*2 + 1] = 0;
618 n++;
619 continue;
620 }
621
622 sz = datasize - used;
623 r = msi_id2stringA( st, i, data+used, &sz );
624 if( r != ERROR_SUCCESS )
625 {
626 ERR("failed to fetch string\n");
627 sz = 0;
628 }
629
630 if (sz)
631 pool[ n*2 + 1 ] = st->strings[i].persistent_refcount;
632 else
633 pool[ n*2 + 1 ] = 0;
634 if (sz < 0x10000)
635 {
636 pool[ n*2 ] = sz;
637 n++;
638 }
639 else
640 {
641 pool[ n*2 ] = 0;
642 pool[ n*2 + 2 ] = sz&0xffff;
643 pool[ n*2 + 3 ] = (sz>>16);
644 n += 2;
645 }
646 used += sz;
647 if( used > datasize )
648 {
649 ERR("oops overran %d >= %d\n", used, datasize);
650 goto err;
651 }
652 }
653
654 if( used != datasize )
655 {
656 ERR("oops used %d != datasize %d\n", used, datasize);
657 goto err;
658 }
659
660 /* write the streams */
661 r = write_stream_data( storage, szStringData, data, datasize, TRUE );
662 TRACE("Wrote StringData r=%08x\n", r);
663 if( r )
664 goto err;
665 r = write_stream_data( storage, szStringPool, pool, poolsize, TRUE );
666 TRACE("Wrote StringPool r=%08x\n", r);
667 if( r )
668 goto err;
669
670 ret = ERROR_SUCCESS;
671
672 err:
673 msi_free( data );
674 msi_free( pool );
675
676 return ret;
677 }
678
679 UINT msi_get_string_table_codepage( const string_table *st )
680 {
681 return st->codepage;
682 }
683
684 UINT msi_set_string_table_codepage( string_table *st, UINT codepage )
685 {
686 if (validate_codepage( codepage ))
687 {
688 st->codepage = codepage;
689 return ERROR_SUCCESS;
690 }
691 return ERROR_FUNCTION_FAILED;
692 }