Prevent buffer overflows in vsn(w)printf
authorguido <guido@svn.reactos.org>
Thu, 12 Sep 2002 17:50:42 +0000 (17:50 +0000)
committerguido <guido@svn.reactos.org>
Thu, 12 Sep 2002 17:50:42 +0000 (17:50 +0000)
svn path=/trunk/; revision=3488

reactos/lib/ntdll/stdio/sprintf.c
reactos/lib/ntdll/stdio/swprintf.c
reactos/ntoskrnl/rtl/sprintf.c
reactos/ntoskrnl/rtl/swprintf.c

index d2a71dd..7ed8c1e 100644 (file)
@@ -1,4 +1,4 @@
-/* $Id: sprintf.c,v 1.8 2002/09/08 10:23:07 chorns Exp $
+/* $Id: sprintf.c,v 1.9 2002/09/12 17:50:42 guido Exp $
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
@@ -7,8 +7,6 @@
  * PROGRAMMERS:     David Welch
  *                  Eric Kohl
  *
  * PROGRAMMERS:     David Welch
  *                  Eric Kohl
  *
- * TODO:
- *     - Implement maximum length (cnt) in _vsnprintf().
  */
 
 /*
  */
 
 /*
@@ -60,22 +58,21 @@ static int skip_atoi(const char **s)
 
 
 static char *
 
 
 static char *
-number (char * str, long long num, int base, int size, int precision, int type)
+number(char * buf, char * end, long long num, int base, int size, int precision, int type)
 {
        char c,sign,tmp[66];
 {
        char c,sign,tmp[66];
-       const char *digits="0123456789abcdefghijklmnopqrstuvwxyz";
+       const char *digits;
+       const char small_digits[] = "0123456789abcdefghijklmnopqrstuvwxyz";
+       const char large_digits[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
        int i;
 
        int i;
 
-       if (type & LARGE)
-               digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+       digits = (type & LARGE) ? large_digits : small_digits;
        if (type & LEFT)
                type &= ~ZEROPAD;
        if (base < 2 || base > 36)
                return 0;
        if (type & LEFT)
                type &= ~ZEROPAD;
        if (base < 2 || base > 36)
                return 0;
-
        c = (type & ZEROPAD) ? '0' : ' ';
        sign = 0;
        c = (type & ZEROPAD) ? '0' : ' ';
        sign = 0;
-
        if (type & SIGN) {
                if (num < 0) {
                        sign = '-';
        if (type & SIGN) {
                if (num < 0) {
                        sign = '-';
@@ -89,14 +86,12 @@ number (char * str, long long num, int base, int size, int precision, int type)
                        size--;
                }
        }
                        size--;
                }
        }
-
        if (type & SPECIAL) {
                if (base == 16)
                        size -= 2;
                else if (base == 8)
                        size--;
        }
        if (type & SPECIAL) {
                if (base == 16)
                        size -= 2;
                else if (base == 8)
                        size--;
        }
-
        i = 0;
        if (num == 0)
                tmp[i++]='0';
        i = 0;
        if (num == 0)
                tmp[i++]='0';
@@ -105,29 +100,55 @@ number (char * str, long long num, int base, int size, int precision, int type)
        if (i > precision)
                precision = i;
        size -= precision;
        if (i > precision)
                precision = i;
        size -= precision;
-       if (!(type&(ZEROPAD+LEFT)))
-               while(size-->0)
-                       *str++ = ' ';
-       if (sign)
-               *str++ = sign;
+       if (!(type&(ZEROPAD+LEFT))) {
+               while(size-->0) {
+                       if (buf <= end)
+                               *buf = ' ';
+                       ++buf;
+               }
+       }
+       if (sign) {
+               if (buf <= end)
+                       *buf = sign;
+               ++buf;
+       }
        if (type & SPECIAL) {
                if (base==8) {
        if (type & SPECIAL) {
                if (base==8) {
-                       *str++ = '0';
+                       if (buf <= end)
+                               *buf = '0';
+                       ++buf;
                } else if (base==16) {
                } else if (base==16) {
-                       *str++ = '0';
-                       *str++ = digits[33];
+                       if (buf <= end)
+                               *buf = '0';
+                       ++buf;
+                       if (buf <= end)
+                               *buf = digits[33];
+                       ++buf;
+               }
+       }
+       if (!(type & LEFT)) {
+               while (size-- > 0) {
+                       if (buf <= end)
+                               *buf = c;
+                       ++buf;
                }
        }
                }
        }
-       if (!(type & LEFT))
-               while (size-- > 0)
-                       *str++ = c;
-       while (i < precision--)
-               *str++ = '0';
-       while (i-- > 0)
-               *str++ = tmp[i];
-       while (size-- > 0)
-               *str++ = ' ';
-       return str;
+       while (i < precision--) {
+               if (buf <= end)
+                       *buf = '0';
+               ++buf;
+       }
+       while (i-- > 0) {
+               if (buf <= end)
+                       *buf = tmp[i];
+               ++buf;
+       }
+       while (size-- > 0) {
+               if (buf <= end)
+                       *buf = ' ';
+               ++buf;
+       }
+       return buf;
 }
 
 
 }
 
 
@@ -136,7 +157,7 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        int len;
        unsigned long long num;
        int i, base;
        int len;
        unsigned long long num;
        int i, base;
-       char * str;
+       char *str, *end;
        const char *s;
        const wchar_t *sw;
 
        const char *s;
        const wchar_t *sw;
 
@@ -147,9 +168,18 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                                   number of chars for from string */
        int qualifier;          /* 'h', 'l', 'L', 'I' or 'w' for integer fields */
 
                                   number of chars for from string */
        int qualifier;          /* 'h', 'l', 'L', 'I' or 'w' for integer fields */
 
-       for (str=buf ; *fmt ; ++fmt) {
+       str = buf;
+       end = buf + cnt - 1;
+       if (end < buf - 1) {
+               end = ((void *) -1);
+               cnt = end - buf + 1;
+       }
+
+       for ( ; *fmt ; ++fmt) {
                if (*fmt != '%') {
                if (*fmt != '%') {
-                       *str++ = *fmt;
+                       if (str <= end)
+                               *str = *fmt;
+                       ++str;
                        continue;
                }
 
                        continue;
                }
 
@@ -210,28 +240,48 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                switch (*fmt) {
                case 'c': /* finished */
                        if (!(flags & LEFT))
                switch (*fmt) {
                case 'c': /* finished */
                        if (!(flags & LEFT))
-                               while (--field_width > 0)
-                                       *str++ = ' ';
-                       if (qualifier == 'l' || qualifier == 'w')
-                               *str++ = 
-                                 (unsigned char)(wchar_t) va_arg(args, int);
-                       else
-                               *str++ = (unsigned char) va_arg(args, int);
-                       while (--field_width > 0)
-                               *str++ = ' ';
+                               while (--field_width > 0) {
+                                       if (str <= end)
+                                               *str = ' ';
+                                       ++str;
+                               }
+                       if (qualifier == 'l' || qualifier == 'w') {
+                               if (str <= end)
+                                       *str = (unsigned char)(wchar_t) va_arg(args, int);
+                               ++str;
+                       } else {
+                               if (str <= end)
+                                       *str = (unsigned char) va_arg(args, int);
+                               ++str;
+                       }
+                       while (--field_width > 0) {
+                               if (str <= end)
+                                       *str = ' ';
+                               ++str;
+                       }
                        continue;
 
                case 'C': /* finished */
                        if (!(flags & LEFT))
                        continue;
 
                case 'C': /* finished */
                        if (!(flags & LEFT))
-                               while (--field_width > 0)
-                                       *str++ = ' ';
-                       if (qualifier == 'h')
-                               *str++ = (unsigned char) va_arg(args, int);
-                       else
-                               *str++ = 
-                                 (unsigned char)(wchar_t) va_arg(args, int);
-                       while (--field_width > 0)
-                               *str++ = ' ';
+                               while (--field_width > 0) {
+                                       if (str <= end)
+                                               *str = ' ';
+                                       ++str;
+                               }
+                       if (qualifier == 'h') {
+                               if (str <= end)
+                                       *str = (unsigned char) va_arg(args, int);
+                               ++str;
+                       } else {
+                               if (str <= end)
+                                       *str = (unsigned char)(wchar_t) va_arg(args, int);
+                               ++str;
+                       }
+                       while (--field_width > 0) {
+                               if (str <= end)
+                                       *str = ' ';
+                               ++str;
+                       }
                        continue;
 
                case 's': /* finished */
                        continue;
 
                case 's': /* finished */
@@ -246,12 +296,22 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = ' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = (unsigned char)(*sw++);
-                               while (len < field_width--)
-                                       *str++ = ' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = ' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = (unsigned char)(*sw);
+                                       ++str;
+                                       ++sw;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = ' ';
+                                       ++str;
+                               }
                        } else {
                                /* print ascii string */
                                s = va_arg(args, char *);
                        } else {
                                /* print ascii string */
                                s = va_arg(args, char *);
@@ -263,12 +323,22 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = ' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = *s++;
-                               while (len < field_width--)
-                                       *str++ = ' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = ' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = *s;
+                                       ++str;
+                                       ++s;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = ' ';
+                                       ++str;
+                               }
                        }
                        continue;
 
                        }
                        continue;
 
@@ -284,12 +354,22 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = ' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = *s++;
-                               while (len < field_width--)
-                                       *str++ = ' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = ' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = *s;
+                                       ++str;
+                                       ++s;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = ' ';
+                                       ++str;
+                               }
                        } else {
                                /* print unicode string */
                                sw = va_arg(args, wchar_t *);
                        } else {
                                /* print unicode string */
                                sw = va_arg(args, wchar_t *);
@@ -301,12 +381,22 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = ' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = (unsigned char)(*sw++);
-                               while (len < field_width--)
-                                       *str++ = ' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = ' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = (unsigned char)(*sw);
+                                       ++str;
+                                       ++sw;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = ' ';
+                                       ++str;
+                               }
                        }
                        continue;
 
                        }
                        continue;
 
@@ -316,22 +406,36 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                                PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        s = "<NULL>";
                                PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        s = "<NULL>";
-                                       while ((*s) != 0)
-                                               *str++ = *s++;
+                                       while ((*s) != 0) {
+                                               if (str <= end)
+                                                       *str = *s;
+                                               ++str;
+                                               ++s;
+                                       }
                                } else {
                                } else {
-                                       for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++)
-                                               *str++ = (unsigned char)(pus->Buffer[i]);
+                                       for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++) {
+                                               if (str <= end)
+                                                       *str = (unsigned char)(pus->Buffer[i]);
+                                               ++str;
+                                       }
                                }
                        } else {
                                /* print counted ascii string */
                                PANSI_STRING pus = va_arg(args, PANSI_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        s = "<NULL>";
                                }
                        } else {
                                /* print counted ascii string */
                                PANSI_STRING pus = va_arg(args, PANSI_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        s = "<NULL>";
-                                       while ((*s) != 0)
-                                               *str++ = *s++;
+                                       while ((*s) != 0) {
+                                               if (str <= end)
+                                                       *str = *s;
+                                               ++str;
+                                               ++s;
+                                       }
                                } else {
                                } else {
-                                       for (i = 0; pus->Buffer[i] && i < pus->Length; i++)
-                                               *str++ = pus->Buffer[i];
+                                       for (i = 0; pus->Buffer[i] && i < pus->Length; i++) {
+                                               if (str <= end)
+                                                       *str = pus->Buffer[i];
+                                               ++str;
+                                       }
                                }
                        }
                        continue;
                                }
                        }
                        continue;
@@ -341,12 +445,13 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                                field_width = 2 * sizeof(void *);
                                flags |= ZEROPAD;
                        }
                                field_width = 2 * sizeof(void *);
                                flags |= ZEROPAD;
                        }
-                       str = number(str,
+                       str = number(str, end,
                                (unsigned long) va_arg(args, void *), 16,
                                field_width, precision, flags);
                        continue;
 
                case 'n':
                                (unsigned long) va_arg(args, void *), 16,
                                field_width, precision, flags);
                        continue;
 
                case 'n':
+                       /* FIXME: What does C99 say about the overflow case here? */
                        if (qualifier == 'l') {
                                long * ip = va_arg(args, long *);
                                *ip = (str - buf);
                        if (qualifier == 'l') {
                                long * ip = va_arg(args, long *);
                                *ip = (str - buf);
@@ -378,11 +483,16 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                        break;
 
                default:
                        break;
 
                default:
-                       if (*fmt != '%')
-                               *str++ = '%';
-                       if (*fmt)
-                               *str++ = *fmt;
-                       else
+                       if (*fmt != '%') {
+                               if (str <= end)
+                                       *str = '%';
+                               ++str;
+                       }
+                       if (*fmt) {
+                               if (str <= end)
+                                       *str = *fmt;
+                               ++str;
+                       } else
                                --fmt;
                        continue;
                }
                                --fmt;
                        continue;
                }
@@ -403,9 +513,13 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                        else
                                num = va_arg(args, unsigned int);
                }
                        else
                                num = va_arg(args, unsigned int);
                }
-               str = number(str, num, base, field_width, precision, flags);
+               str = number(str, end, num, base, field_width, precision, flags);
        }
        }
-       *str = '\0';
+       if (str <= end)
+               *str = '\0';
+       else if (cnt > 0)
+               /* don't write out a null byte if the buf size is zero */
+               *end = '\0';
        return str-buf;
 }
 
        return str-buf;
 }
 
index 83cd0b3..37f181f 100644 (file)
@@ -1,4 +1,4 @@
-/* $Id: swprintf.c,v 1.10 2002/09/08 10:23:07 chorns Exp $
+/* $Id: swprintf.c,v 1.11 2002/09/12 17:50:42 guido Exp $
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
@@ -8,7 +8,6 @@
  *                  Eric Kohl
  *
  * TODO:
  *                  Eric Kohl
  *
  * TODO:
- *   - Implement maximum length (cnt) in _vsnwprintf().
  *   - Verify the implementation of '%Z'.
  */
 
  *   - Verify the implementation of '%Z'.
  */
 
@@ -61,86 +60,97 @@ static int skip_atoi(const wchar_t **s)
 
 
 static wchar_t *
 
 
 static wchar_t *
-number (wchar_t *str, long long num, int base, int size, int precision,
-       int type)
+number(wchar_t * buf, wchar_t * end, long long num, int base, int size, int precision, int type)
 {
 {
-   wchar_t c,sign,tmp[66];
-   const wchar_t *digits = L"0123456789abcdefghijklmnopqrstuvwxyz";
-   int i;
-   
-   if (type & LARGE)
-     digits = L"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
-   if (type & LEFT)
-     type &= ~ZEROPAD;
-   if (base < 2 || base > 36)
-     return 0;
-   
-   c = (type & ZEROPAD) ? L'0' : L' ';
-   sign = 0;
-   
-   if (type & SIGN) 
-     {
-       if (num < 0) 
-         {
-            sign = L'-';
-            num = -num;
-            size--;
-         } 
-       else if (type & PLUS) 
-         {
-            sign = L'+';
-            size--;
-         } 
-       else if (type & SPACE) 
-         {
-            sign = L' ';
-            size--;
-         }
-     }
-   
-   if (type & SPECIAL) 
-     {
-       if (base == 16)
-         size -= 2;
-       else if (base == 8)
-         size--;
-     }
-   
-   i = 0;
-   if (num == 0)
-     tmp[i++]='0';
-   else while (num != 0)
-     tmp[i++] = digits[do_div(num,base)];
-   if (i > precision)
-     precision = i;
-   size -= precision;
-   if (!(type&(ZEROPAD+LEFT)))
-     while(size-->0)
-       *str++ = L' ';
-   if (sign)
-     *str++ = sign;
-   if (type & SPECIAL)
-     {
-       if (base==8)
-         {
-             *str++ = L'0';
-         }
-       else if (base==16) 
-         {
-             *str++ = L'0';
-            *str++ = digits[33];
-         }
-     }
-   if (!(type & LEFT))
-     while (size-- > 0)
-       *str++ = c;
-   while (i < precision--)
-     *str++ = '0';
-   while (i-- > 0)
-     *str++ = tmp[i];
-   while (size-- > 0)
-     *str++ = L' ';
-   return str;
+       wchar_t c, sign, tmp[66];
+       const wchar_t *digits;
+       const wchar_t small_digits[] = L"0123456789abcdefghijklmnopqrstuvwxyz";
+       const wchar_t large_digits[] = L"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+       int i;
+
+       digits = (type & LARGE) ? large_digits : small_digits;
+       if (type & LEFT)
+               type &= ~ZEROPAD;
+       if (base < 2 || base > 36)
+               return 0;
+       c = (type & ZEROPAD) ? L'0' : L' ';
+       sign = 0;
+       if (type & SIGN) {
+               if (num < 0) {
+                       sign = L'-';
+                       num = -num;
+                       size--;
+               } else if (type & PLUS) {
+                       sign = L'+';
+                       size--;
+               } else if (type & SPACE) {
+                       sign = ' ';
+                       size--;
+               }
+       }
+       if (type & SPECIAL) {
+               if (base == 16)
+                       size -= 2;
+               else if (base == 8)
+                       size--;
+       }
+       i = 0;
+       if (num == 0)
+               tmp[i++] = L'0';
+       else while (num != 0)
+               tmp[i++] = digits[do_div(num,base)];
+       if (i > precision)
+               precision = i;
+       size -= precision;
+       if (!(type&(ZEROPAD+LEFT))) {
+               while(size-->0) {
+                       if (buf <= end)
+                               *buf = L' ';
+                       ++buf;
+               }
+       }
+       if (sign) {
+               if (buf <= end)
+                       *buf = sign;
+               ++buf;
+       }
+       if (type & SPECIAL) {
+               if (base==8) {
+                       if (buf <= end)
+                               *buf = L'0';
+                       ++buf;
+               } else if (base==16) {
+                       if (buf <= end)
+                               *buf = L'0';
+                       ++buf;
+                       if (buf <= end)
+                               *buf = digits[33];
+                       ++buf;
+               }
+       }
+       if (!(type & LEFT)) {
+               while (size-- > 0) {
+                       if (buf <= end)
+                               *buf = c;
+                       ++buf;
+               }
+       }
+       while (i < precision--) {
+               if (buf <= end)
+                       *buf = L'0';
+               ++buf;
+       }
+       while (i-- > 0) {
+               if (buf <= end)
+                       *buf = tmp[i];
+               ++buf;
+       }
+       while (size-- > 0) {
+               if (buf <= end)
+                       *buf = L' ';
+               ++buf;
+       }
+       return buf;
 }
 
 
 }
 
 
@@ -149,7 +159,7 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
        int len;
        unsigned long long num;
        int i, base;
        int len;
        unsigned long long num;
        int i, base;
-       wchar_t * str;
+       wchar_t * str, * end;
        const char *s;
        const wchar_t *sw;
 
        const char *s;
        const wchar_t *sw;
 
@@ -160,9 +170,18 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                   number of chars for from string */
        int qualifier;          /* 'h', 'l', 'L', 'w' or 'I' for integer fields */
 
                                   number of chars for from string */
        int qualifier;          /* 'h', 'l', 'L', 'w' or 'I' for integer fields */
 
-       for (str=buf ; *fmt ; ++fmt) {
+       str = buf;
+       end = buf + cnt - 1;
+       if (end < buf - 1) {
+               end = ((void *) -1);
+               cnt = end - buf + 1;
+       }
+
+       for ( ; *fmt ; ++fmt) {
                if (*fmt != L'%') {
                if (*fmt != L'%') {
-                       *str++ = *fmt;
+                       if (str <= end)
+                               *str = *fmt;
+                       ++str;
                        continue;
                }
 
                        continue;
                }
 
@@ -195,7 +214,7 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                /* get the precision */
                precision = -1;
                if (*fmt == L'.') {
                /* get the precision */
                precision = -1;
                if (*fmt == L'.') {
-                       ++fmt;  
+                       ++fmt;
                        if (iswdigit(*fmt))
                                precision = skip_atoi(&fmt);
                        else if (*fmt == L'*') {
                        if (iswdigit(*fmt))
                                precision = skip_atoi(&fmt);
                        else if (*fmt == L'*') {
@@ -223,27 +242,48 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                switch (*fmt) {
                case L'c':
                        if (!(flags & LEFT))
                switch (*fmt) {
                case L'c':
                        if (!(flags & LEFT))
-                               while (--field_width > 0)
-                                       *str++ = L' ';
-                       if (qualifier == 'h')
-                               *str++ = (wchar_t) va_arg(args, int);
-                       else
-                               *str++ = 
-                                 (wchar_t) va_arg(args, int);
-                       while (--field_width > 0)
-                               *str++ = L' ';
+                               while (--field_width > 0) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
+                       if (qualifier == 'h') {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       } else {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       }
+                       while (--field_width > 0) {
+                               if (str <= end)
+                                       *str = L' ';
+                               ++str;
+                       }
                        continue;
 
                case L'C':
                        if (!(flags & LEFT))
                        continue;
 
                case L'C':
                        if (!(flags & LEFT))
-                               while (--field_width > 0)
-                                       *str++ = L' ';
-                       if (qualifier == 'l' || qualifier == 'w')
-                               *str++ = (wchar_t) va_arg(args, int);
-                       else
-                               *str++ = (wchar_t) va_arg(args, int);
-                       while (--field_width > 0)
-                               *str++ = L' ';
+                               while (--field_width > 0) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
+                       if (qualifier == 'l' || qualifier == 'w') {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       } else {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       }
+                       while (--field_width > 0) {
+                               if (str <= end)
+                                       *str = L' ';
+                               ++str;
+                       }
                        continue;
 
                case L's':
                        continue;
 
                case L's':
@@ -258,12 +298,22 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = (wchar_t)(*s++);
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = (wchar_t)(*s);
+                                       ++str;
+                                       ++s;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        } else {
                                /* print unicode string */
                                sw = va_arg(args, wchar_t *);
                        } else {
                                /* print unicode string */
                                sw = va_arg(args, wchar_t *);
@@ -275,12 +325,22 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = *sw++;
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = *sw;
+                                       ++str;
+                                       ++sw;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        }
                        continue;
 
                        }
                        continue;
 
@@ -296,12 +356,22 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = *sw++;
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = *sw;
+                                       ++str;
+                                       ++sw;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        } else {
                                /* print ascii string */
                                s = va_arg(args, char *);
                        } else {
                                /* print ascii string */
                                s = va_arg(args, char *);
@@ -313,37 +383,61 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = (wchar_t)(*s++);
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = (wchar_t)(*s);
+                                       ++str;
+                                       ++s;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        }
                        continue;
 
                        }
                        continue;
 
-               case 'Z':
+               case L'Z':
                        if (qualifier == 'h') {
                                /* print counted ascii string */
                                PANSI_STRING pus = va_arg(args, PANSI_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
                        if (qualifier == 'h') {
                                /* print counted ascii string */
                                PANSI_STRING pus = va_arg(args, PANSI_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
-                                       while ((*sw) != 0)
-                                               *str++ = *sw++;
+                                       while ((*sw) != 0) {
+                                               if (str <= end)
+                                                       *str = *sw;
+                                               ++str;
+                                               ++sw;
+                                       }
                                } else {
                                } else {
-                                       for (i = 0; pus->Buffer[i] && i < pus->Length; i++)
-                                               *str++ = (wchar_t)(pus->Buffer[i]);
+                                       for (i = 0; pus->Buffer[i] && i < pus->Length; i++) {
+                                               if (str <= end)
+                                                       *str = (wchar_t)(pus->Buffer[i]);
+                                               ++str;
+                                       }
                                }
                        } else {
                                /* print counted unicode string */
                                PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
                                }
                        } else {
                                /* print counted unicode string */
                                PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
-                                       while ((*sw) != 0)
-                                               *str++ = *sw++;
+                                       while ((*sw) != 0) {
+                                               if (str <= end)
+                                                       *str = *sw;
+                                               ++str;
+                                               ++sw;
+                                       }
                                } else {
                                } else {
-                                       for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++)
-                                               *str++ = pus->Buffer[i];
+                                       for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++) {
+                                               if (str <= end)
+                                                       *str = pus->Buffer[i];
+                                               ++str;
+                                       }
                                }
                        }
                        continue;
                                }
                        }
                        continue;
@@ -353,12 +447,13 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                field_width = 2*sizeof(void *);
                                flags |= ZEROPAD;
                        }
                                field_width = 2*sizeof(void *);
                                flags |= ZEROPAD;
                        }
-                       str = number(str,
+                       str = number(str, end,
                                (unsigned long) va_arg(args, void *), 16,
                                field_width, precision, flags);
                        continue;
 
                case L'n':
                                (unsigned long) va_arg(args, void *), 16,
                                field_width, precision, flags);
                        continue;
 
                case L'n':
+                       /* FIXME: What does C99 say about the overflow case here? */
                        if (qualifier == 'l') {
                                long * ip = va_arg(args, long *);
                                *ip = (str - buf);
                        if (qualifier == 'l') {
                                long * ip = va_arg(args, long *);
                                *ip = (str - buf);
@@ -390,11 +485,16 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                        break;
 
                default:
                        break;
 
                default:
-                       if (*fmt != L'%')
-                               *str++ = L'%';
-                       if (*fmt)
-                               *str++ = *fmt;
-                       else
+                       if (*fmt != L'%') {
+                               if (str <= end)
+                                       *str = L'%';
+                               ++str;
+                       }
+                       if (*fmt) {
+                               if (str <= end)
+                                       *str = *fmt;
+                               ++str;
+                       } else
                                --fmt;
                        continue;
                }
                                --fmt;
                        continue;
                }
@@ -415,9 +515,13 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                        else
                                num = va_arg(args, unsigned int);
                }
                        else
                                num = va_arg(args, unsigned int);
                }
-               str = number(str, num, base, field_width, precision, flags);
+               str = number(str, end, num, base, field_width, precision, flags);
        }
        }
-       *str = L'\0';
+       if (str <= end)
+               *str = L'\0';
+       else if (cnt > 0)
+               /* don't write out a null byte if the buf size is zero */
+               *end = L'\0';
        return str-buf;
 }
 
        return str-buf;
 }
 
index 1cd2ce4..a480c73 100644 (file)
@@ -1,4 +1,4 @@
-/* $Id: sprintf.c,v 1.9 2002/09/08 10:23:42 chorns Exp $
+/* $Id: sprintf.c,v 1.10 2002/09/12 17:50:05 guido Exp $
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
@@ -7,8 +7,6 @@
  * PROGRAMMERS:     David Welch
  *                  Eric Kohl
  *
  * PROGRAMMERS:     David Welch
  *                  Eric Kohl
  *
- * TODO:
- *      - Implement maximum length (cnt) in _vsnprintf().
  */
 
 /*
  */
 
 /*
@@ -59,26 +57,25 @@ static int skip_atoi(const char **s)
 
 
 static char *
 
 
 static char *
-number (char * str, long long num, int base, int size, int precision, int type)
+number(char *buf, char *end, long long num, int base, int size, int precision, int type)
 {
   char c,sign,tmp[66];
 {
   char c,sign,tmp[66];
-  const char *digits="0123456789abcdefghijklmnopqrstuvwxyz";
+  const char *digits;
+  const char small_digits[] = "0123456789abcdefghijklmnopqrstuvwxyz";
+  const char large_digits[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
   int i;
 
   int i;
 
-  if (type & LARGE)
-    digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+  digits = (type & LARGE) ? large_digits : small_digits;
   if (type & LEFT)
     type &= ~ZEROPAD;
   if (base < 2 || base > 36)
     return 0;
   if (type & LEFT)
     type &= ~ZEROPAD;
   if (base < 2 || base > 36)
     return 0;
-
   c = (type & ZEROPAD) ? '0' : ' ';
   sign = 0;
   c = (type & ZEROPAD) ? '0' : ' ';
   sign = 0;
-
   if (type & SIGN) {
     if (num < 0) {
       sign = '-';
   if (type & SIGN) {
     if (num < 0) {
       sign = '-';
-      num = -num;
+        num = -num;
       size--;
     } else if (type & PLUS) {
       sign = '+';
       size--;
     } else if (type & PLUS) {
       sign = '+';
@@ -88,14 +85,12 @@ number (char * str, long long num, int base, int size, int precision, int type)
       size--;
     }
   }
       size--;
     }
   }
-
   if (type & SPECIAL) {
     if (base == 16)
       size -= 2;
     else if (base == 8)
       size--;
   }
   if (type & SPECIAL) {
     if (base == 16)
       size -= 2;
     else if (base == 8)
       size--;
   }
-
   i = 0;
   if (num == 0)
     tmp[i++]='0';
   i = 0;
   if (num == 0)
     tmp[i++]='0';
@@ -104,29 +99,55 @@ number (char * str, long long num, int base, int size, int precision, int type)
   if (i > precision)
     precision = i;
   size -= precision;
   if (i > precision)
     precision = i;
   size -= precision;
-  if (!(type&(ZEROPAD+LEFT)))
-    while(size-->0)
-      *str++ = ' ';
-  if (sign)
-    *str++ = sign;
+  if (!(type&(ZEROPAD+LEFT))) {
+    while(size-->0) {
+      if (buf <= end)
+        *buf = ' ';
+      ++buf;
+    }
+  }
+  if (sign) {
+    if (buf <= end)
+      *buf = sign;
+    ++buf;
+  }
   if (type & SPECIAL) {
     if (base==8) {
   if (type & SPECIAL) {
     if (base==8) {
-      *str++ = '0';
+      if (buf <= end)
+        *buf = '0';
+      ++buf;
     } else if (base==16) {
     } else if (base==16) {
-      *str++ = '0';
-      *str++ = digits[33];
+      if (buf <= end)
+        *buf = '0';
+      ++buf;
+      if (buf <= end)
+        *buf = digits[33];
+      ++buf;
+    }
+  }
+  if (!(type & LEFT)) {
+    while (size-- > 0) {
+      if (buf <= end)
+        *buf = c;
+      ++buf;
     }
   }
     }
   }
-  if (!(type & LEFT))
-    while (size-- > 0)
-      *str++ = c;
-  while (i < precision--)
-    *str++ = '0';
-  while (i-- > 0)
-    *str++ = tmp[i];
-  while (size-- > 0)
-    *str++ = ' ';
-  return str;
+  while (i < precision--) {
+    if (buf <= end)
+      *buf = '0';
+    ++buf;
+  }
+  while (i-- > 0) {
+    if (buf <= end)
+      *buf = tmp[i];
+    ++buf;
+  }
+  while (size-- > 0) {
+    if (buf <= end)
+      *buf = ' ';
+    ++buf;
+  }
+  return buf;
 }
 
 
 }
 
 
@@ -135,7 +156,7 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
   int len;
   unsigned long long num;
   int i, base;
   int len;
   unsigned long long num;
   int i, base;
-  char * str;
+  char *str, *end;
   const char *s;
   const wchar_t *sw;
 
   const char *s;
   const wchar_t *sw;
 
@@ -146,24 +167,31 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                             number of chars for from string */
   int qualifier;          /* 'h', 'l', 'L', 'I' or 'w' for integer fields */
 
                             number of chars for from string */
   int qualifier;          /* 'h', 'l', 'L', 'I' or 'w' for integer fields */
 
-  for (str=buf ; *fmt ; ++fmt) {
+  str = buf;
+  end = buf + cnt - 1;
+  if (end < buf - 1) {
+    end = ((void *) -1);
+    cnt = end - buf + 1;
+  }
+
+  for ( ; *fmt ; ++fmt) {
     if (*fmt != '%') {
     if (*fmt != '%') {
-      *str++ = *fmt;
-      if( --cnt == 0 )
-       goto out;
+      if (str <= end)
+        *str = *fmt;
+      ++str;
       continue;
     }
 
     /* process flags */
     flags = 0;
       continue;
     }
 
     /* process flags */
     flags = 0;
-  repeat:
+    repeat:
     ++fmt;          /* this also skips first '%' */
     switch (*fmt) {
     ++fmt;          /* this also skips first '%' */
     switch (*fmt) {
-    case '-': flags |= LEFT; goto repeat;
-    case '+': flags |= PLUS; goto repeat;
-    case ' ': flags |= SPACE; goto repeat;
-    case '#': flags |= SPECIAL; goto repeat;
-    case '0': flags |= ZEROPAD; goto repeat;
+      case '-': flags |= LEFT; goto repeat;
+      case '+': flags |= PLUS; goto repeat;
+      case ' ': flags |= SPACE; goto repeat;
+      case '#': flags |= SPECIAL; goto repeat;
+      case '0': flags |= ZEROPAD; goto repeat;
     }
 
     /* get field width */
     }
 
     /* get field width */
@@ -213,27 +241,27 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
       if (!(flags & LEFT))
        while (--field_width > 0)
          {
       if (!(flags & LEFT))
        while (--field_width > 0)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+              *str = ' ';
+           ++str;
          }
       if (qualifier == 'l' || qualifier == 'w')
        {
          }
       if (qualifier == 'l' || qualifier == 'w')
        {
-         *str++ = (unsigned char)(wchar_t) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+            *str = (unsigned char)(wchar_t) va_arg(args, int);
+          ++str;
        }
       else
        {
        }
       else
        {
-         *str++ = (unsigned char) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = (unsigned char) va_arg(args, int);
+         ++str;
        }
       while (--field_width > 0)
        {
        }
       while (--field_width > 0)
        {
-         *str++ = ' ';
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = ' ';
+         ++str;
        }
       continue;
 
        }
       continue;
 
@@ -241,33 +269,33 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
       if (!(flags & LEFT))
        while (--field_width > 0)
          {
       if (!(flags & LEFT))
        while (--field_width > 0)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       if (qualifier == 'h')
        {
          }
       if (qualifier == 'h')
        {
-         *str++ = (unsigned char) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = (unsigned char) va_arg(args, int);
+         ++str;
        }
       else
        {
        }
       else
        {
-         *str++ = (unsigned char)(wchar_t) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = (unsigned char)(wchar_t) va_arg(args, int);
+         ++str;
        }
       while (--field_width > 0)
        {
        }
       while (--field_width > 0)
        {
-         *str++ = ' ';
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = ' ';
+         ++str;
        }
       continue;
 
     case 's': /* finished */
       if (qualifier == 'l' || qualifier == 'w') {
        }
       continue;
 
     case 's': /* finished */
       if (qualifier == 'l' || qualifier == 'w') {
-                                /* print unicode string */
+        /* print unicode string */
        sw = va_arg(args, wchar_t *);
        if (sw == NULL)
          sw = L"<NULL>";
        sw = va_arg(args, wchar_t *);
        if (sw == NULL)
          sw = L"<NULL>";
@@ -277,24 +305,25 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = (unsigned char)(*sw++);
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = (unsigned char)(*sw);
+           ++str;
+           ++sw;
          }
        while (len < field_width--)
          {
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       } else {
          }
       } else {
-                                /* print ascii string */
+        /* print ascii string */
        s = va_arg(args, char *);
        if (s == NULL)
          s = "<NULL>";
        s = va_arg(args, char *);
        if (s == NULL)
          s = "<NULL>";
@@ -304,21 +333,22 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = *s++;
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = *s;
+           ++str;
+           ++s;
          }
        while (len < field_width--)
          {
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       }
       continue;
          }
       }
       continue;
@@ -335,24 +365,25 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = *s++;
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = *s;
+           ++str;
+           ++s;
          }
        while (len < field_width--)
          {
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       } else {
          }
       } else {
-                                /* print unicode string */
+        /* print unicode string */
        sw = va_arg(args, wchar_t *);
        if (sw == NULL)
          sw = L"<NULL>";
        sw = va_arg(args, wchar_t *);
        if (sw == NULL)
          sw = L"<NULL>";
@@ -362,62 +393,65 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = (unsigned char)(*sw++);
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = (unsigned char)(*sw);
+           ++str;
+           ++sw;
          }
        while (len < field_width--)
          {
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       }
       continue;
 
     case 'Z':
       if (qualifier == 'w') {
          }
       }
       continue;
 
     case 'Z':
       if (qualifier == 'w') {
-                                /* print counted unicode string */
+        /* print counted unicode string */
        PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
        if ((pus == NULL) || (pus->Buffer == NULL)) {
          s = "<NULL>";
          while ((*s) != 0)
            {
        PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
        if ((pus == NULL) || (pus->Buffer == NULL)) {
          s = "<NULL>";
          while ((*s) != 0)
            {
-             *str++ = *s++;
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = *s;
+             ++str;
+             ++s;
            }
        } else {
          for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++)
            {
            }
        } else {
          for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++)
            {
-             *str++ = (unsigned char)(pus->Buffer[i]);
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = (unsigned char)(pus->Buffer[i]);
+             ++str;
            }
        }
       } else {
            }
        }
       } else {
-                                /* print counted ascii string */
+        /* print counted ascii string */
        PANSI_STRING pus = va_arg(args, PANSI_STRING);
        if ((pus == NULL) || (pus->Buffer == NULL)) {
          s = "<NULL>";
          while ((*s) != 0)
            {
        PANSI_STRING pus = va_arg(args, PANSI_STRING);
        if ((pus == NULL) || (pus->Buffer == NULL)) {
          s = "<NULL>";
          while ((*s) != 0)
            {
-             *str++ = *s++;
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = *s;
+             ++str;
+             ++s;
            }
        } else {
          for (i = 0; pus->Buffer[i] && i < pus->Length; i++)
            {
            }
        } else {
          for (i = 0; pus->Buffer[i] && i < pus->Length; i++)
            {
-             *str++ = pus->Buffer[i];
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = pus->Buffer[i];
+             ++str;
            }
        }
       }
            }
        }
       }
@@ -428,12 +462,13 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        field_width = 2 * sizeof(void *);
        flags |= ZEROPAD;
       }
        field_width = 2 * sizeof(void *);
        flags |= ZEROPAD;
       }
-      str = number(str,
-                  (unsigned long) va_arg(args, void *), 16,
-                  field_width, precision, flags);
+      str = number(str, end,
+                   (unsigned long) va_arg(args, void *),
+                   16, field_width, precision, flags);
       continue;
 
     case 'n':
       continue;
 
     case 'n':
+      /* FIXME: What does C99 say about the overflow case here? */
       if (qualifier == 'l') {
        long * ip = va_arg(args, long *);
        *ip = (str - buf);
       if (qualifier == 'l') {
        long * ip = va_arg(args, long *);
        *ip = (str - buf);
@@ -467,15 +502,15 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
     default:
       if (*fmt != '%')
        {
     default:
       if (*fmt != '%')
        {
-         *str++ = '%';
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = '%';
+         ++str;
        }
       if (*fmt)
        {
        }
       if (*fmt)
        {
-         *str++ = *fmt;
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = *fmt;
+         ++str;
        }
       else
        --fmt;
        }
       else
        --fmt;
@@ -498,10 +533,15 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
       else
        num = va_arg(args, unsigned int);
     }
       else
        num = va_arg(args, unsigned int);
     }
-    str = number(str, num, base, field_width, precision, flags);
+    str = number(str, end, num, base, field_width, precision, flags);
   }
   }
- out:
-  *str = '\0';
+
+  if (str <= end)
+    *str = '\0';
+  else if (cnt > 0)
+    /* don't write out a null byte if the buf size is zero */
+    *end = '\0';
+  
   return str-buf;
 }
 
   return str-buf;
 }
 
index ad6f502..cd81ae8 100644 (file)
@@ -1,4 +1,4 @@
-/* $Id: swprintf.c,v 1.7 2002/09/08 10:23:42 chorns Exp $
+/* $Id: swprintf.c,v 1.8 2002/09/12 17:50:05 guido Exp $
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
@@ -8,7 +8,6 @@
  *                  Eric Kohl
  *
  * TODO:
  *                  Eric Kohl
  *
  * TODO:
- *   - Implement maximum length (cnt) in _vsnwprintf().
  *   - Verify the implementation of '%Z'.
  */
 
  *   - Verify the implementation of '%Z'.
  */
 
@@ -61,86 +60,97 @@ static int skip_atoi(const wchar_t **s)
 
 
 static wchar_t *
 
 
 static wchar_t *
-number (wchar_t *str, long long num, int base, int size, int precision,
-       int type)
+number(wchar_t * buf, wchar_t * end, long long num, int base, int size, int precision, int type)
 {
 {
-   wchar_t c,sign,tmp[66];
-   const wchar_t *digits = L"0123456789abcdefghijklmnopqrstuvwxyz";
-   int i;
-   
-   if (type & LARGE)
-     digits = L"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
-   if (type & LEFT)
-     type &= ~ZEROPAD;
-   if (base < 2 || base > 36)
-     return 0;
-   
-   c = (type & ZEROPAD) ? L'0' : L' ';
-   sign = 0;
-   
-   if (type & SIGN) 
-     {
-       if (num < 0) 
-         {
-            sign = L'-';
-            num = -num;
-            size--;
-         } 
-       else if (type & PLUS) 
-         {
-            sign = L'+';
-            size--;
-         } 
-       else if (type & SPACE) 
-         {
-            sign = L' ';
-            size--;
-         }
-     }
-   
-   if (type & SPECIAL) 
-     {
-       if (base == 16)
-         size -= 2;
-       else if (base == 8)
-         size--;
-     }
-   
-   i = 0;
-   if (num == 0)
-     tmp[i++]='0';
-   else while (num != 0)
-     tmp[i++] = digits[do_div(num,base)];
-   if (i > precision)
-     precision = i;
-   size -= precision;
-   if (!(type&(ZEROPAD+LEFT)))
-     while(size-->0)
-       *str++ = L' ';
-   if (sign)
-     *str++ = sign;
-   if (type & SPECIAL)
-     {
-       if (base==8)
-         {
-            *str++ = L'0';
-         }
-       else if (base==16) 
-         {
-            *str++ = L'0';
-            *str++ = digits[33];
-         }
-     }
-   if (!(type & LEFT))
-     while (size-- > 0)
-       *str++ = c;
-   while (i < precision--)
-     *str++ = '0';
-   while (i-- > 0)
-     *str++ = tmp[i];
-   while (size-- > 0)
-     *str++ = L' ';
-   return str;
+       wchar_t c,sign, tmp[66];
+       const wchar_t *digits;
+       const wchar_t small_digits[] = L"0123456789abcdefghijklmnopqrstuvwxyz";
+       const wchar_t large_digits[] = L"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+       int i;
+
+       digits = (type & LARGE) ? large_digits : small_digits;
+       if (type & LEFT)
+               type &= ~ZEROPAD;
+       if (base < 2 || base > 36)
+               return 0;
+       c = (type & ZEROPAD) ? L'0' : L' ';
+       sign = 0;
+       if (type & SIGN) {
+               if (num < 0) {
+                       sign = L'-';
+                       num = -num;
+                       size--;
+               } else if (type & PLUS) {
+                       sign = L'+';
+                       size--;
+               } else if (type & SPACE) {
+                       sign = ' ';
+                       size--;
+               }
+       }
+       if (type & SPECIAL) {
+               if (base == 16)
+                       size -= 2;
+               else if (base == 8)
+                       size--;
+       }
+       i = 0;
+       if (num == 0)
+               tmp[i++] = L'0';
+       else while (num != 0)
+               tmp[i++] = digits[do_div(num,base)];
+       if (i > precision)
+               precision = i;
+       size -= precision;
+       if (!(type&(ZEROPAD+LEFT))) {
+               while(size-->0) {
+                       if (buf <= end)
+                               *buf = L' ';
+                       ++buf;
+               }
+       }
+       if (sign) {
+               if (buf <= end)
+                       *buf = sign;
+               ++buf;
+       }
+       if (type & SPECIAL) {
+               if (base==8) {
+                       if (buf <= end)
+                               *buf = L'0';
+                       ++buf;
+               } else if (base==16) {
+                       if (buf <= end)
+                               *buf = L'0';
+                       ++buf;
+                       if (buf <= end)
+                               *buf = digits[33];
+                       ++buf;
+               }
+       }
+       if (!(type & LEFT)) {
+               while (size-- > 0) {
+                       if (buf <= end)
+                               *buf = c;
+                       ++buf;
+               }
+       }
+       while (i < precision--) {
+               if (buf <= end)
+                       *buf = L'0';
+               ++buf;
+       }
+       while (i-- > 0) {
+               if (buf <= end)
+                       *buf = tmp[i];
+               ++buf;
+       }
+       while (size-- > 0) {
+               if (buf <= end)
+                       *buf = L' ';
+               ++buf;
+       }
+       return buf;
 }
 
 
 }
 
 
@@ -149,7 +159,7 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
        int len;
        unsigned long long num;
        int i, base;
        int len;
        unsigned long long num;
        int i, base;
-       wchar_t * str;
+       wchar_t * str, * end;
        const char *s;
        const wchar_t *sw;
 
        const char *s;
        const wchar_t *sw;
 
@@ -160,9 +170,18 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                   number of chars for from string */
        int qualifier;          /* 'h', 'l', 'L', 'w' or 'I' for integer fields */
 
                                   number of chars for from string */
        int qualifier;          /* 'h', 'l', 'L', 'w' or 'I' for integer fields */
 
-       for (str=buf ; *fmt ; ++fmt) {
+       str = buf;
+       end = buf + cnt - 1;
+       if (end < buf - 1) {
+               end = ((void *) -1);
+               cnt = end - buf + 1;
+       }
+
+       for ( ; *fmt ; ++fmt) {
                if (*fmt != L'%') {
                if (*fmt != L'%') {
-                       *str++ = *fmt;
+                       if (str <= end)
+                               *str = *fmt;
+                       ++str;
                        continue;
                }
 
                        continue;
                }
 
@@ -195,7 +214,7 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                /* get the precision */
                precision = -1;
                if (*fmt == L'.') {
                /* get the precision */
                precision = -1;
                if (*fmt == L'.') {
-                       ++fmt;  
+                       ++fmt;
                        if (iswdigit(*fmt))
                                precision = skip_atoi(&fmt);
                        else if (*fmt == L'*') {
                        if (iswdigit(*fmt))
                                precision = skip_atoi(&fmt);
                        else if (*fmt == L'*') {
@@ -223,26 +242,48 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                switch (*fmt) {
                case L'c':
                        if (!(flags & LEFT))
                switch (*fmt) {
                case L'c':
                        if (!(flags & LEFT))
-                               while (--field_width > 0)
-                                       *str++ = L' ';
-                       if (qualifier == 'h')
-                               *str++ = (wchar_t) va_arg(args, int);
-                       else
-                               *str++ = (wchar_t) va_arg(args, int);
-                       while (--field_width > 0)
-                               *str++ = L' ';
+                               while (--field_width > 0) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
+                       if (qualifier == 'h') {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       } else {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       }
+                       while (--field_width > 0) {
+                               if (str <= end)
+                                       *str = L' ';
+                               ++str;
+                       }
                        continue;
 
                case L'C':
                        if (!(flags & LEFT))
                        continue;
 
                case L'C':
                        if (!(flags & LEFT))
-                               while (--field_width > 0)
-                                       *str++ = L' ';
-                       if (qualifier == 'l' || qualifier == 'w')
-                               *str++ = (wchar_t) va_arg(args, int);
-                       else
-                               *str++ = (wchar_t) va_arg(args, int);
-                       while (--field_width > 0)
-                               *str++ = L' ';
+                               while (--field_width > 0) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
+                       if (qualifier == 'l' || qualifier == 'w') {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       } else {
+                               if (str <= end)
+                                       *str = (wchar_t) va_arg(args, int);
+                               ++str;
+                       }
+                       while (--field_width > 0) {
+                               if (str <= end)
+                                       *str = L' ';
+                               ++str;
+                       }
                        continue;
 
                case L's':
                        continue;
 
                case L's':
@@ -257,12 +298,22 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = (wchar_t)(*s++);
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = (wchar_t)(*s);
+                                       ++str;
+                                       ++s;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        } else {
                                /* print unicode string */
                                sw = va_arg(args, wchar_t *);
                        } else {
                                /* print unicode string */
                                sw = va_arg(args, wchar_t *);
@@ -274,12 +325,22 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = *sw++;
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = *sw;
+                                       ++str;
+                                       ++sw;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        }
                        continue;
 
                        }
                        continue;
 
@@ -295,12 +356,22 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = *sw++;
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = *sw;
+                                       ++str;
+                                       ++sw;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        } else {
                                /* print ascii string */
                                s = va_arg(args, char *);
                        } else {
                                /* print ascii string */
                                s = va_arg(args, char *);
@@ -312,12 +383,22 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                        len = precision;
 
                                if (!(flags & LEFT))
                                        len = precision;
 
                                if (!(flags & LEFT))
-                                       while (len < field_width--)
-                                               *str++ = L' ';
-                               for (i = 0; i < len; ++i)
-                                       *str++ = (wchar_t)(*s++);
-                               while (len < field_width--)
-                                       *str++ = L' ';
+                                       while (len < field_width--) {
+                                               if (str <= end)
+                                                       *str = L' ';
+                                               ++str;
+                                       }
+                               for (i = 0; i < len; ++i) {
+                                       if (str <= end)
+                                               *str = (wchar_t)(*s);
+                                       ++str;
+                                       ++s;
+                               }
+                               while (len < field_width--) {
+                                       if (str <= end)
+                                               *str = L' ';
+                                       ++str;
+                               }
                        }
                        continue;
 
                        }
                        continue;
 
@@ -327,22 +408,36 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                PANSI_STRING pus = va_arg(args, PANSI_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
                                PANSI_STRING pus = va_arg(args, PANSI_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
-                                       while ((*sw) != 0)
-                                               *str++ = *sw++;
+                                       while ((*sw) != 0) {
+                                               if (str <= end)
+                                                       *str = *sw;
+                                               ++str;
+                                               ++sw;
+                                       }
                                } else {
                                } else {
-                                       for (i = 0; pus->Buffer[i] && i < pus->Length; i++)
-                                               *str++ = (wchar_t)(pus->Buffer[i]);
+                                       for (i = 0; pus->Buffer[i] && i < pus->Length; i++) {
+                                               if (str <= end)
+                                                       *str = (wchar_t)(pus->Buffer[i]);
+                                               ++str;
+                                       }
                                }
                        } else {
                                /* print counted unicode string */
                                PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
                                }
                        } else {
                                /* print counted unicode string */
                                PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
                                if ((pus == NULL) || (pus->Buffer == NULL)) {
                                        sw = L"<NULL>";
-                                       while ((*sw) != 0)
-                                               *str++ = *sw++;
+                                       while ((*sw) != 0) {
+                                               if (str <= end)
+                                                       *str = *sw;
+                                               ++str;
+                                               ++sw;
+                                       }
                                } else {
                                } else {
-                                       for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++)
-                                               *str++ = pus->Buffer[i];
+                                       for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++) {
+                                               if (str <= end)
+                                                       *str = pus->Buffer[i];
+                                               ++str;
+                                       }
                                }
                        }
                        continue;
                                }
                        }
                        continue;
@@ -352,12 +447,13 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                                field_width = 2*sizeof(void *);
                                flags |= ZEROPAD;
                        }
                                field_width = 2*sizeof(void *);
                                flags |= ZEROPAD;
                        }
-                       str = number(str,
+                       str = number(str, end,
                                (unsigned long) va_arg(args, void *), 16,
                                field_width, precision, flags);
                        continue;
 
                case L'n':
                                (unsigned long) va_arg(args, void *), 16,
                                field_width, precision, flags);
                        continue;
 
                case L'n':
+                       /* FIXME: What does C99 say about the overflow case here? */
                        if (qualifier == 'l') {
                                long * ip = va_arg(args, long *);
                                *ip = (str - buf);
                        if (qualifier == 'l') {
                                long * ip = va_arg(args, long *);
                                *ip = (str - buf);
@@ -389,11 +485,16 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                        break;
 
                default:
                        break;
 
                default:
-                       if (*fmt != L'%')
-                               *str++ = L'%';
-                       if (*fmt)
-                               *str++ = *fmt;
-                       else
+                       if (*fmt != L'%') {
+                               if (str <= end)
+                                       *str = L'%';
+                               ++str;
+                       }
+                       if (*fmt) {
+                               if (str <= end)
+                                       *str = *fmt;
+                               ++str;
+                       } else
                                --fmt;
                        continue;
                }
                                --fmt;
                        continue;
                }
@@ -414,9 +515,13 @@ int _vsnwprintf(wchar_t *buf, size_t cnt, const wchar_t *fmt, va_list args)
                        else
                                num = va_arg(args, unsigned int);
                }
                        else
                                num = va_arg(args, unsigned int);
                }
-               str = number(str, num, base, field_width, precision, flags);
+               str = number(str, end, num, base, field_width, precision, flags);
        }
        }
-       *str = L'\0';
+       if (str <= end)
+               *str = L'\0';
+       else if (cnt > 0)
+               /* don't write out a null byte if the buf size is zero */
+               *end = L'\0';
        return str-buf;
 }
 
        return str-buf;
 }