* Sync up to trunk HEAD (r62975).
[reactos.git] / ntoskrnl / fsrtl / name.c
index 1405b9f..1cda629 100644 (file)
@@ -5,7 +5,8 @@
  * PURPOSE:         Provides name parsing and other support routines for FSDs
  * PROGRAMMERS:     Alex Ionescu (alex.ionescu@reactos.org)
  *                  Filip Navara (navaraf@reactos.org)
- *                  Pierre Schweitzer (pierre.schweitzer@reactos.org) 
+ *                  Pierre Schweitzer (pierre.schweitzer@reactos.org)
+ *                  Aleksey Bragin (aleksey@reactos.org)
  */
 
 /* INCLUDES ******************************************************************/
@@ -22,64 +23,255 @@ FsRtlIsNameInExpressionPrivate(IN PUNICODE_STRING Expression,
                                IN BOOLEAN IgnoreCase,
                                IN PWCHAR UpcaseTable OPTIONAL)
 {
-    ULONG i = 0, j, k = 0;
+    SHORT StarFound = -1, DosStarFound = -1;
+    PUSHORT BackTracking = NULL, DosBackTracking = NULL;
+    UNICODE_STRING IntExpression;
+    USHORT ExpressionPosition = 0, NamePosition = 0, MatchingChars, LastDot;
+    WCHAR CompareChar;
+    PAGED_CODE();
 
-    ASSERT(!FsRtlDoesNameContainWildCards(Name));
+    /* Check if we were given strings at all */
+    if (!Name->Length || !Expression->Length)
+    {
+        /* Return TRUE if both strings are empty, otherwise FALSE */
+        if (Name->Length == 0 && Expression->Length == 0)
+            return TRUE;
+        else
+            return FALSE;
+    }
+
+    /* Check for a shortcut: just one wildcard */
+    if (Expression->Length == sizeof(WCHAR))
+    {
+        if (Expression->Buffer[0] == L'*')
+            return TRUE;
+    }
+
+    ASSERT(!IgnoreCase || UpcaseTable);
+
+    /* Another shortcut, wildcard followed by some string */
+    if (Expression->Buffer[0] == L'*')
+    {
+        /* Copy Expression to our local variable */
+        IntExpression = *Expression;
+
+        /* Skip the first char */
+        IntExpression.Buffer++;
+        IntExpression.Length -= sizeof(WCHAR);
+
+        /* Continue only if the rest of the expression does NOT contain
+           any more wildcards */
+        if (!FsRtlDoesNameContainWildCards(&IntExpression))
+        {
+            /* Check for a degenerate case */
+            if (Name->Length < (Expression->Length - sizeof(WCHAR)))
+                return FALSE;
+
+            /* Calculate position */
+            NamePosition = (Name->Length - IntExpression.Length) / sizeof(WCHAR);
+
+            /* Compare */
+            if (!IgnoreCase)
+            {
+                /* We can just do a byte compare */
+                return RtlEqualMemory(IntExpression.Buffer,
+                                      Name->Buffer + NamePosition,
+                                      IntExpression.Length);
+            }
+            else
+            {
+                /* Not so easy, need to upcase and check char by char */
+                for (ExpressionPosition = 0; ExpressionPosition < (IntExpression.Length / sizeof(WCHAR)); ExpressionPosition++)
+                {
+                    /* Assert that expression is already upcased! */
+                    ASSERT(IntExpression.Buffer[ExpressionPosition] == UpcaseTable[IntExpression.Buffer[ExpressionPosition]]);
+
+                    /* Now compare upcased name char with expression */
+                    if (UpcaseTable[Name->Buffer[NamePosition + ExpressionPosition]] !=
+                        IntExpression.Buffer[ExpressionPosition])
+                    {
+                        return FALSE;
+                    }
+                }
 
-    while (i < Name->Length / sizeof(WCHAR) && k < Expression->Length / sizeof(WCHAR))
+                /* It matches */
+                return TRUE;
+            }
+        }
+    }
+
+    while ((NamePosition < Name->Length / sizeof(WCHAR)) &&
+           (ExpressionPosition < Expression->Length / sizeof(WCHAR)))
     {
-      if ((Expression->Buffer[k] == (IgnoreCase ? UpcaseTable[Name->Buffer[i]] : Name->Buffer[i])) ||
-          (Expression->Buffer[k] == L'?') || (Expression->Buffer[k] == DOS_QM) ||
-          (Expression->Buffer[k] == DOS_DOT && (Name->Buffer[i] == L'.' || Name->Buffer[i] == L'0')))
-      {
-        i++;
-        k++;
-      }
-      else if (Expression->Buffer[k] == L'*')
-      {
-        if (k < (Expression->Length / sizeof(WCHAR) - 1))
+        /* Basic check to test if chars are equal */
+        CompareChar = IgnoreCase ? UpcaseTable[Name->Buffer[NamePosition]] :
+                                   Name->Buffer[NamePosition];
+        if (Expression->Buffer[ExpressionPosition] == CompareChar)
+        {
+            NamePosition++;
+            ExpressionPosition++;
+        }
+        /* Check cases that eat one char */
+        else if (Expression->Buffer[ExpressionPosition] == L'?')
+        {
+            NamePosition++;
+            ExpressionPosition++;
+        }
+        /* Test star */
+        else if (Expression->Buffer[ExpressionPosition] == L'*')
+        {
+            /* Skip contigous stars */
+            while ((ExpressionPosition + 1 < (USHORT)(Expression->Length / sizeof(WCHAR))) &&
+                   (Expression->Buffer[ExpressionPosition + 1] == L'*'))
+            {
+                ExpressionPosition++;
+            }
+
+            /* Save star position */
+            if (!BackTracking)
+            {
+                BackTracking = ExAllocatePoolWithTag(PagedPool | POOL_RAISE_IF_ALLOCATION_FAILURE,
+                                                     (Expression->Length / sizeof(WCHAR)) * sizeof(USHORT),
+                                                     'nrSF');
+            }
+            BackTracking[++StarFound] = ExpressionPosition++;
+
+            /* If star is at the end, then eat all rest and leave */
+            if (ExpressionPosition == Expression->Length / sizeof(WCHAR))
+            {
+                NamePosition = Name->Length / sizeof(WCHAR);
+                break;
+            }
+            /* Allow null matching */
+            else if (Expression->Buffer[ExpressionPosition] != L'?' &&
+                     Expression->Buffer[ExpressionPosition] != Name->Buffer[NamePosition])
+            {
+                NamePosition++;
+            }
+        }
+        /* Check DOS_STAR */
+        else if (Expression->Buffer[ExpressionPosition] == DOS_STAR)
+        {
+            /* Skip contigous stars */
+            while ((ExpressionPosition + 1 < (USHORT)(Expression->Length / sizeof(WCHAR))) &&
+                   (Expression->Buffer[ExpressionPosition + 1] == DOS_STAR))
+            {
+                ExpressionPosition++;
+            }
+
+            /* Look for last dot */
+            MatchingChars = 0;
+            LastDot = (USHORT)-1;
+            while (MatchingChars < Name->Length / sizeof(WCHAR))
+            {
+                if (Name->Buffer[MatchingChars] == L'.')
+                {
+                    LastDot = MatchingChars;
+                    if (LastDot > NamePosition)
+                        break;
+                }
+
+                MatchingChars++;
+            }
+
+            /* If we don't have dots or we didn't find last yet
+             * start eating everything
+             */
+            if (MatchingChars != Name->Length || LastDot == (USHORT)-1)
+            {
+                if (!DosBackTracking) DosBackTracking = ExAllocatePoolWithTag(PagedPool | POOL_RAISE_IF_ALLOCATION_FAILURE,
+                                                                              (Expression->Length / sizeof(WCHAR)) * sizeof(USHORT),
+                                                                              'nrSF');
+                DosBackTracking[++DosStarFound] = ExpressionPosition++;
+
+                /* Not the same char, start exploring */
+                if (Expression->Buffer[ExpressionPosition] != Name->Buffer[NamePosition])
+                    NamePosition++;
+            }
+            else
+            {
+                /* Else, if we are at last dot, eat it - otherwise, null match */
+                if (Name->Buffer[NamePosition] == '.')
+                    NamePosition++;
+
+                 ExpressionPosition++;
+            }
+        }
+        /* Check DOS_DOT */
+        else if (Expression->Buffer[ExpressionPosition] == DOS_DOT)
+        {
+            /* We only match dots */
+            if (Name->Buffer[NamePosition] == L'.')
+            {
+                NamePosition++;
+            }
+            /* Try to explore later on for null matching */
+            else if ((ExpressionPosition + 1 < (USHORT)(Expression->Length / sizeof(WCHAR))) && 
+                     (Name->Buffer[NamePosition] == Expression->Buffer[ExpressionPosition + 1]))
+            {
+                NamePosition++;
+            }
+            ExpressionPosition++;
+        }
+        /* Check DOS_QM */
+        else if (Expression->Buffer[ExpressionPosition] == DOS_QM)
         {
-          if (Expression->Buffer[k+1] != L'*' && Expression->Buffer[k+1] != L'?' &&
-              Expression->Buffer[k+1] != DOS_DOT && Expression->Buffer[k+1] != DOS_QM &&
-              Expression->Buffer[k+1] != DOS_STAR)
-          {
-            while ((IgnoreCase ? UpcaseTable[Name->Buffer[i]] : Name->Buffer[i]) != Expression->Buffer[k+1] &&
-                   i < Name->Length / sizeof(WCHAR)) i++;
-          }
-          else
-          {
-            if (!(Expression->Buffer[k+1] != DOS_DOT && (Name->Buffer[i] == L'.' || Name->Buffer[i] == L'0')))
+            /* We match everything except dots */
+            if (Name->Buffer[NamePosition] != L'.')
             {
-              i++;
+                NamePosition++;
             }
-          }
+            ExpressionPosition++;
+        }
+        /* If nothing match, try to backtrack */
+        else if (StarFound >= 0)
+        {
+            ExpressionPosition = BackTracking[StarFound--];
         }
+        else if (DosStarFound >= 0)
+        {
+            ExpressionPosition = DosBackTracking[DosStarFound--];
+        }
+        /* Otherwise, fail */
         else
         {
-          i = Name->Length / sizeof(WCHAR);
+            break;
+        }
+
+        /* Under certain circumstances, expression is over, but name isn't
+         * and we can backtrack, then, backtrack */
+        if (ExpressionPosition == Expression->Length / sizeof(WCHAR) &&
+            NamePosition != Name->Length / sizeof(WCHAR) &&
+            StarFound >= 0)
+        {
+            ExpressionPosition = BackTracking[StarFound--];
         }
-        k++;
-      }
-      else if (Expression->Buffer[k] == DOS_STAR)
-      {
-        j = i;
-        while (j < Name->Length / sizeof(WCHAR))
+    }
+    /* If we have nullable matching wc at the end of the string, eat them */
+    if (ExpressionPosition != Expression->Length / sizeof(WCHAR) && NamePosition == Name->Length / sizeof(WCHAR))
+    {
+        while (ExpressionPosition < Expression->Length / sizeof(WCHAR))
         {
-          if (Name->Buffer[j] == L'.')
-          {
-            i = j;
-          }
-          j++;
+            if (Expression->Buffer[ExpressionPosition] != DOS_DOT &&
+                Expression->Buffer[ExpressionPosition] != L'*' &&
+                Expression->Buffer[ExpressionPosition] != DOS_STAR)
+            {
+                break;
+            }
+            ExpressionPosition++;
         }
-        k++;
-      }
-      else
-      {
-        i = Name->Length / sizeof(WCHAR);
-      }
     }
 
-    return (k == Expression->Length / sizeof(WCHAR) && i == Name->Length / sizeof(WCHAR));
+    if (BackTracking)
+    {
+        ExFreePoolWithTag(BackTracking, 'nrSF');
+    }
+    if (DosBackTracking)
+    {
+        ExFreePoolWithTag(DosBackTracking, 'nrSF');
+    }
+
+    return (ExpressionPosition == Expression->Length / sizeof(WCHAR) && NamePosition == Name->Length / sizeof(WCHAR));
 }
 
 /* PUBLIC FUNCTIONS **********************************************************/
@@ -117,8 +309,9 @@ FsRtlAreNamesEqual(IN PCUNICODE_STRING Name1,
     UNICODE_STRING UpcaseName1;
     UNICODE_STRING UpcaseName2;
     BOOLEAN StringsAreEqual, MemoryAllocated = FALSE;
-    ULONG i;
+    USHORT i;
     NTSTATUS Status;
+    PAGED_CODE();
 
     /* Well, first check their size */
     if (Name1->Length != Name2->Length) return FALSE;
@@ -208,8 +401,9 @@ FsRtlDissectName(IN UNICODE_STRING Name,
                  OUT PUNICODE_STRING FirstPart,
                  OUT PUNICODE_STRING RemainingPart)
 {
-    ULONG FirstPosition, i;
-    ULONG SkipFirstSlash = 0;
+    USHORT FirstPosition, i;
+    USHORT SkipFirstSlash = 0;
+    PAGED_CODE();
 
     /* Zero the strings before continuing */
     RtlZeroMemory(FirstPart, sizeof(UNICODE_STRING));
@@ -272,6 +466,7 @@ NTAPI
 FsRtlDoesNameContainWildCards(IN PUNICODE_STRING Name)
 {
     PWCHAR Ptr;
+    PAGED_CODE();
 
     /* Loop through every character */
     if (Name->Length)
@@ -297,7 +492,7 @@ FsRtlDoesNameContainWildCards(IN PUNICODE_STRING Name)
  *
  * @param Expression
  *        The string in which we've to find Name. It can contain wildcards.
- *        If IgnoreCase is set to TRUE, this string MUST BE uppercase. 
+ *        If IgnoreCase is set to TRUE, this string MUST BE uppercase.
  *
  * @param Name
  *        The string to find. It cannot contain wildcards
@@ -307,7 +502,7 @@ FsRtlDoesNameContainWildCards(IN PUNICODE_STRING Name)
  *
  * @param UpcaseTable
  *        If not NULL, and if IgnoreCase is set to TRUE, it will be used to
- *        upcase the both strings 
+ *        upcase the both strings
  *
  * @return TRUE if Name is in Expression, FALSE otherwise
  *
@@ -330,7 +525,7 @@ FsRtlIsNameInExpression(IN PUNICODE_STRING Expression,
     if (IgnoreCase && !UpcaseTable)
     {
         Status = RtlUpcaseUnicodeString(&IntName, Name, TRUE);
-        if (Status != STATUS_SUCCESS)
+        if (!NT_SUCCESS(Status))
         {
             ExRaiseStatus(Status);
         }