[NTOSKRNL] Unify buffer size calculation in FsRtlIs{Name,Dbcs}InExpression
[reactos.git] / ntoskrnl / fsrtl / name.c
index 38ccca4..3d2b33e 100644 (file)
@@ -23,13 +23,16 @@ FsRtlIsNameInExpressionPrivate(IN PUNICODE_STRING Expression,
                                IN BOOLEAN IgnoreCase,
                                IN PWCHAR UpcaseTable OPTIONAL)
 {
-    SHORT StarFound = -1, DosStarFound = -1;
-    USHORT BackTrackingBuffer[5], DosBackTrackingBuffer[5];
-    PUSHORT BackTracking = BackTrackingBuffer, DosBackTracking = DosBackTrackingBuffer;
-    SHORT BackTrackingSize = RTL_NUMBER_OF(BackTrackingBuffer);
-    SHORT DosBackTrackingSize = RTL_NUMBER_OF(DosBackTrackingBuffer);
+    USHORT Offset, Position, BackTrackingPosition, OldBackTrackingPosition;
+    USHORT BackTrackingBuffer[16], OldBackTrackingBuffer[16] = {0};
+    PUSHORT BackTrackingSwap, BackTracking = BackTrackingBuffer, OldBackTracking = OldBackTrackingBuffer;
+    ULONG BackTrackingBufferSize = RTL_NUMBER_OF(BackTrackingBuffer);
+    PVOID AllocatedBuffer = NULL;
     UNICODE_STRING IntExpression;
-    USHORT ExpressionPosition = 0, NamePosition = 0, MatchingChars, LastDot;
+    USHORT ExpressionPosition, NamePosition = 0, MatchingChars = 1;
+    BOOLEAN EndOfName = FALSE;
+    BOOLEAN Result;
+    BOOLEAN DontSkipDot;
     WCHAR CompareChar;
     PAGED_CODE();
 
@@ -37,7 +40,7 @@ FsRtlIsNameInExpressionPrivate(IN PUNICODE_STRING Expression,
     if (!Name->Length || !Expression->Length)
     {
         /* Return TRUE if both strings are empty, otherwise FALSE */
-        if (Name->Length == 0 && Expression->Length == 0)
+        if (!Name->Length && !Expression->Length)
             return TRUE;
         else
             return FALSE;
@@ -103,189 +106,164 @@ FsRtlIsNameInExpressionPrivate(IN PUNICODE_STRING Expression,
         }
     }
 
-    while ((NamePosition < Name->Length / sizeof(WCHAR)) &&
-           (ExpressionPosition < Expression->Length / sizeof(WCHAR)))
+    /* Name parsing loop */
+    for (; !EndOfName; MatchingChars = BackTrackingPosition, NamePosition++)
     {
-        /* Basic check to test if chars are equal */
-        CompareChar = IgnoreCase ? UpcaseTable[Name->Buffer[NamePosition]] :
-                                   Name->Buffer[NamePosition];
-        if (Expression->Buffer[ExpressionPosition] == CompareChar)
-        {
-            NamePosition++;
-            ExpressionPosition++;
-        }
-        /* Check cases that eat one char */
-        else if (Expression->Buffer[ExpressionPosition] == L'?')
+        /* Reset positions */
+        OldBackTrackingPosition = BackTrackingPosition = 0;
+
+        if (NamePosition >= Name->Length / sizeof(WCHAR))
         {
-            NamePosition++;
-            ExpressionPosition++;
+            EndOfName = TRUE;
+            if (MatchingChars && (OldBackTracking[MatchingChars - 1] == Expression->Length * 2))
+                break;
         }
-        /* Test star */
-        else if (Expression->Buffer[ExpressionPosition] == L'*')
+
+        while (MatchingChars > OldBackTrackingPosition)
         {
-            /* Skip contigous stars */
-            while ((ExpressionPosition + 1 < (USHORT)(Expression->Length / sizeof(WCHAR))) &&
-                   (Expression->Buffer[ExpressionPosition + 1] == L'*'))
-            {
-                ExpressionPosition++;
-            }
+            ExpressionPosition = (OldBackTracking[OldBackTrackingPosition++] + 1) / 2;
 
-            /* Save star position */
-            StarFound++;
-            if (StarFound >= BackTrackingSize)
+            /* Expression parsing loop */
+            for (Offset = 0; ExpressionPosition < Expression->Length; Offset = sizeof(WCHAR))
             {
-                BackTrackingSize = Expression->Length / sizeof(WCHAR);
-                BackTracking = ExAllocatePoolWithTag(PagedPool | POOL_RAISE_IF_ALLOCATION_FAILURE,
-                                                     BackTrackingSize * sizeof(USHORT),
-                                                     'nrSF');
-                RtlCopyMemory(BackTracking, BackTrackingBuffer, sizeof(BackTrackingBuffer));
+                ExpressionPosition += Offset;
 
-            }
-            BackTracking[StarFound] = ExpressionPosition++;
+                if (ExpressionPosition == Expression->Length)
+                {
+                    BackTracking[BackTrackingPosition++] = Expression->Length * 2;
+                    break;
+                }
 
-            /* If star is at the end, then eat all rest and leave */
-            if (ExpressionPosition == Expression->Length / sizeof(WCHAR))
-            {
-                NamePosition = Name->Length / sizeof(WCHAR);
-                break;
-            }
+                /* If buffer too small */
+                if (BackTrackingPosition > BackTrackingBufferSize - 3)
+                {
+                    /* We should only ever get here once! */
+                    ASSERT(AllocatedBuffer == NULL);
+                    ASSERT((BackTracking == BackTrackingBuffer) || (BackTracking == OldBackTrackingBuffer));
+                    ASSERT((OldBackTracking == BackTrackingBuffer) || (OldBackTracking == OldBackTrackingBuffer));
 
-            /* Allow null matching */
-            if (Expression->Buffer[ExpressionPosition] != L'?' &&
-                     Expression->Buffer[ExpressionPosition] != Name->Buffer[NamePosition])
-            {
-                NamePosition++;
-            }
-        }
-        /* Check DOS_STAR */
-        else if (Expression->Buffer[ExpressionPosition] == DOS_STAR)
-        {
-            /* Skip contigous stars */
-            while ((ExpressionPosition + 1 < (USHORT)(Expression->Length / sizeof(WCHAR))) &&
-                   (Expression->Buffer[ExpressionPosition + 1] == DOS_STAR))
-            {
-                ExpressionPosition++;
-            }
+                    /* Calculate buffer size */
+                    BackTrackingBufferSize = Expression->Length / sizeof(WCHAR) * 2 + 1;
 
-            /* Look for last dot */
-            MatchingChars = 0;
-            LastDot = (USHORT)-1;
-            while (MatchingChars < Name->Length / sizeof(WCHAR))
-            {
-                if (Name->Buffer[MatchingChars] == L'.')
+                    /* Allocate memory for both back-tracking buffers */
+                    AllocatedBuffer = ExAllocatePoolWithTag(PagedPool | POOL_RAISE_IF_ALLOCATION_FAILURE,
+                                                            2 * BackTrackingBufferSize * sizeof(USHORT),
+                                                            'nrSF');
+                    if (AllocatedBuffer == NULL)
+                    {
+                        DPRINT1("Failed to allocate BackTracking buffer. BackTrackingBufferSize = =x%lx\n",
+                                BackTrackingBufferSize);
+                        Result = FALSE;
+                        goto Exit;
+                    }
+
+                    /* Copy BackTracking content. Note that it can point to either BackTrackingBuffer or OldBackTrackingBuffer */
+                    RtlCopyMemory(AllocatedBuffer,
+                                  BackTracking,
+                                  RTL_NUMBER_OF(BackTrackingBuffer) * sizeof(USHORT));
+
+                    /* Place current Backtracking is at the start of the new buffer */
+                    BackTracking = AllocatedBuffer;
+
+                    /* Copy OldBackTracking content */
+                    RtlCopyMemory(&BackTracking[BackTrackingBufferSize],
+                                  OldBackTracking,
+                                  RTL_NUMBER_OF(OldBackTrackingBuffer) * sizeof(USHORT));
+
+                    /* Place current OldBackTracking after current BackTracking in the buffer */
+                    OldBackTracking = &BackTracking[BackTrackingBufferSize];
+                }
+
+                /* Basic check to test if chars are equal */
+                CompareChar = (NamePosition >= Name->Length / sizeof(WCHAR)) ? UNICODE_NULL : (IgnoreCase ? UpcaseTable[Name->Buffer[NamePosition]] :
+                                           Name->Buffer[NamePosition]);
+                if (Expression->Buffer[ExpressionPosition / sizeof(WCHAR)] == CompareChar && !EndOfName)
+                {
+                    BackTracking[BackTrackingPosition++] = (ExpressionPosition + sizeof(WCHAR)) * 2;
+                }
+                /* Check cases that eat one char */
+                else if (Expression->Buffer[ExpressionPosition / sizeof(WCHAR)] == L'?' && !EndOfName)
                 {
-                    LastDot = MatchingChars;
-                    if (LastDot > NamePosition)
-                        break;
+                    BackTracking[BackTrackingPosition++] = (ExpressionPosition + sizeof(WCHAR)) * 2;
                 }
+                /* Test star */
+                else if (Expression->Buffer[ExpressionPosition / sizeof(WCHAR)] == L'*')
+                {
+                    BackTracking[BackTrackingPosition++] = ExpressionPosition * 2;
+                    BackTracking[BackTrackingPosition++] = (ExpressionPosition * 2) + 3;
+                    continue;
+                }
+                /* Check DOS_STAR */
+                else if (Expression->Buffer[ExpressionPosition / sizeof(WCHAR)] == DOS_STAR)
+                {
+                    /* Look for last dot */
+                    DontSkipDot = TRUE;
+                    if (!EndOfName && Name->Buffer[NamePosition] == '.')
+                    {
+                        for (Position = NamePosition + 1; Position < Name->Length / sizeof(WCHAR); Position++)
+                        {
+                            if (Name->Buffer[Position] == L'.')
+                            {
+                                DontSkipDot = FALSE;
+                                break;
+                            }
+                        }
+                    }
 
-                MatchingChars++;
-            }
+                    if (EndOfName || Name->Buffer[NamePosition] != L'.' || !DontSkipDot)
+                        BackTracking[BackTrackingPosition++] = ExpressionPosition * 2;
 
-            /* If we don't have dots or we didn't find last yet
-             * start eating everything
-             */
-            if (MatchingChars != Name->Length || LastDot == (USHORT)-1)
-            {
-                DosStarFound++;
-                if (DosStarFound >= DosBackTrackingSize)
+                    BackTracking[BackTrackingPosition++] = (ExpressionPosition * 2) + 3;
+                    continue;
+                }
+                /* Check DOS_DOT */
+                else if (Expression->Buffer[ExpressionPosition / sizeof(WCHAR)] == DOS_DOT)
                 {
-                    DosBackTrackingSize = Expression->Length / sizeof(WCHAR);
-                    DosBackTracking = ExAllocatePoolWithTag(PagedPool | POOL_RAISE_IF_ALLOCATION_FAILURE,
-                                                            DosBackTrackingSize * sizeof(USHORT),
-                                                            'nrSF');
-                    RtlCopyMemory(DosBackTracking, DosBackTrackingBuffer, sizeof(DosBackTrackingBuffer));
+                    if (EndOfName) continue;
+
+                    if (Name->Buffer[NamePosition] == L'.')
+                        BackTracking[BackTrackingPosition++] = (ExpressionPosition + sizeof(WCHAR)) * 2;
                 }
-                DosBackTracking[DosStarFound] = ExpressionPosition++;
+                /* Check DOS_QM */
+                else if (Expression->Buffer[ExpressionPosition / sizeof(WCHAR)] == DOS_QM)
+                {
+                    if (EndOfName || Name->Buffer[NamePosition] == L'.') continue;
 
-                /* Not the same char, start exploring */
-                if (Expression->Buffer[ExpressionPosition] != Name->Buffer[NamePosition])
-                    NamePosition++;
-            }
-            else
-            {
-                /* Else, if we are at last dot, eat it - otherwise, null match */
-                if (Name->Buffer[NamePosition] == '.')
-                    NamePosition++;
+                    BackTracking[BackTrackingPosition++] = (ExpressionPosition + sizeof(WCHAR)) * 2;
+                }
 
-                 ExpressionPosition++;
-            }
-        }
-        /* Check DOS_DOT */
-        else if (Expression->Buffer[ExpressionPosition] == DOS_DOT)
-        {
-            /* We only match dots */
-            if (Name->Buffer[NamePosition] == L'.')
-            {
-                NamePosition++;
-            }
-            /* Try to explore later on for null matching */
-            else if ((ExpressionPosition + 1 < (USHORT)(Expression->Length / sizeof(WCHAR))) &&
-                     (Name->Buffer[NamePosition] == Expression->Buffer[ExpressionPosition + 1]))
-            {
-                NamePosition++;
-            }
-            ExpressionPosition++;
-        }
-        /* Check DOS_QM */
-        else if (Expression->Buffer[ExpressionPosition] == DOS_QM)
-        {
-            /* We match everything except dots */
-            if (Name->Buffer[NamePosition] != L'.')
-            {
-                NamePosition++;
+                /* Leave from loop */
+                break;
             }
-            ExpressionPosition++;
-        }
-        /* If nothing match, try to backtrack */
-        else if (StarFound >= 0)
-        {
-            ExpressionPosition = BackTracking[StarFound--];
-        }
-        else if (DosStarFound >= 0)
-        {
-            ExpressionPosition = DosBackTracking[DosStarFound--];
-        }
-        /* Otherwise, fail */
-        else
-        {
-            break;
-        }
 
-        /* Under certain circumstances, expression is over, but name isn't
-         * and we can backtrack, then, backtrack */
-        if (ExpressionPosition == Expression->Length / sizeof(WCHAR) &&
-            NamePosition != Name->Length / sizeof(WCHAR) &&
-            StarFound >= 0)
-        {
-            ExpressionPosition = BackTracking[StarFound--];
-        }
-    }
-    /* If we have nullable matching wc at the end of the string, eat them */
-    if (ExpressionPosition != Expression->Length / sizeof(WCHAR) && NamePosition == Name->Length / sizeof(WCHAR))
-    {
-        while (ExpressionPosition < Expression->Length / sizeof(WCHAR))
-        {
-            if (Expression->Buffer[ExpressionPosition] != DOS_DOT &&
-                Expression->Buffer[ExpressionPosition] != L'*' &&
-                Expression->Buffer[ExpressionPosition] != DOS_STAR)
+            for (Position = 0; MatchingChars > OldBackTrackingPosition && Position < BackTrackingPosition; Position++)
             {
-                break;
+                while (MatchingChars > OldBackTrackingPosition &&
+                       BackTracking[Position] > OldBackTracking[OldBackTrackingPosition])
+                {
+                    ++OldBackTrackingPosition;
+                }
             }
-            ExpressionPosition++;
         }
-    }
 
-    if (BackTracking != BackTrackingBuffer)
-    {
-        ExFreePoolWithTag(BackTracking, 'nrSF');
+        /* Swap pointers */
+        BackTrackingSwap = BackTracking;
+        BackTracking = OldBackTracking;
+        OldBackTracking = BackTrackingSwap;
     }
-    if (DosBackTracking != DosBackTrackingBuffer)
+
+    /* Store result value */
+    Result = MatchingChars > 0 && (OldBackTracking[MatchingChars - 1] == (Expression->Length * 2));
+
+Exit:
+
+    /* Frees the memory if necessary */
+    if (AllocatedBuffer != NULL)
     {
-        ExFreePoolWithTag(DosBackTracking, 'nrSF');
+        ExFreePoolWithTag(AllocatedBuffer, 'nrSF');
     }
 
-    return (ExpressionPosition == Expression->Length / sizeof(WCHAR) && NamePosition == Name->Length / sizeof(WCHAR));
+    return Result;
 }
 
 /* PUBLIC FUNCTIONS **********************************************************/
@@ -338,7 +316,13 @@ FsRtlAreNamesEqual(IN PCUNICODE_STRING Name1,
         if (!NT_SUCCESS(Status)) RtlRaiseStatus(Status);
 
         /* Upcase the second string too */
-        RtlUpcaseUnicodeString(&UpcaseName2, Name2, TRUE);
+        Status = RtlUpcaseUnicodeString(&UpcaseName2, Name2, TRUE);
+        if (!NT_SUCCESS(Status))
+        {
+            RtlFreeUnicodeString(&UpcaseName1);
+            RtlRaiseStatus(Status);
+        }
+
         Name1 = &UpcaseName1;
         Name2 = &UpcaseName2;