- Allocate the csrss request buffer from heap if the necessary length is larger than...
[reactos.git] / reactos / subsys / csrss / win32csr / conio.c
index a01987d..51084e1 100644 (file)
@@ -564,7 +564,7 @@ CSR_API(CsrReadConsole)
   /* truncate length to CSRSS_MAX_READ_CONSOLE_REQUEST */
   nNumberOfCharsToRead = min(Request->Data.ReadConsoleRequest.NrCharactersToRead, CSRSS_MAX_READ_CONSOLE / CharSize);
   Request->Header.u1.s1.TotalLength = sizeof(CSR_API_MESSAGE);
   /* truncate length to CSRSS_MAX_READ_CONSOLE_REQUEST */
   nNumberOfCharsToRead = min(Request->Data.ReadConsoleRequest.NrCharactersToRead, CSRSS_MAX_READ_CONSOLE / CharSize);
   Request->Header.u1.s1.TotalLength = sizeof(CSR_API_MESSAGE);
-  Request->Header.u1.s1.DataLength = Request->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+  Request->Header.u1.s1.DataLength = sizeof(CSR_API_MESSAGE) - sizeof(PORT_MESSAGE);
 
   Buffer = Request->Data.ReadConsoleRequest.Buffer;
   UnicodeBuffer = (PWCHAR)Buffer;
 
   Buffer = Request->Data.ReadConsoleRequest.Buffer;
   UnicodeBuffer = (PWCHAR)Buffer;
@@ -672,9 +672,15 @@ CSR_API(CsrReadConsole)
     {
       Console->EchoCount = 0;             /* if the client is no longer waiting on input, do not echo */
     }
     {
       Console->EchoCount = 0;             /* if the client is no longer waiting on input, do not echo */
     }
-  Request->Header.u1.s1.TotalLength += i * CharSize;
 
   ConioUnlockConsole(Console);
 
   ConioUnlockConsole(Console);
+
+  if (CSR_API_MESSAGE_HEADER_SIZE(CSRSS_READ_CONSOLE) + i * CharSize > sizeof(CSR_API_MESSAGE))
+    {
+      Request->Header.u1.s1.TotalLength = CSR_API_MESSAGE_HEADER_SIZE(CSRSS_READ_CONSOLE) + i * CharSize;
+      Request->Header.u1.s1.DataLength = Request->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+    }
+
   return Request->Status;
 }
 
   return Request->Status;
 }
 
@@ -923,8 +929,8 @@ CSR_API(CsrWriteConsole)
 
   DPRINT("CsrWriteConsole\n");
 
 
   DPRINT("CsrWriteConsole\n");
 
-  if (Request->Header.u1.s1.DataLength
-      < sizeof(CSRSS_WRITE_CONSOLE) 
+  if (Request->Header.u1.s1.TotalLength
+      < CSR_API_MESSAGE_HEADER_SIZE(CSRSS_WRITE_CONSOLE) 
         + (Request->Data.WriteConsoleRequest.NrCharactersToWrite * CharSize))
     {
       DPRINT1("Invalid request size\n");
         + (Request->Data.WriteConsoleRequest.NrCharactersToWrite * CharSize))
     {
       DPRINT1("Invalid request size\n");
@@ -1558,8 +1564,8 @@ CSR_API(CsrWriteConsoleOutputChar)
 
   CharSize = (Request->Data.WriteConsoleOutputCharRequest.Unicode ? sizeof(WCHAR) : sizeof(CHAR));
 
 
   CharSize = (Request->Data.WriteConsoleOutputCharRequest.Unicode ? sizeof(WCHAR) : sizeof(CHAR));
 
-  if (Request->Header.u1.s1.DataLength
-      < sizeof(CSRSS_WRITE_CONSOLE_OUTPUT_CHAR) 
+  if (Request->Header.u1.s1.TotalLength
+      < CSR_API_MESSAGE_HEADER_SIZE(CSRSS_WRITE_CONSOLE_OUTPUT_CHAR) 
         + (Request->Data.WriteConsoleOutputCharRequest.Length * CharSize))
     {
       DPRINT1("Invalid request size\n");
         + (Request->Data.WriteConsoleOutputCharRequest.Length * CharSize))
     {
       DPRINT1("Invalid request size\n");
@@ -1812,9 +1818,9 @@ CSR_API(CsrWriteConsoleOutputAttrib)
 
   DPRINT("CsrWriteConsoleOutputAttrib\n");
 
 
   DPRINT("CsrWriteConsoleOutputAttrib\n");
 
-  if (Request->Header.u1.s1.DataLength
-      < sizeof(CSRSS_WRITE_CONSOLE_OUTPUT_ATTRIB)
-        + Request->Data.WriteConsoleOutputAttribRequest.Length)
+  if (Request->Header.u1.s1.TotalLength
+      < CSR_API_MESSAGE_HEADER_SIZE(CSRSS_WRITE_CONSOLE_OUTPUT_ATTRIB)
+        + Request->Data.WriteConsoleOutputAttribRequest.Length * sizeof(WORD))
     {
       DPRINT1("Invalid request size\n");
       Request->Header.u1.s1.TotalLength = sizeof(CSR_API_MESSAGE);
     {
       DPRINT1("Invalid request size\n");
       Request->Header.u1.s1.TotalLength = sizeof(CSR_API_MESSAGE);
@@ -2256,11 +2262,12 @@ CSR_API(CsrSetTitle)
 {
   NTSTATUS Status;
   PCSRSS_CONSOLE Console;
 {
   NTSTATUS Status;
   PCSRSS_CONSOLE Console;
+  PWCHAR Buffer;
 
   DPRINT("CsrSetTitle\n");
 
 
   DPRINT("CsrSetTitle\n");
 
-  if (Request->Header.u1.s1.DataLength
-      < sizeof(CSRSS_SET_TITLE) 
+  if (Request->Header.u1.s1.TotalLength
+      < CSR_API_MESSAGE_HEADER_SIZE(CSRSS_SET_TITLE)
         + Request->Data.SetTitleRequest.Length)
     {
       DPRINT1("Invalid request size\n");
         + Request->Data.SetTitleRequest.Length)
     {
       DPRINT1("Invalid request size\n");
@@ -2278,16 +2285,26 @@ CSR_API(CsrSetTitle)
     }
   else
     {
     }
   else
     {
-      /* copy title to console */
-      RtlFreeUnicodeString(&Console->Title);
-      RtlCreateUnicodeString(&Console->Title, Request->Data.SetTitleRequest.Title);
-      if (! ConioChangeTitle(Console))
+      Buffer =  RtlAllocateHeap(RtlGetProcessHeap(), 0, Request->Data.SetTitleRequest.Length);
+      if (Buffer)
         {
         {
-          Request->Status = STATUS_UNSUCCESSFUL;
+          /* copy title to console */
+          RtlFreeUnicodeString(&Console->Title);
+          Console->Title.Buffer = Buffer;
+          Console->Title.Length = Console->Title.MaximumLength = Request->Data.SetTitleRequest.Length;
+          memcpy(Console->Title.Buffer, Request->Data.SetTitleRequest.Title, Console->Title.Length);
+          if (! ConioChangeTitle(Console))
+            {
+              Request->Status = STATUS_UNSUCCESSFUL;
+            }
+          else
+            {
+              Request->Status = STATUS_SUCCESS;
+            }
         }
       else
         {
         }
       else
         {
-          Request->Status = STATUS_SUCCESS;
+          Request->Status = STATUS_NO_MEMORY;
         }
     }
   ConioUnlockConsole(Console);
         }
     }
   ConioUnlockConsole(Console);
@@ -2299,6 +2316,7 @@ CSR_API(CsrGetTitle)
 {
   NTSTATUS Status;
   PCSRSS_CONSOLE Console;
 {
   NTSTATUS Status;
   PCSRSS_CONSOLE Console;
+  DWORD Length;
 
   DPRINT("CsrGetTitle\n");
 
 
   DPRINT("CsrGetTitle\n");
 
@@ -2318,12 +2336,17 @@ CSR_API(CsrGetTitle)
   Request->Data.GetTitleRequest.ConsoleHandle = Request->Data.GetTitleRequest.ConsoleHandle;
   Request->Data.GetTitleRequest.Length = Console->Title.Length;
   wcscpy (Request->Data.GetTitleRequest.Title, Console->Title.Buffer);
   Request->Data.GetTitleRequest.ConsoleHandle = Request->Data.GetTitleRequest.ConsoleHandle;
   Request->Data.GetTitleRequest.Length = Console->Title.Length;
   wcscpy (Request->Data.GetTitleRequest.Title, Console->Title.Buffer);
-  Request->Header.u1.s1.TotalLength += Console->Title.Length;
-  Request->Header.u1.s1.DataLength += Console->Title.Length;
-  Request->Status = STATUS_SUCCESS;
+  Length = CSR_API_MESSAGE_HEADER_SIZE(CSRSS_SET_TITLE) + Console->Title.Length;
 
   ConioUnlockConsole(Console);
 
 
   ConioUnlockConsole(Console);
 
+  if (Length > sizeof(CSR_API_MESSAGE))
+    {
+      Request->Header.u1.s1.TotalLength = Length;
+      Request->Header.u1.s1.DataLength = Length - sizeof(PORT_MESSAGE);
+    }
+  Request->Status = STATUS_SUCCESS;
+
   return Request->Status;
 }
 
   return Request->Status;
 }
 
@@ -2646,8 +2669,6 @@ CSR_API(CsrReadConsoleOutputChar)
   Request->Status = STATUS_SUCCESS;
   Request->Data.ReadConsoleOutputCharRequest.EndCoord.X = Xpos - Buff->ShowX;
   Request->Data.ReadConsoleOutputCharRequest.EndCoord.Y = (Ypos - Buff->ShowY + Buff->MaxY) % Buff->MaxY;
   Request->Status = STATUS_SUCCESS;
   Request->Data.ReadConsoleOutputCharRequest.EndCoord.X = Xpos - Buff->ShowX;
   Request->Data.ReadConsoleOutputCharRequest.EndCoord.Y = (Ypos - Buff->ShowY + Buff->MaxY) % Buff->MaxY;
-  Request->Header.u1.s1.TotalLength += Request->Data.ReadConsoleOutputCharRequest.NumCharsToRead;
-  Request->Header.u1.s1.DataLength += Request->Data.ReadConsoleOutputCharRequest.NumCharsToRead;
 
   ConioUnlockScreenBuffer(Buff);
   if (NULL != Console)
 
   ConioUnlockScreenBuffer(Buff);
   if (NULL != Console)
@@ -2656,6 +2677,11 @@ CSR_API(CsrReadConsoleOutputChar)
     }
 
   Request->Data.ReadConsoleOutputCharRequest.CharsRead = (DWORD)((ULONG_PTR)ReadBuffer - (ULONG_PTR)Request->Data.ReadConsoleOutputCharRequest.String) / CharSize;
     }
 
   Request->Data.ReadConsoleOutputCharRequest.CharsRead = (DWORD)((ULONG_PTR)ReadBuffer - (ULONG_PTR)Request->Data.ReadConsoleOutputCharRequest.String) / CharSize;
+  if (Request->Data.ReadConsoleOutputCharRequest.CharsRead * CharSize + CSR_API_MESSAGE_HEADER_SIZE(CSRSS_READ_CONSOLE_OUTPUT_CHAR) > sizeof(CSR_API_MESSAGE))
+    {
+      Request->Header.u1.s1.TotalLength = Request->Data.ReadConsoleOutputCharRequest.CharsRead * CharSize + CSR_API_MESSAGE_HEADER_SIZE(CSRSS_READ_CONSOLE_OUTPUT_CHAR);
+      Request->Header.u1.s1.DataLength = Request->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+    }
 
   return Request->Status;
 }
 
   return Request->Status;
 }
@@ -2668,6 +2694,7 @@ CSR_API(CsrReadConsoleOutputAttrib)
   DWORD Xpos, Ypos;
   PWORD ReadBuffer;
   DWORD i;
   DWORD Xpos, Ypos;
   PWORD ReadBuffer;
   DWORD i;
+  DWORD CurrentLength;
 
   DPRINT("CsrReadConsoleOutputAttrib\n");
 
 
   DPRINT("CsrReadConsoleOutputAttrib\n");
 
@@ -2708,11 +2735,17 @@ CSR_API(CsrReadConsoleOutputAttrib)
   Request->Status = STATUS_SUCCESS;
   Request->Data.ReadConsoleOutputAttribRequest.EndCoord.X = Xpos - Buff->ShowX;
   Request->Data.ReadConsoleOutputAttribRequest.EndCoord.Y = (Ypos - Buff->ShowY + Buff->MaxY) % Buff->MaxY;
   Request->Status = STATUS_SUCCESS;
   Request->Data.ReadConsoleOutputAttribRequest.EndCoord.X = Xpos - Buff->ShowX;
   Request->Data.ReadConsoleOutputAttribRequest.EndCoord.Y = (Ypos - Buff->ShowY + Buff->MaxY) % Buff->MaxY;
-  Request->Header.u1.s1.TotalLength += Request->Data.ReadConsoleOutputAttribRequest.NumAttrsToRead;
-  Request->Header.u1.s1.DataLength += Request->Data.ReadConsoleOutputAttribRequest.NumAttrsToRead;
 
   ConioUnlockScreenBuffer(Buff);
 
 
   ConioUnlockScreenBuffer(Buff);
 
+  CurrentLength = CSR_API_MESSAGE_HEADER_SIZE(CSRSS_READ_CONSOLE_OUTPUT_ATTRIB)
+                     + Request->Data.ReadConsoleOutputAttribRequest.NumAttrsToRead * sizeof(WORD);
+  if (CurrentLength > sizeof(CSR_API_MESSAGE))
+    {
+      Request->Header.u1.s1.TotalLength = CurrentLength;
+      Request->Header.u1.s1.DataLength = CurrentLength - sizeof(PORT_MESSAGE);
+    }
+
   return Request->Status;
 }
 
   return Request->Status;
 }
 
@@ -3204,14 +3237,14 @@ CSR_API(CsrGetProcessList)
   PCSRSS_CONSOLE Console;
   PCSRSS_PROCESS_DATA current;
   PLIST_ENTRY current_entry;
   PCSRSS_CONSOLE Console;
   PCSRSS_PROCESS_DATA current;
   PLIST_ENTRY current_entry;
-  ULONG nItems, nCopied;
+  ULONG nItems, nCopied, Length;
   NTSTATUS Status;
 
   DPRINT("CsrGetProcessList\n");
 
   Buffer = Request->Data.GetProcessListRequest.ProcessId;
   Request->Header.u1.s1.TotalLength = sizeof(CSR_API_MESSAGE);
   NTSTATUS Status;
 
   DPRINT("CsrGetProcessList\n");
 
   Buffer = Request->Data.GetProcessListRequest.ProcessId;
   Request->Header.u1.s1.TotalLength = sizeof(CSR_API_MESSAGE);
-  Request->Header.u1.s1.DataLength = Request->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+  Request->Header.u1.s1.DataLength = sizeof(CSR_API_MESSAGE) - sizeof(PORT_MESSAGE);
 
   nItems = nCopied = 0;
   Request->Data.GetProcessListRequest.nProcessIdsCopied = 0;
 
   nItems = nCopied = 0;
   Request->Data.GetProcessListRequest.nProcessIdsCopied = 0;
@@ -3242,6 +3275,12 @@ CSR_API(CsrGetProcessList)
   Request->Data.GetProcessListRequest.nProcessIdsCopied = nCopied;
   Request->Data.GetProcessListRequest.nProcessIdsTotal = nItems;
 
   Request->Data.GetProcessListRequest.nProcessIdsCopied = nCopied;
   Request->Data.GetProcessListRequest.nProcessIdsTotal = nItems;
 
+  Length = CSR_API_MESSAGE_HEADER_SIZE(CSRSS_GET_PROCESS_LIST) + nCopied * sizeof(HANDLE);
+  if (Length > sizeof(CSR_API_MESSAGE))
+  {
+     Request->Header.u1.s1.TotalLength = Length;
+     Request->Header.u1.s1.DataLength = Length - sizeof(PORT_MESSAGE);
+  }
   return Request->Status = STATUS_SUCCESS;
 }
 
   return Request->Status = STATUS_SUCCESS;
 }