Sync with trunk head (r49139)
[reactos.git] / dll / win32 / ole32 / storage32.c
index 185b8b0..5c58201 100644 (file)
@@ -856,7 +856,7 @@ static HRESULT WINAPI StorageBaseImpl_RenameElement(
     return STG_E_FILENOTFOUND;
   }
 
-  return S_OK;
+  return StorageBaseImpl_Flush(This);
 }
 
 /************************************************************************
@@ -1011,7 +1011,7 @@ static HRESULT WINAPI StorageBaseImpl_CreateStream(
     return STG_E_INSUFFICIENTMEMORY;
   }
 
-  return S_OK;
+  return StorageBaseImpl_Flush(This);
 }
 
 /************************************************************************
@@ -1047,6 +1047,9 @@ static HRESULT WINAPI StorageBaseImpl_SetClass(
                                          &currentEntry);
   }
 
+  if (SUCCEEDED(hRes))
+    hRes = StorageBaseImpl_Flush(This);
+
   return hRes;
 }
 
@@ -1203,6 +1206,8 @@ static HRESULT WINAPI StorageBaseImpl_CreateStorage(
     return hr;
   }
 
+  if (SUCCEEDED(hr))
+    hr = StorageBaseImpl_Flush(This);
 
   return S_OK;
 }
@@ -1916,6 +1921,9 @@ static HRESULT WINAPI StorageBaseImpl_DestroyElement(
   if (SUCCEEDED(hr))
     StorageBaseImpl_DestroyDirEntry(This, entryToDeleteRef);
 
+  if (SUCCEEDED(hr))
+    hr = StorageBaseImpl_Flush(This);
+
   return hr;
 }
 
@@ -2849,7 +2857,10 @@ end:
     *result = NULL;
   }
   else
+  {
+    StorageImpl_Flush((StorageBaseImpl*)This);
     *result = This;
+  }
 
   return hr;
 }
@@ -2888,8 +2899,26 @@ static void StorageImpl_Destroy(StorageBaseImpl* iface)
 static HRESULT StorageImpl_Flush(StorageBaseImpl* iface)
 {
   StorageImpl *This = (StorageImpl*) iface;
+  int i;
+  HRESULT hr;
+  TRACE("(%p)\n", This);
+
+  hr = BlockChainStream_Flush(This->smallBlockRootChain);
+
+  if (SUCCEEDED(hr))
+    hr = BlockChainStream_Flush(This->rootBlockChain);
 
-  return ILockBytes_Flush(This->lockBytes);
+  if (SUCCEEDED(hr))
+    hr = BlockChainStream_Flush(This->smallBlockDepotChain);
+
+  for (i=0; SUCCEEDED(hr) && i<BLOCKCHAIN_CACHE_SIZE; i++)
+    if (This->blockChainCache[i])
+      hr = BlockChainStream_Flush(This->blockChainCache[i]);
+
+  if (SUCCEEDED(hr))
+    hr = ILockBytes_Flush(This->lockBytes);
+
+  return hr;
 }
 
 /******************************************************************************
@@ -3846,13 +3875,20 @@ static BOOL StorageImpl_ReadBigBlock(
   void*          buffer)
 {
   ULARGE_INTEGER ulOffset;
-  DWORD  read;
+  DWORD  read=0;
 
   ulOffset.u.HighPart = 0;
   ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This, blockIndex);
 
   StorageImpl_ReadAt(This, ulOffset, buffer, This->bigBlockSize, &read);
-  return (read == This->bigBlockSize);
+
+  if (read && read < This->bigBlockSize)
+  {
+    /* File ends during this block; fill the rest with 0's. */
+    memset((LPBYTE)buffer+read, 0, This->bigBlockSize-read);
+  }
+
+  return (read != 0);
 }
 
 static BOOL StorageImpl_ReadDWordFromBigBlock(
@@ -5812,6 +5848,53 @@ ULONG BlockChainStream_GetSectorOfOffset(BlockChainStream *This, ULONG offset)
   return This->indexCache[min_run].firstSector + offset - This->indexCache[min_run].firstOffset;
 }
 
+HRESULT BlockChainStream_GetBlockAtOffset(BlockChainStream *This,
+    ULONG index, BlockChainBlock **block, ULONG *sector, BOOL create)
+{
+  BlockChainBlock *result=NULL;
+  int i;
+
+  for (i=0; i<2; i++)
+    if (This->cachedBlocks[i].index == index)
+    {
+      *sector = This->cachedBlocks[i].sector;
+      *block = &This->cachedBlocks[i];
+      return S_OK;
+    }
+
+  *sector = BlockChainStream_GetSectorOfOffset(This, index);
+  if (*sector == BLOCK_END_OF_CHAIN)
+    return STG_E_DOCFILECORRUPT;
+
+  if (create)
+  {
+    if (This->cachedBlocks[0].index == 0xffffffff)
+      result = &This->cachedBlocks[0];
+    else if (This->cachedBlocks[1].index == 0xffffffff)
+      result = &This->cachedBlocks[1];
+    else
+    {
+      result = &This->cachedBlocks[This->blockToEvict++];
+      if (This->blockToEvict == 2)
+        This->blockToEvict = 0;
+    }
+
+    if (result->dirty)
+    {
+      if (!StorageImpl_WriteBigBlock(This->parentStorage, result->sector, result->data))
+        return STG_E_WRITEFAULT;
+      result->dirty = 0;
+    }
+
+    result->read = 0;
+    result->index = index;
+    result->sector = *sector;
+  }
+
+  *block = result;
+  return S_OK;
+}
+
 BlockChainStream* BlockChainStream_Construct(
   StorageImpl* parentStorage,
   ULONG*         headOfStreamPlaceHolder,
@@ -5827,6 +5910,11 @@ BlockChainStream* BlockChainStream_Construct(
   newStream->indexCache              = NULL;
   newStream->indexCacheLen           = 0;
   newStream->indexCacheSize          = 0;
+  newStream->cachedBlocks[0].index = 0xffffffff;
+  newStream->cachedBlocks[0].dirty = 0;
+  newStream->cachedBlocks[1].index = 0xffffffff;
+  newStream->cachedBlocks[1].dirty = 0;
+  newStream->blockToEvict          = 0;
 
   if (FAILED(BlockChainStream_UpdateIndexCache(newStream)))
   {
@@ -5838,10 +5926,30 @@ BlockChainStream* BlockChainStream_Construct(
   return newStream;
 }
 
+HRESULT BlockChainStream_Flush(BlockChainStream* This)
+{
+  int i;
+  if (!This) return S_OK;
+  for (i=0; i<2; i++)
+  {
+    if (This->cachedBlocks[i].dirty)
+    {
+      if (StorageImpl_WriteBigBlock(This->parentStorage, This->cachedBlocks[i].sector, This->cachedBlocks[i].data))
+        This->cachedBlocks[i].dirty = 0;
+      else
+        return STG_E_WRITEFAULT;
+    }
+  }
+  return S_OK;
+}
+
 void BlockChainStream_Destroy(BlockChainStream* This)
 {
   if (This)
+  {
+    BlockChainStream_Flush(This);
     HeapFree(GetProcessHeap(), 0, This->indexCache);
+  }
   HeapFree(GetProcessHeap(), 0, This);
 }
 
@@ -5907,6 +6015,8 @@ HRESULT BlockChainStream_ReadAt(BlockChainStream* This,
   ULONG blockIndex;
   BYTE* bufferWalker;
   ULARGE_INTEGER stream_size;
+  HRESULT hr;
+  BlockChainBlock *cachedBlock;
 
   TRACE("(%p)-> %i %p %i %p\n",This, offset.u.LowPart, buffer, size, bytesRead);
 
@@ -5928,32 +6038,50 @@ HRESULT BlockChainStream_ReadAt(BlockChainStream* This,
    */
   bufferWalker = buffer;
 
-  while ( (size > 0) && (blockIndex != BLOCK_END_OF_CHAIN) )
+  while (size > 0)
   {
     ULARGE_INTEGER ulOffset;
     DWORD bytesReadAt;
+
     /*
      * Calculate how many bytes we can copy from this big block.
      */
     bytesToReadInBuffer =
       min(This->parentStorage->bigBlockSize - offsetInBlock, size);
 
-     TRACE("block %i\n",blockIndex);
-     ulOffset.u.HighPart = 0;
-     ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
-                             offsetInBlock;
+    hr = BlockChainStream_GetBlockAtOffset(This, blockNoInSequence, &cachedBlock, &blockIndex, size == bytesToReadInBuffer);
 
-     StorageImpl_ReadAt(This->parentStorage,
-         ulOffset,
-         bufferWalker,
-         bytesToReadInBuffer,
-         &bytesReadAt);
-    /*
-     * Step to the next big block.
-     */
-    if( size > bytesReadAt && FAILED(StorageImpl_GetNextBlockInChain(This->parentStorage, blockIndex, &blockIndex)))
-      return STG_E_DOCFILECORRUPT;
+    if (FAILED(hr))
+      return hr;
 
+    if (!cachedBlock)
+    {
+      /* Not in cache, and we're going to read past the end of the block. */
+      ulOffset.u.HighPart = 0;
+      ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
+                               offsetInBlock;
+
+      StorageImpl_ReadAt(This->parentStorage,
+           ulOffset,
+           bufferWalker,
+           bytesToReadInBuffer,
+           &bytesReadAt);
+    }
+    else
+    {
+      if (!cachedBlock->read)
+      {
+        if (!StorageImpl_ReadBigBlock(This->parentStorage, cachedBlock->sector, cachedBlock->data))
+          return STG_E_READFAULT;
+
+        cachedBlock->read = 1;
+      }
+
+      memcpy(bufferWalker, cachedBlock->data+offsetInBlock, bytesToReadInBuffer);
+      bytesReadAt = bytesToReadInBuffer;
+    }
+
+    blockNoInSequence++;
     bufferWalker += bytesReadAt;
     size         -= bytesReadAt;
     *bytesRead   += bytesReadAt;
@@ -5983,51 +6111,61 @@ HRESULT BlockChainStream_WriteAt(BlockChainStream* This,
   ULONG bytesToWrite;
   ULONG blockIndex;
   const BYTE* bufferWalker;
-
-  /*
-   * Find the first block in the stream that contains part of the buffer.
-   */
-  blockIndex = BlockChainStream_GetSectorOfOffset(This, blockNoInSequence);
-
-  /* BlockChainStream_SetSize should have already been called to ensure we have
-   * enough blocks in the chain to write into */
-  if (blockIndex == BLOCK_END_OF_CHAIN)
-  {
-    ERR("not enough blocks in chain to write data\n");
-    return STG_E_DOCFILECORRUPT;
-  }
+  HRESULT hr;
+  BlockChainBlock *cachedBlock;
 
   *bytesWritten   = 0;
   bufferWalker = buffer;
 
-  while ( (size > 0) && (blockIndex != BLOCK_END_OF_CHAIN) )
+  while (size > 0)
   {
     ULARGE_INTEGER ulOffset;
     DWORD bytesWrittenAt;
+
     /*
-     * Calculate how many bytes we can copy from this big block.
+     * Calculate how many bytes we can copy to this big block.
      */
     bytesToWrite =
       min(This->parentStorage->bigBlockSize - offsetInBlock, size);
 
-    TRACE("block %i\n",blockIndex);
-    ulOffset.u.HighPart = 0;
-    ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
-                             offsetInBlock;
+    hr = BlockChainStream_GetBlockAtOffset(This, blockNoInSequence, &cachedBlock, &blockIndex, size == bytesToWrite);
 
-    StorageImpl_WriteAt(This->parentStorage,
-         ulOffset,
-         bufferWalker,
-         bytesToWrite,
-         &bytesWrittenAt);
+    /* BlockChainStream_SetSize should have already been called to ensure we have
+     * enough blocks in the chain to write into */
+    if (FAILED(hr))
+    {
+      ERR("not enough blocks in chain to write data\n");
+      return hr;
+    }
 
-    /*
-     * Step to the next big block.
-     */
-    if(size > bytesWrittenAt && FAILED(StorageImpl_GetNextBlockInChain(This->parentStorage, blockIndex,
-                                             &blockIndex)))
-      return STG_E_DOCFILECORRUPT;
+    if (!cachedBlock)
+    {
+      /* Not in cache, and we're going to write past the end of the block. */
+      ulOffset.u.HighPart = 0;
+      ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
+                               offsetInBlock;
+
+      StorageImpl_WriteAt(This->parentStorage,
+           ulOffset,
+           bufferWalker,
+           bytesToWrite,
+           &bytesWrittenAt);
+    }
+    else
+    {
+      if (!cachedBlock->read && bytesToWrite != This->parentStorage->bigBlockSize)
+      {
+        if (!StorageImpl_ReadBigBlock(This->parentStorage, cachedBlock->sector, cachedBlock->data))
+          return STG_E_READFAULT;
+      }
 
+      memcpy(cachedBlock->data+offsetInBlock, bufferWalker, bytesToWrite);
+      bytesWrittenAt = bytesToWrite;
+      cachedBlock->read = 1;
+      cachedBlock->dirty = 1;
+    }
+
+    blockNoInSequence++;
     bufferWalker  += bytesWrittenAt;
     size          -= bytesWrittenAt;
     *bytesWritten += bytesWrittenAt;
@@ -6050,6 +6188,7 @@ static BOOL BlockChainStream_Shrink(BlockChainStream* This,
 {
   ULONG blockIndex;
   ULONG numBlocks;
+  int i;
 
   /*
    * Figure out how many blocks are needed to contain the new size
@@ -6117,6 +6256,18 @@ static BOOL BlockChainStream_Shrink(BlockChainStream* This,
       last_run->lastOffset--;
   }
 
+  /*
+   * Reset the last accessed block cache.
+   */
+  for (i=0; i<2; i++)
+  {
+    if (This->cachedBlocks[i].index >= numBlocks)
+    {
+      This->cachedBlocks[i].index = 0xffffffff;
+      This->cachedBlocks[i].dirty = 0;
+    }
+  }
+
   return TRUE;
 }