[CMD] Protect certain actions with a critical section, patch by Katayama Hirofumi
[reactos.git] / reactos / base / shell / cmd / cmd.c
index 9859156..54ff9b2 100644 (file)
@@ -156,7 +156,7 @@ BOOL bCanExit = TRUE;     /* indicates if this shell is exitable */
 BOOL bCtrlBreak = FALSE;  /* Ctrl-Break or Ctrl-C hit */
 BOOL bIgnoreEcho = FALSE; /* Set this to TRUE to prevent a newline, when executing a command */
 INT  nErrorLevel = 0;     /* Errorlevel of last launched external program */
-BOOL bChildProcessRunning = FALSE;
+CRITICAL_SECTION ChildProcessRunningLock;
 BOOL bUnicodeOutput = FALSE;
 BOOL bDisableBatchEcho = FALSE;
 BOOL bDelayedExpansion = FALSE;
@@ -164,7 +164,6 @@ DWORD dwChildProcessId = 0;
 OSVERSIONINFO osvi;
 HANDLE hIn;
 HANDLE hOut;
-HANDLE hConsole;
 LPTSTR lpOriginalEnvironment;
 HANDLE CMD_ModuleHandle;
 
@@ -183,7 +182,7 @@ WORD wDefColor;           /* default color */
 INT
 ConvertULargeInteger(ULONGLONG num, LPTSTR des, INT len, BOOL bPutSeperator)
 {
-       TCHAR temp[32];
+       TCHAR temp[39];   /* maximum length with nNumberGroups == 1 */
        UINT  n, iTarget;
 
        if (len <= 1)
@@ -199,15 +198,15 @@ ConvertULargeInteger(ULONGLONG num, LPTSTR des, INT len, BOOL bPutSeperator)
                if (iTarget == n && bPutSeperator)
                {
                        iTarget += nNumberGroups + 1;
-                       temp[31 - n++] = cThousandSeparator;
+                       temp[38 - n++] = cThousandSeparator;
                }
-               temp[31 - n++] = (TCHAR)(num % 10) + _T('0');
+               temp[38 - n++] = (TCHAR)(num % 10) + _T('0');
                num /= 10;
        } while (num > 0);
        if (n > len-1)
                n = len-1;
 
-       memcpy(des, temp + 32 - n, n * sizeof(TCHAR));
+       memcpy(des, temp + 39 - n, n * sizeof(TCHAR));
        des[n] = _T('\0');
 
        return n;
@@ -245,7 +244,7 @@ static BOOL IsConsoleProcess(HANDLE Process)
                return TRUE;
        }
 
-       return IMAGE_SUBSYSTEM_WINDOWS_CUI == ProcessPeb.ImageSubSystem;
+       return IMAGE_SUBSYSTEM_WINDOWS_CUI == ProcessPeb.ImageSubsystem;
 }
 
 
@@ -309,7 +308,7 @@ HANDLE RunFile(DWORD flags, LPTSTR filename, LPTSTR params,
  * Rest  - rest of command line
  */
 
-static BOOL
+static INT
 Execute (LPTSTR Full, LPTSTR First, LPTSTR Rest, PARSED_COMMAND *Cmd)
 {
        TCHAR szFullName[MAX_PATH];
@@ -364,7 +363,7 @@ Execute (LPTSTR Full, LPTSTR First, LPTSTR Rest, PARSED_COMMAND *Cmd)
                }
 
                if (!working) ConErrResPuts (STRING_FREE_ERROR1);
-               return working;
+               return !working;
        }
 
        /* get the PATH environment variable and parse it */
@@ -373,7 +372,7 @@ Execute (LPTSTR Full, LPTSTR First, LPTSTR Rest, PARSED_COMMAND *Cmd)
        if (!SearchForExecutable(First, szFullName))
        {
                error_bad_command(first);
-               return FALSE;
+               return 1;
        }
 
        GetConsoleTitle (szWindowTitle, MAX_PATH);
@@ -385,7 +384,7 @@ Execute (LPTSTR Full, LPTSTR First, LPTSTR Rest, PARSED_COMMAND *Cmd)
                while (*rest == _T(' '))
                        rest++;
                TRACE ("[BATCH: %s %s]\n", debugstr_aw(szFullName), debugstr_aw(rest));
-               Batch (szFullName, first, rest, Cmd);
+               dwExitCode = Batch(szFullName, first, rest, Cmd);
        }
        else
        {
@@ -437,29 +436,23 @@ Execute (LPTSTR Full, LPTSTR First, LPTSTR Rest, PARSED_COMMAND *Cmd)
                {
                        if (IsConsoleProcess(prci.hProcess))
                        {
-                               /* FIXME: Protect this with critical section */
-                               bChildProcessRunning = TRUE;
+                               EnterCriticalSection(&ChildProcessRunningLock);
                                dwChildProcessId = prci.dwProcessId;
 
                                WaitForSingleObject (prci.hProcess, INFINITE);
 
-                               /* FIXME: Protect this with critical section */
-                               bChildProcessRunning = FALSE;
+                               LeaveCriticalSection(&ChildProcessRunningLock);
 
                                GetExitCodeProcess (prci.hProcess, &dwExitCode);
                                nErrorLevel = (INT)dwExitCode;
                        }
-                        else
-                        {
-                            nErrorLevel = 0;
-                        }
                        CloseHandle (prci.hProcess);
                }
                else
                {
                        TRACE ("[ShellExecute failed!: %s]\n", debugstr_aw(Full));
                        error_bad_command (first);
-                       nErrorLevel = 1;
+                       dwExitCode = 1;
                }
 
                // restore console mode
@@ -473,7 +466,7 @@ Execute (LPTSTR Full, LPTSTR First, LPTSTR Rest, PARSED_COMMAND *Cmd)
        OutputCodePage = GetConsoleOutputCP();
        SetConsoleTitle (szWindowTitle);
 
-       return nErrorLevel == 0;
+       return dwExitCode;
 }
 
 
@@ -486,18 +479,27 @@ Execute (LPTSTR Full, LPTSTR First, LPTSTR Rest, PARSED_COMMAND *Cmd)
  * rest  - rest of command line
  */
 
-BOOL
+INT
 DoCommand(LPTSTR first, LPTSTR rest, PARSED_COMMAND *Cmd)
 {
-       TCHAR com[_tcslen(first) + _tcslen(rest) + 2];  /* full command line */
+       TCHAR *com;
        TCHAR *cp;
        LPTSTR param;   /* pointer to command's parameters */
        INT cl;
        LPCOMMAND cmdptr;
        BOOL nointernal = FALSE;
+       INT ret;
 
        TRACE ("DoCommand: (\'%s\' \'%s\')\n", debugstr_aw(first), debugstr_aw(rest));
 
+       /* full command line */
+       com = cmd_alloc((_tcslen(first) + _tcslen(rest) + 2) * sizeof(TCHAR));
+       if (com == NULL)
+       {
+               error_out_of_memory();
+               return 1;
+       }
+
        /* If present in the first word, these characters end the name of an
         * internal command and become the beginning of its parameters. */
        cp = first + _tcscspn(first, _T("\t +,/;=[]"));
@@ -531,11 +533,15 @@ DoCommand(LPTSTR first, LPTSTR rest, PARSED_COMMAND *Cmd)
                        if (_tcsicmp(cmdptr->name, _T("echo")) != 0)
                                while (_istspace(*param))
                                        param++;
-                       return !cmdptr->func(param);
+                       ret = cmdptr->func(param);
+                       cmd_free(com);
+                       return ret;
                }
        }
 
-       return Execute(com, first, rest, Cmd);
+       ret = Execute(com, first, rest, Cmd);
+       cmd_free(com);
+       return ret;
 }
 
 
@@ -544,14 +550,16 @@ DoCommand(LPTSTR first, LPTSTR rest, PARSED_COMMAND *Cmd)
  * full input/output redirection and piping are supported
  */
 
-VOID ParseCommandLine (LPTSTR cmd)
+INT ParseCommandLine (LPTSTR cmd)
 {
+       INT Ret = 0;
        PARSED_COMMAND *Cmd = ParseCommand(cmd);
        if (Cmd)
        {
-               ExecuteCommand(Cmd);
+               Ret = ExecuteCommand(Cmd);
                FreeCommand(Cmd);
        }
+       return Ret;
 }
 
 /* Execute a command without waiting for it to finish. If it's an internal
@@ -655,9 +663,9 @@ ExecutePipeline(PARSED_COMMAND *Cmd)
        SetStdHandle(STD_INPUT_HANDLE, hOldConIn);
 
        /* Wait for all processes to complete */
-       bChildProcessRunning = TRUE;
+       EnterCriticalSection(&ChildProcessRunningLock);
        WaitForMultipleObjects(nProcesses, hProcess, TRUE, INFINITE);
-       bChildProcessRunning = FALSE;
+       LeaveCriticalSection(&ChildProcessRunningLock);
 
        /* Use the exit code of the last process in the pipeline */
        GetExitCodeProcess(hProcess[nProcesses - 1], &dwExitCode);
@@ -680,27 +688,27 @@ failed:
 #endif
 }
 
-BOOL
+INT
 ExecuteCommand(PARSED_COMMAND *Cmd)
 {
        PARSED_COMMAND *Sub;
        LPTSTR First, Rest;
-       BOOL Success = TRUE;
+       INT Ret = 0;
 
        if (!PerformRedirection(Cmd->Redirections))
-               return FALSE;
+               return 1;
 
        switch (Cmd->Type)
        {
        case C_COMMAND:
-               Success = FALSE;
+               Ret = 1;
                First = DoDelayedExpansion(Cmd->Command.First);
                if (First)
                {
                        Rest = DoDelayedExpansion(Cmd->Command.Rest);
                        if (Rest)
                        {
-                               Success = DoCommand(First, Rest, Cmd);
+                               Ret = DoCommand(First, Rest, Cmd);
                                cmd_free(Rest);
                        }
                        cmd_free(First);
@@ -710,69 +718,54 @@ ExecuteCommand(PARSED_COMMAND *Cmd)
        case C_BLOCK:
        case C_MULTI:
                for (Sub = Cmd->Subcommands; Sub; Sub = Sub->Next)
-                       Success = ExecuteCommand(Sub);
+                       Ret = ExecuteCommand(Sub);
                break;
        case C_IFFAILURE:
-       case C_IFSUCCESS:
                Sub = Cmd->Subcommands;
-               Success = ExecuteCommand(Sub);
-               if (Success == (Cmd->Type - C_IFFAILURE))
+               Ret = ExecuteCommand(Sub);
+               if (Ret != 0)
                {
-                       Sub = Sub->Next;
-                       Success = ExecuteCommand(Sub);
+                       nErrorLevel = Ret;
+                       Ret = ExecuteCommand(Sub->Next);
                }
                break;
+       case C_IFSUCCESS:
+               Sub = Cmd->Subcommands;
+               Ret = ExecuteCommand(Sub);
+               if (Ret == 0)
+                       Ret = ExecuteCommand(Sub->Next);
+               break;
        case C_PIPE:
                ExecutePipeline(Cmd);
                break;
        case C_IF:
-               Success = ExecuteIf(Cmd);
+               Ret = ExecuteIf(Cmd);
                break;
        case C_FOR:
-               Success = ExecuteFor(Cmd);
+               Ret = ExecuteFor(Cmd);
                break;
        }
 
        UndoRedirection(Cmd->Redirections, NULL);
-       return Success;
-}
-
-BOOL
-GrowIfNecessary_dbg ( UINT needed, LPTSTR* ret, UINT* retlen, const char *file, int line )
-{
-       if ( *ret && needed < *retlen )
-               return TRUE;
-       *retlen = needed;
-       if ( *ret )
-               cmd_free ( *ret );
-#ifdef _DEBUG_MEM
-       *ret = (LPTSTR)cmd_alloc_dbg ( *retlen * sizeof(TCHAR), file, line );
-#else
-       *ret = (LPTSTR)cmd_alloc ( *retlen * sizeof(TCHAR) );
-#endif
-       if ( !*ret )
-               SetLastError ( ERROR_OUTOFMEMORY );
-       return *ret != NULL;
+       return Ret;
 }
-#define GrowIfNecessary(x, y, z) GrowIfNecessary_dbg(x, y, z, __FILE__, __LINE__)
 
 LPTSTR
 GetEnvVar(LPCTSTR varName)
 {
        static LPTSTR ret = NULL;
-       static UINT retlen = 0;
        UINT size;
 
-       size = GetEnvironmentVariable ( varName, ret, retlen );
-       if ( size > retlen )
+       cmd_free(ret);
+       ret = NULL;
+       size = GetEnvironmentVariable(varName, NULL, 0);
+       if (size > 0)
        {
-               if ( !GrowIfNecessary ( size, &ret, &retlen ) )
-                       return NULL;
-               size = GetEnvironmentVariable ( varName, ret, retlen );
+               ret = cmd_alloc(size * sizeof(TCHAR));
+               if (ret != NULL)
+                       GetEnvironmentVariable(varName, ret, size + 1);
        }
-       if ( size )
-               return ret;
-       return NULL;
+       return ret;
 }
 
 LPCTSTR
@@ -1380,7 +1373,12 @@ ReadLine (TCHAR *commandline, BOOL bMore)
                        }
                }
 
-               ReadCommand (readline, CMDLINE_LENGTH - 1);
+               if (!ReadCommand(readline, CMDLINE_LENGTH - 1))
+               {
+                       bExit = TRUE;
+                       return FALSE;
+               }
+
                if (CheckCtrlBreak(BREAK_INPUT))
                {
                        ConOutPuts(_T("\n"));
@@ -1398,7 +1396,7 @@ ReadLine (TCHAR *commandline, BOOL bMore)
        return SubstituteVars(ip, commandline, _T('%'));
 }
 
-static INT
+static VOID
 ProcessInput()
 {
        PARSED_COMMAND *Cmd;
@@ -1412,8 +1410,6 @@ ProcessInput()
                ExecuteCommand(Cmd);
                FreeCommand(Cmd);
        }
-
-       return nErrorLevel;
 }
 
 
@@ -1441,13 +1437,16 @@ BOOL WINAPI BreakHandler (DWORD dwCtrlType)
                }
        }
 
-       if (bChildProcessRunning == TRUE)
+       if (!TryEnterCriticalSection(&ChildProcessRunningLock))
        {
                SelfGenerated = TRUE;
                GenerateConsoleCtrlEvent (dwCtrlType, 0);
                return TRUE;
        }
-
+       else
+       {
+               LeaveCriticalSection(&ChildProcessRunningLock);
+       }
 
     rec.EventType = KEY_EVENT;
     rec.Event.KeyEvent.bKeyDown = TRUE;
@@ -1521,13 +1520,13 @@ ShowCommands (VOID)
 #endif
 
 static VOID
-ExecuteAutoRunFile (VOID)
+ExecuteAutoRunFile(HKEY hkeyRoot)
 {
-    TCHAR autorun[MAX_PATH];
-       DWORD len = MAX_PATH;
+       TCHAR autorun[2048];
+       DWORD len = sizeof autorun;
        HKEY hkey;
 
-    if( RegOpenKeyEx(HKEY_LOCAL_MACHINE, 
+    if (RegOpenKeyEx(hkeyRoot,
                     _T("SOFTWARE\\Microsoft\\Command Processor"),
                     0, 
                     KEY_READ, 
@@ -1543,9 +1542,8 @@ ExecuteAutoRunFile (VOID)
                        if (*autorun)
                                ParseCommandLine(autorun);
            }
+           RegCloseKey(hkey);
     }
-
-       RegCloseKey(hkey);
 }
 
 /* Get the command that comes after a /C or /K switch */
@@ -1611,6 +1609,7 @@ Initialize()
        TCHAR commandline[CMDLINE_LENGTH];
        TCHAR ModuleName[_MAX_PATH + 1];
        TCHAR lpBuffer[2];
+       INT nExitCode;
 
        //INT len;
        TCHAR *ptr, *cmdLine, option = 0;
@@ -1631,9 +1630,6 @@ Initialize()
                NtReadVirtualMemoryPtr = (NtReadVirtualMemoryProc)GetProcAddress(NtDllModule, "NtReadVirtualMemory");
        }
 
-       cmdLine = GetCommandLine();
-       TRACE ("[command args: %s]\n", debugstr_aw(cmdLine));
-
        InitLocale ();
 
        /* get default input and output console handles */
@@ -1647,12 +1643,31 @@ Initialize()
        if (GetEnvironmentVariable(_T("PROMPT"),lpBuffer, sizeof(lpBuffer) / sizeof(lpBuffer[0])) == 0)
            SetEnvironmentVariable (_T("PROMPT"), _T("$P$G"));
 
+#ifdef FEATURE_DIR_STACK
+       /* initialize directory stack */
+       InitDirectoryStack ();
+#endif
+
+#ifdef FEATURE_HISTORY
+       /*initialize history*/
+       InitHistory();
+#endif
+
+       /* Set COMSPEC environment variable */
+       if (0 != GetModuleFileName (NULL, ModuleName, _MAX_PATH + 1))
+       {
+               ModuleName[_MAX_PATH] = _T('\0');
+               SetEnvironmentVariable (_T("COMSPEC"), ModuleName);
+       }
+
+       /* add ctrl break handler */
+       AddBreakHandler ();
+
 
        SetConsoleMode (hIn, ENABLE_PROCESSED_INPUT);
 
-#ifdef INCLUDE_CMD_CHDIR
-       InitLastPath ();
-#endif
+       cmdLine = GetCommandLine();
+       TRACE ("[command args: %s]\n", debugstr_aw(cmdLine));
 
        for (ptr = cmdLine; *ptr; ptr++)
        {
@@ -1662,7 +1677,9 @@ Initialize()
                        if (option == _T('?'))
                        {
                                ConOutResPaging(TRUE,STRING_CMD_HELP8);
-                               cmd_exit(0);
+                               nErrorLevel = 1;
+                               bExit = TRUE;
+                               return;
                        }
                        else if (option == _T('P'))
                        {
@@ -1730,37 +1747,22 @@ Initialize()
                ConOutPuts(_T("(C) Copyright 1998-") _T(COPYRIGHT_YEAR) _T(" ReactOS Team."));
        }
 
-#ifdef FEATURE_DIR_STACK
-       /* initialize directory stack */
-       InitDirectoryStack ();
-#endif
-
-
-#ifdef FEATURE_HISTORY
-       /*initialize history*/
-       InitHistory();
-#endif
-
-       /* Set COMSPEC environment variable */
-       if (0 != GetModuleFileName (NULL, ModuleName, _MAX_PATH + 1))
+       if (AutoRun)
        {
-               ModuleName[_MAX_PATH] = _T('\0');
-               SetEnvironmentVariable (_T("COMSPEC"), ModuleName);
+               ExecuteAutoRunFile(HKEY_LOCAL_MACHINE);
+               ExecuteAutoRunFile(HKEY_CURRENT_USER);
        }
 
-       /* add ctrl break handler */
-       AddBreakHandler ();
-
-       if (AutoRun)
-               ExecuteAutoRunFile();
-
        if (*ptr)
        {
                /* Do the /C or /K command */
                GetCmdLineCommand(commandline, &ptr[2], AlwaysStrip);
-               ParseCommandLine(commandline);
+               nExitCode = ParseCommandLine(commandline);
                if (option != _T('K'))
+               {
+                       nErrorLevel = nExitCode;
                        bExit = TRUE;
+               }
        }
 }
 
@@ -1785,19 +1787,18 @@ static VOID Cleanup()
        DestroyDirectoryStack ();
 #endif
 
-#ifdef INCLUDE_CMD_CHDIR
-       FreeLastPath ();
-#endif
-
 #ifdef FEATURE_HISTORY
        CleanHistory();
 #endif
 
+       /* free GetEnvVar's buffer */
+       GetEnvVar(NULL);
 
        /* remove ctrl break handler */
        RemoveBreakHandler ();
        SetConsoleMode( GetStdHandle( STD_INPUT_HANDLE ),
                        ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT | ENABLE_ECHO_INPUT );
+       DeleteCriticalSection(&ChildProcessRunningLock);
 }
 
 /*
@@ -1805,10 +1806,11 @@ static VOID Cleanup()
  */
 int cmd_main (int argc, const TCHAR *argv[])
 {
+       HANDLE hConsole;
        TCHAR startPath[MAX_PATH];
        CONSOLE_SCREEN_BUFFER_INFO Info;
-       INT nExitCode;
 
+       InitializeCriticalSection(&ChildProcessRunningLock);
        lpOriginalEnvironment = DuplicateEnvironment();
 
        GetCurrentDirectory(MAX_PATH,startPath);
@@ -1821,12 +1823,16 @@ int cmd_main (int argc, const TCHAR *argv[])
        hConsole = CreateFile(_T("CONOUT$"), GENERIC_READ|GENERIC_WRITE,
                FILE_SHARE_READ|FILE_SHARE_WRITE, NULL,
                OPEN_EXISTING, 0, NULL);
-       if (GetConsoleScreenBufferInfo(hConsole, &Info) == FALSE)
+       if (hConsole != INVALID_HANDLE_VALUE)
        {
-               ConErrFormatMessage(GetLastError());
-               return(1);
+               if (!GetConsoleScreenBufferInfo(hConsole, &Info))
+               {
+                       ConErrFormatMessage(GetLastError());
+                       return(1);
+               }
+               wDefColor = Info.wAttributes;
+               CloseHandle(hConsole);
        }
-       wDefColor = Info.wAttributes;
 
        InputCodePage= GetConsoleCP();
        OutputCodePage = GetConsoleOutputCP();
@@ -1836,15 +1842,15 @@ int cmd_main (int argc, const TCHAR *argv[])
        Initialize();
 
        /* call prompt routine */
-       nExitCode = ProcessInput();
+       ProcessInput();
 
        /* do the cleanup */
        Cleanup();
 
        cmd_free(lpOriginalEnvironment);
 
-       cmd_exit(nExitCode);
-       return(nExitCode);
+       cmd_exit(nErrorLevel);
+       return(nErrorLevel);
 }
 
 /* EOF */