Prevent buffer overflows in vsn(w)printf
[reactos.git] / reactos / ntoskrnl / rtl / sprintf.c
index 1cd2ce4..a480c73 100644 (file)
@@ -1,4 +1,4 @@
-/* $Id: sprintf.c,v 1.9 2002/09/08 10:23:42 chorns Exp $
+/* $Id: sprintf.c,v 1.10 2002/09/12 17:50:05 guido Exp $
  *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
@@ -7,8 +7,6 @@
  * PROGRAMMERS:     David Welch
  *                  Eric Kohl
  *
- * TODO:
- *      - Implement maximum length (cnt) in _vsnprintf().
  */
 
 /*
@@ -59,26 +57,25 @@ static int skip_atoi(const char **s)
 
 
 static char *
-number (char * str, long long num, int base, int size, int precision, int type)
+number(char *buf, char *end, long long num, int base, int size, int precision, int type)
 {
   char c,sign,tmp[66];
-  const char *digits="0123456789abcdefghijklmnopqrstuvwxyz";
+  const char *digits;
+  const char small_digits[] = "0123456789abcdefghijklmnopqrstuvwxyz";
+  const char large_digits[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
   int i;
 
-  if (type & LARGE)
-    digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+  digits = (type & LARGE) ? large_digits : small_digits;
   if (type & LEFT)
     type &= ~ZEROPAD;
   if (base < 2 || base > 36)
     return 0;
-
   c = (type & ZEROPAD) ? '0' : ' ';
   sign = 0;
-
   if (type & SIGN) {
     if (num < 0) {
       sign = '-';
-      num = -num;
+        num = -num;
       size--;
     } else if (type & PLUS) {
       sign = '+';
@@ -88,14 +85,12 @@ number (char * str, long long num, int base, int size, int precision, int type)
       size--;
     }
   }
-
   if (type & SPECIAL) {
     if (base == 16)
       size -= 2;
     else if (base == 8)
       size--;
   }
-
   i = 0;
   if (num == 0)
     tmp[i++]='0';
@@ -104,29 +99,55 @@ number (char * str, long long num, int base, int size, int precision, int type)
   if (i > precision)
     precision = i;
   size -= precision;
-  if (!(type&(ZEROPAD+LEFT)))
-    while(size-->0)
-      *str++ = ' ';
-  if (sign)
-    *str++ = sign;
+  if (!(type&(ZEROPAD+LEFT))) {
+    while(size-->0) {
+      if (buf <= end)
+        *buf = ' ';
+      ++buf;
+    }
+  }
+  if (sign) {
+    if (buf <= end)
+      *buf = sign;
+    ++buf;
+  }
   if (type & SPECIAL) {
     if (base==8) {
-      *str++ = '0';
+      if (buf <= end)
+        *buf = '0';
+      ++buf;
     } else if (base==16) {
-      *str++ = '0';
-      *str++ = digits[33];
+      if (buf <= end)
+        *buf = '0';
+      ++buf;
+      if (buf <= end)
+        *buf = digits[33];
+      ++buf;
+    }
+  }
+  if (!(type & LEFT)) {
+    while (size-- > 0) {
+      if (buf <= end)
+        *buf = c;
+      ++buf;
     }
   }
-  if (!(type & LEFT))
-    while (size-- > 0)
-      *str++ = c;
-  while (i < precision--)
-    *str++ = '0';
-  while (i-- > 0)
-    *str++ = tmp[i];
-  while (size-- > 0)
-    *str++ = ' ';
-  return str;
+  while (i < precision--) {
+    if (buf <= end)
+      *buf = '0';
+    ++buf;
+  }
+  while (i-- > 0) {
+    if (buf <= end)
+      *buf = tmp[i];
+    ++buf;
+  }
+  while (size-- > 0) {
+    if (buf <= end)
+      *buf = ' ';
+    ++buf;
+  }
+  return buf;
 }
 
 
@@ -135,7 +156,7 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
   int len;
   unsigned long long num;
   int i, base;
-  char * str;
+  char *str, *end;
   const char *s;
   const wchar_t *sw;
 
@@ -146,24 +167,31 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
                             number of chars for from string */
   int qualifier;          /* 'h', 'l', 'L', 'I' or 'w' for integer fields */
 
-  for (str=buf ; *fmt ; ++fmt) {
+  str = buf;
+  end = buf + cnt - 1;
+  if (end < buf - 1) {
+    end = ((void *) -1);
+    cnt = end - buf + 1;
+  }
+
+  for ( ; *fmt ; ++fmt) {
     if (*fmt != '%') {
-      *str++ = *fmt;
-      if( --cnt == 0 )
-       goto out;
+      if (str <= end)
+        *str = *fmt;
+      ++str;
       continue;
     }
 
     /* process flags */
     flags = 0;
-  repeat:
+    repeat:
     ++fmt;          /* this also skips first '%' */
     switch (*fmt) {
-    case '-': flags |= LEFT; goto repeat;
-    case '+': flags |= PLUS; goto repeat;
-    case ' ': flags |= SPACE; goto repeat;
-    case '#': flags |= SPECIAL; goto repeat;
-    case '0': flags |= ZEROPAD; goto repeat;
+      case '-': flags |= LEFT; goto repeat;
+      case '+': flags |= PLUS; goto repeat;
+      case ' ': flags |= SPACE; goto repeat;
+      case '#': flags |= SPECIAL; goto repeat;
+      case '0': flags |= ZEROPAD; goto repeat;
     }
 
     /* get field width */
@@ -213,27 +241,27 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
       if (!(flags & LEFT))
        while (--field_width > 0)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+              *str = ' ';
+           ++str;
          }
       if (qualifier == 'l' || qualifier == 'w')
        {
-         *str++ = (unsigned char)(wchar_t) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+            *str = (unsigned char)(wchar_t) va_arg(args, int);
+          ++str;
        }
       else
        {
-         *str++ = (unsigned char) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = (unsigned char) va_arg(args, int);
+         ++str;
        }
       while (--field_width > 0)
        {
-         *str++ = ' ';
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = ' ';
+         ++str;
        }
       continue;
 
@@ -241,33 +269,33 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
       if (!(flags & LEFT))
        while (--field_width > 0)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       if (qualifier == 'h')
        {
-         *str++ = (unsigned char) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = (unsigned char) va_arg(args, int);
+         ++str;
        }
       else
        {
-         *str++ = (unsigned char)(wchar_t) va_arg(args, int);
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = (unsigned char)(wchar_t) va_arg(args, int);
+         ++str;
        }
       while (--field_width > 0)
        {
-         *str++ = ' ';
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = ' ';
+         ++str;
        }
       continue;
 
     case 's': /* finished */
       if (qualifier == 'l' || qualifier == 'w') {
-                                /* print unicode string */
+        /* print unicode string */
        sw = va_arg(args, wchar_t *);
        if (sw == NULL)
          sw = L"<NULL>";
@@ -277,24 +305,25 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = (unsigned char)(*sw++);
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = (unsigned char)(*sw);
+           ++str;
+           ++sw;
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       } else {
-                                /* print ascii string */
+        /* print ascii string */
        s = va_arg(args, char *);
        if (s == NULL)
          s = "<NULL>";
@@ -304,21 +333,22 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = *s++;
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = *s;
+           ++str;
+           ++s;
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       }
       continue;
@@ -335,24 +365,25 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = *s++;
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = *s;
+           ++str;
+           ++s;
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       } else {
-                                /* print unicode string */
+        /* print unicode string */
        sw = va_arg(args, wchar_t *);
        if (sw == NULL)
          sw = L"<NULL>";
@@ -362,62 +393,65 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        if (!(flags & LEFT))
          while (len < field_width--)
            {
-             *str++ = ' ';
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = ' ';
+             ++str;
            }
        for (i = 0; i < len; ++i)
          {
-           *str++ = (unsigned char)(*sw++);
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = (unsigned char)(*sw);
+           ++str;
+           ++sw;
          }
        while (len < field_width--)
          {
-           *str++ = ' ';
-           if( --cnt == 0 )
-             goto out;
+            if (str <= end)
+             *str = ' ';
+           ++str;
          }
       }
       continue;
 
     case 'Z':
       if (qualifier == 'w') {
-                                /* print counted unicode string */
+        /* print counted unicode string */
        PUNICODE_STRING pus = va_arg(args, PUNICODE_STRING);
        if ((pus == NULL) || (pus->Buffer == NULL)) {
          s = "<NULL>";
          while ((*s) != 0)
            {
-             *str++ = *s++;
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = *s;
+             ++str;
+             ++s;
            }
        } else {
          for (i = 0; pus->Buffer[i] && i < pus->Length / sizeof(WCHAR); i++)
            {
-             *str++ = (unsigned char)(pus->Buffer[i]);
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = (unsigned char)(pus->Buffer[i]);
+             ++str;
            }
        }
       } else {
-                                /* print counted ascii string */
+        /* print counted ascii string */
        PANSI_STRING pus = va_arg(args, PANSI_STRING);
        if ((pus == NULL) || (pus->Buffer == NULL)) {
          s = "<NULL>";
          while ((*s) != 0)
            {
-             *str++ = *s++;
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = *s;
+             ++str;
+             ++s;
            }
        } else {
          for (i = 0; pus->Buffer[i] && i < pus->Length; i++)
            {
-             *str++ = pus->Buffer[i];
-             if( --cnt == 0 )
-               goto out;
+              if (str <= end)
+               *str = pus->Buffer[i];
+             ++str;
            }
        }
       }
@@ -428,12 +462,13 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
        field_width = 2 * sizeof(void *);
        flags |= ZEROPAD;
       }
-      str = number(str,
-                  (unsigned long) va_arg(args, void *), 16,
-                  field_width, precision, flags);
+      str = number(str, end,
+                   (unsigned long) va_arg(args, void *),
+                   16, field_width, precision, flags);
       continue;
 
     case 'n':
+      /* FIXME: What does C99 say about the overflow case here? */
       if (qualifier == 'l') {
        long * ip = va_arg(args, long *);
        *ip = (str - buf);
@@ -467,15 +502,15 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
     default:
       if (*fmt != '%')
        {
-         *str++ = '%';
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = '%';
+         ++str;
        }
       if (*fmt)
        {
-         *str++ = *fmt;
-         if( --cnt == 0 )
-           goto out;
+          if (str <= end)
+           *str = *fmt;
+         ++str;
        }
       else
        --fmt;
@@ -498,10 +533,15 @@ int _vsnprintf(char *buf, size_t cnt, const char *fmt, va_list args)
       else
        num = va_arg(args, unsigned int);
     }
-    str = number(str, num, base, field_width, precision, flags);
+    str = number(str, end, num, base, field_width, precision, flags);
   }
- out:
-  *str = '\0';
+
+  if (str <= end)
+    *str = '\0';
+  else if (cnt > 0)
+    /* don't write out a null byte if the buf size is zero */
+    *end = '\0';
+  
   return str-buf;
 }