[CMAKE]
[reactos.git] / dll / win32 / ole32 / storage32.c
index 98df144..5c58201 100644 (file)
@@ -92,6 +92,7 @@ static StorageInternalImpl* StorageInternalImpl_Construct(StorageBaseImpl* paren
                                                           DWORD openFlags, DirRef storageDirEntry);
 static void StorageImpl_Destroy(StorageBaseImpl* iface);
 static void StorageImpl_Invalidate(StorageBaseImpl* iface);
+static HRESULT StorageImpl_Flush(StorageBaseImpl* iface);
 static BOOL StorageImpl_ReadBigBlock(StorageImpl* This, ULONG blockIndex, void* buffer);
 static BOOL StorageImpl_WriteBigBlock(StorageImpl* This, ULONG blockIndex, const void* buffer);
 static void StorageImpl_SetNextBlockInChain(StorageImpl* This, ULONG blockIndex, ULONG nextBlock);
@@ -327,7 +328,7 @@ static HRESULT StorageImpl_ReadAt(StorageImpl* This,
   ULONG          size,
   ULONG*         bytesRead)
 {
-    return BIGBLOCKFILE_ReadAt(This->bigBlockFile,offset,buffer,size,bytesRead);
+    return ILockBytes_ReadAt(This->lockBytes,offset,buffer,size,bytesRead);
 }
 
 static HRESULT StorageImpl_WriteAt(StorageImpl* This,
@@ -336,7 +337,7 @@ static HRESULT StorageImpl_WriteAt(StorageImpl* This,
   const ULONG    size,
   ULONG*         bytesWritten)
 {
-    return BIGBLOCKFILE_WriteAt(This->bigBlockFile,offset,buffer,size,bytesWritten);
+    return ILockBytes_WriteAt(This->lockBytes,offset,buffer,size,bytesWritten);
 }
 
 /************************************************************************
@@ -855,7 +856,7 @@ static HRESULT WINAPI StorageBaseImpl_RenameElement(
     return STG_E_FILENOTFOUND;
   }
 
-  return S_OK;
+  return StorageBaseImpl_Flush(This);
 }
 
 /************************************************************************
@@ -1010,7 +1011,7 @@ static HRESULT WINAPI StorageBaseImpl_CreateStream(
     return STG_E_INSUFFICIENTMEMORY;
   }
 
-  return S_OK;
+  return StorageBaseImpl_Flush(This);
 }
 
 /************************************************************************
@@ -1046,6 +1047,9 @@ static HRESULT WINAPI StorageBaseImpl_SetClass(
                                          &currentEntry);
   }
 
+  if (SUCCEEDED(hRes))
+    hRes = StorageBaseImpl_Flush(This);
+
   return hRes;
 }
 
@@ -1202,6 +1206,8 @@ static HRESULT WINAPI StorageBaseImpl_CreateStorage(
     return hr;
   }
 
+  if (SUCCEEDED(hr))
+    hr = StorageBaseImpl_Flush(This);
 
   return S_OK;
 }
@@ -1306,6 +1312,8 @@ static HRESULT StorageImpl_CreateDirEntry(
         entryIndex,
         emptyData);
     }
+
+    StorageImpl_SaveFileHeader(storage);
   }
 
   UpdateRawDirEntry(currentData, newData);
@@ -1813,16 +1821,15 @@ static HRESULT WINAPI StorageBaseImpl_MoveElementTo(
  * Ensures that any changes made to a storage object open in transacted mode
  * are reflected in the parent storage
  *
- * NOTES
- *  Wine doesn't implement transacted mode, which seems to be a basic
- *  optimization, so we can ignore this stub for now.
+ * In a non-transacted mode, this ensures all cached writes are completed.
  */
 static HRESULT WINAPI StorageImpl_Commit(
   IStorage*   iface,
   DWORD         grfCommitFlags)/* [in] */
 {
-  FIXME("(%p %d): stub\n", iface, grfCommitFlags);
-  return S_OK;
+  StorageBaseImpl* const base=(StorageBaseImpl*)iface;
+  TRACE("(%p %d)\n", iface, grfCommitFlags);
+  return StorageBaseImpl_Flush(base);
 }
 
 /*************************************************************************
@@ -1914,6 +1921,9 @@ static HRESULT WINAPI StorageBaseImpl_DestroyElement(
   if (SUCCEEDED(hr))
     StorageBaseImpl_DestroyDirEntry(This, entryToDeleteRef);
 
+  if (SUCCEEDED(hr))
+    hr = StorageBaseImpl_Flush(This);
+
   return hr;
 }
 
@@ -2581,6 +2591,19 @@ static HRESULT StorageImpl_StreamLink(StorageBaseImpl *base, DirRef dst,
   return hr;
 }
 
+static HRESULT StorageImpl_GetFilename(StorageBaseImpl* iface, LPWSTR *result)
+{
+  StorageImpl *This = (StorageImpl*) iface;
+  STATSTG statstg;
+  HRESULT hr;
+
+  hr = ILockBytes_Stat(This->lockBytes, &statstg, 0);
+
+  *result = statstg.pwcsName;
+
+  return hr;
+}
+
 /*
  * Virtual function table for the IStorage32Impl class.
  */
@@ -2610,6 +2633,8 @@ static const StorageBaseImplVtbl StorageImpl_BaseVtbl =
 {
   StorageImpl_Destroy,
   StorageImpl_Invalidate,
+  StorageImpl_Flush,
+  StorageImpl_GetFilename,
   StorageImpl_CreateDirEntry,
   StorageImpl_BaseWriteDirEntry,
   StorageImpl_BaseReadDirEntry,
@@ -2634,7 +2659,6 @@ static HRESULT StorageImpl_Construct(
   HRESULT     hr = S_OK;
   DirEntry currentEntry;
   DirRef      currentEntryRef;
-  WCHAR fullpath[MAX_PATH];
 
   if ( FAILED( validateSTGM(openFlags) ))
     return STG_E_INVALIDFLAG;
@@ -2658,40 +2682,22 @@ static HRESULT StorageImpl_Construct(
 
   This->base.reverted = 0;
 
-  This->hFile = hFile;
-
-  if(pwcsName) {
-      if (!GetFullPathNameW(pwcsName, MAX_PATH, fullpath, NULL))
-      {
-        lstrcpynW(fullpath, pwcsName, MAX_PATH);
-      }
-      This->pwcsName = HeapAlloc(GetProcessHeap(), 0,
-                                (lstrlenW(fullpath)+1)*sizeof(WCHAR));
-      if (!This->pwcsName)
-      {
-         hr = STG_E_INSUFFICIENTMEMORY;
-         goto end;
-      }
-      strcpyW(This->pwcsName, fullpath);
-      This->base.filename = This->pwcsName;
-  }
-
   /*
    * Initialize the big block cache.
    */
   This->bigBlockSize   = sector_size;
   This->smallBlockSize = DEF_SMALL_BLOCK_SIZE;
-  This->bigBlockFile   = BIGBLOCKFILE_Construct(hFile,
-                                                pLkbyt,
-                                                openFlags,
-                                                fileBased);
-
-  if (This->bigBlockFile == 0)
+  if (hFile)
+    hr = FileLockBytesImpl_Construct(hFile, openFlags, pwcsName, &This->lockBytes);
+  else
   {
-    hr = E_FAIL;
-    goto end;
+    This->lockBytes = pLkbyt;
+    ILockBytes_AddRef(pLkbyt);
   }
 
+  if (FAILED(hr))
+    goto end;
+
   if (create)
   {
     ULARGE_INTEGER size;
@@ -2727,7 +2733,7 @@ static HRESULT StorageImpl_Construct(
      */
     size.u.HighPart = 0;
     size.u.LowPart  = This->bigBlockSize * 3;
-    BIGBLOCKFILE_SetSize(This->bigBlockFile, size);
+    ILockBytes_SetSize(This->lockBytes, size);
 
     /*
      * Initialize the big block depot
@@ -2851,7 +2857,10 @@ end:
     *result = NULL;
   }
   else
+  {
+    StorageImpl_Flush((StorageBaseImpl*)This);
     *result = This;
+  }
 
   return hr;
 }
@@ -2871,9 +2880,9 @@ static void StorageImpl_Destroy(StorageBaseImpl* iface)
   int i;
   TRACE("(%p)\n", This);
 
-  StorageImpl_Invalidate(iface);
+  StorageImpl_Flush(iface);
 
-  HeapFree(GetProcessHeap(), 0, This->pwcsName);
+  StorageImpl_Invalidate(iface);
 
   BlockChainStream_Destroy(This->smallBlockRootChain);
   BlockChainStream_Destroy(This->rootBlockChain);
@@ -2882,11 +2891,36 @@ static void StorageImpl_Destroy(StorageBaseImpl* iface)
   for (i=0; i<BLOCKCHAIN_CACHE_SIZE; i++)
     BlockChainStream_Destroy(This->blockChainCache[i]);
 
-  if (This->bigBlockFile)
-    BIGBLOCKFILE_Destructor(This->bigBlockFile);
+  if (This->lockBytes)
+    ILockBytes_Release(This->lockBytes);
   HeapFree(GetProcessHeap(), 0, This);
 }
 
+static HRESULT StorageImpl_Flush(StorageBaseImpl* iface)
+{
+  StorageImpl *This = (StorageImpl*) iface;
+  int i;
+  HRESULT hr;
+  TRACE("(%p)\n", This);
+
+  hr = BlockChainStream_Flush(This->smallBlockRootChain);
+
+  if (SUCCEEDED(hr))
+    hr = BlockChainStream_Flush(This->rootBlockChain);
+
+  if (SUCCEEDED(hr))
+    hr = BlockChainStream_Flush(This->smallBlockDepotChain);
+
+  for (i=0; SUCCEEDED(hr) && i<BLOCKCHAIN_CACHE_SIZE; i++)
+    if (This->blockChainCache[i])
+      hr = BlockChainStream_Flush(This->blockChainCache[i]);
+
+  if (SUCCEEDED(hr))
+    hr = ILockBytes_Flush(This->lockBytes);
+
+  return hr;
+}
+
 /******************************************************************************
  *      Storage32Impl_GetNextFreeBigBlock
  *
@@ -2906,6 +2940,7 @@ static ULONG StorageImpl_GetNextFreeBigBlock(
   int   depotIndex        = 0;
   ULONG freeBlock         = BLOCK_UNUSED;
   ULARGE_INTEGER neededSize;
+  STATSTG statstg;
 
   depotIndex = This->prevFreeBlock / blocksPerDepot;
   depotBlockOffset = (This->prevFreeBlock % blocksPerDepot) * sizeof(ULONG);
@@ -3020,7 +3055,11 @@ static ULONG StorageImpl_GetNextFreeBigBlock(
    * make sure that the block physically exists before using it
    */
   neededSize.QuadPart = StorageImpl_GetBigBlockOffset(This, freeBlock)+This->bigBlockSize;
-  BIGBLOCKFILE_Expand(This->bigBlockFile, neededSize);
+
+  ILockBytes_Stat(This->lockBytes, &statstg, STATFLAG_NONAME);
+
+  if (neededSize.QuadPart > statstg.cbSize.QuadPart)
+    ILockBytes_SetSize(This->lockBytes, neededSize);
 
   This->prevFreeBlock = freeBlock;
 
@@ -3465,6 +3504,7 @@ static void StorageImpl_SaveFileHeader(
   HRESULT hr;
   ULARGE_INTEGER offset;
   DWORD bytes_read, bytes_written;
+  DWORD major_version, dirsectorcount;
 
   /*
    * Get a pointer to the big block of data containing the header.
@@ -3475,6 +3515,16 @@ static void StorageImpl_SaveFileHeader(
   if (SUCCEEDED(hr) && bytes_read != HEADER_SIZE)
     hr = STG_E_FILENOTFOUND;
 
+  if (This->bigBlockSizeBits == 0x9)
+    major_version = 3;
+  else if (This->bigBlockSizeBits == 0xc)
+    major_version = 4;
+  else
+  {
+    ERR("invalid big block shift 0x%x\n", This->bigBlockSizeBits);
+    major_version = 4;
+  }
+
   /*
    * If the block read failed, the file is probably new.
    */
@@ -3489,18 +3539,26 @@ static void StorageImpl_SaveFileHeader(
      * Initialize the magic number.
      */
     memcpy(headerBigBlock, STORAGE_magic, sizeof(STORAGE_magic));
-
-    /*
-     * And a bunch of things we don't know what they mean
-     */
-    StorageUtl_WriteWord(headerBigBlock,  0x18, 0x3b);
-    StorageUtl_WriteWord(headerBigBlock,  0x1a, 0x3);
-    StorageUtl_WriteWord(headerBigBlock,  0x1c, (WORD)-2);
   }
 
   /*
    * Write the information to the header.
    */
+  StorageUtl_WriteWord(
+    headerBigBlock,
+    OFFSET_MINORVERSION,
+    0x3e);
+
+  StorageUtl_WriteWord(
+    headerBigBlock,
+    OFFSET_MAJORVERSION,
+    major_version);
+
+  StorageUtl_WriteWord(
+    headerBigBlock,
+    OFFSET_BYTEORDERMARKER,
+    (WORD)-2);
+
   StorageUtl_WriteWord(
     headerBigBlock,
     OFFSET_BIGBLOCKSIZEBITS,
@@ -3511,6 +3569,23 @@ static void StorageImpl_SaveFileHeader(
     OFFSET_SMALLBLOCKSIZEBITS,
     This->smallBlockSizeBits);
 
+  if (major_version >= 4)
+  {
+    if (This->rootBlockChain)
+      dirsectorcount = BlockChainStream_GetCount(This->rootBlockChain);
+    else
+      /* This file is being created, and it will start out with one block. */
+      dirsectorcount = 1;
+  }
+  else
+    /* This field must be 0 in versions older than 4 */
+    dirsectorcount = 0;
+
+  StorageUtl_WriteDWord(
+    headerBigBlock,
+    OFFSET_DIRSECTORCOUNT,
+    dirsectorcount);
+
   StorageUtl_WriteDWord(
     headerBigBlock,
     OFFSET_BBDEPOTCOUNT,
@@ -3800,13 +3875,20 @@ static BOOL StorageImpl_ReadBigBlock(
   void*          buffer)
 {
   ULARGE_INTEGER ulOffset;
-  DWORD  read;
+  DWORD  read=0;
 
   ulOffset.u.HighPart = 0;
   ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This, blockIndex);
 
   StorageImpl_ReadAt(This, ulOffset, buffer, This->bigBlockSize, &read);
-  return (read == This->bigBlockSize);
+
+  if (read && read < This->bigBlockSize)
+  {
+    /* File ends during this block; fill the rest with 0's. */
+    memset((LPBYTE)buffer+read, 0, This->bigBlockSize-read);
+  }
+
+  return (read != 0);
 }
 
 static BOOL StorageImpl_ReadDWordFromBigBlock(
@@ -4489,9 +4571,12 @@ static HRESULT WINAPI TransactedSnapshotImpl_Commit(
   else
     dir_root_ref = This->entries[root_entry->data.dirRootEntry].newTransactedParentEntry;
 
+  hr = StorageBaseImpl_Flush(This->transactedParent);
+
   /* Update the storage to use the new data in one step. */
-  hr = StorageBaseImpl_ReadDirEntry(This->transactedParent,
-    root_entry->transactedParentEntry, &data);
+  if (SUCCEEDED(hr))
+    hr = StorageBaseImpl_ReadDirEntry(This->transactedParent,
+      root_entry->transactedParentEntry, &data);
 
   if (SUCCEEDED(hr))
   {
@@ -4504,6 +4589,11 @@ static HRESULT WINAPI TransactedSnapshotImpl_Commit(
       root_entry->transactedParentEntry, &data);
   }
 
+  /* Try to flush after updating the root storage, but if the flush fails, keep
+   * going, on the theory that it'll either succeed later or the subsequent
+   * writes will fail. */
+  StorageBaseImpl_Flush(This->transactedParent);
+
   if (SUCCEEDED(hr))
   {
     /* Destroy the old now-orphaned data. */
@@ -4543,6 +4633,9 @@ static HRESULT WINAPI TransactedSnapshotImpl_Commit(
     TransactedSnapshotImpl_DestroyTemporaryCopy(This, DIRENTRY_NULL);
   }
 
+  if (SUCCEEDED(hr))
+    hr = StorageBaseImpl_Flush(This->transactedParent);
+
   return hr;
 }
 
@@ -4606,6 +4699,19 @@ static void TransactedSnapshotImpl_Destroy( StorageBaseImpl *iface)
   HeapFree(GetProcessHeap(), 0, This);
 }
 
+static HRESULT TransactedSnapshotImpl_Flush(StorageBaseImpl* iface)
+{
+  /* We only need to flush when committing. */
+  return S_OK;
+}
+
+static HRESULT TransactedSnapshotImpl_GetFilename(StorageBaseImpl* iface, LPWSTR *result)
+{
+  TransactedSnapshotImpl* This = (TransactedSnapshotImpl*) iface;
+
+  return StorageBaseImpl_GetFilename(This->transactedParent, result);
+}
+
 static HRESULT TransactedSnapshotImpl_CreateDirEntry(StorageBaseImpl *base,
   const DirEntry *newData, DirRef *index)
 {
@@ -4853,6 +4959,8 @@ static const StorageBaseImplVtbl TransactedSnapshotImpl_BaseVtbl =
 {
   TransactedSnapshotImpl_Destroy,
   TransactedSnapshotImpl_Invalidate,
+  TransactedSnapshotImpl_Flush,
+  TransactedSnapshotImpl_GetFilename,
   TransactedSnapshotImpl_CreateDirEntry,
   TransactedSnapshotImpl_WriteDirEntry,
   TransactedSnapshotImpl_ReadDirEntry,
@@ -4886,8 +4994,6 @@ static HRESULT TransactedSnapshotImpl_Construct(StorageBaseImpl *parentStorage,
 
     (*result)->base.openFlags = parentStorage->openFlags;
 
-    (*result)->base.filename = parentStorage->filename;
-
     /* Create a new temporary storage to act as the scratch file. */
     hr = StgCreateDocfile(NULL, STGM_READWRITE|STGM_SHARE_EXCLUSIVE|STGM_CREATE,
         0, (IStorage**)&(*result)->scratch);
@@ -5000,6 +5106,20 @@ static void StorageInternalImpl_Destroy( StorageBaseImpl *iface)
   HeapFree(GetProcessHeap(), 0, This);
 }
 
+static HRESULT StorageInternalImpl_Flush(StorageBaseImpl* iface)
+{
+  StorageInternalImpl* This = (StorageInternalImpl*) iface;
+
+  return StorageBaseImpl_Flush(This->parentStorage);
+}
+
+static HRESULT StorageInternalImpl_GetFilename(StorageBaseImpl* iface, LPWSTR *result)
+{
+  StorageInternalImpl* This = (StorageInternalImpl*) iface;
+
+  return StorageBaseImpl_GetFilename(This->parentStorage, result);
+}
+
 static HRESULT StorageInternalImpl_CreateDirEntry(StorageBaseImpl *base,
   const DirEntry *newData, DirRef *index)
 {
@@ -5081,8 +5201,9 @@ static HRESULT WINAPI StorageInternalImpl_Commit(
   IStorage*            iface,
   DWORD                  grfCommitFlags)  /* [in] */
 {
-  FIXME("(%p,%x): stub\n", iface, grfCommitFlags);
-  return S_OK;
+  StorageBaseImpl* base = (StorageBaseImpl*) iface;
+  TRACE("(%p,%x)\n", iface, grfCommitFlags);
+  return StorageBaseImpl_Flush(base);
 }
 
 /******************************************************************************
@@ -5425,6 +5546,8 @@ static const StorageBaseImplVtbl StorageInternalImpl_BaseVtbl =
 {
   StorageInternalImpl_Destroy,
   StorageInternalImpl_Invalidate,
+  StorageInternalImpl_Flush,
+  StorageInternalImpl_GetFilename,
   StorageInternalImpl_CreateDirEntry,
   StorageInternalImpl_WriteDirEntry,
   StorageInternalImpl_ReadDirEntry,
@@ -5565,33 +5688,26 @@ void StorageUtl_CopyDirEntryToSTATSTG(
   const DirEntry*       source,
   int                   statFlags)
 {
-  LPCWSTR entryName;
-
-  if (source->stgType == STGTY_ROOT)
-  {
-    /* replace the name of root entry (often "Root Entry") by the file name */
-    entryName = storage->filename;
-  }
-  else
-  {
-    entryName = source->name;
-  }
-
   /*
    * The copy of the string occurs only when the flag is not set
    */
-  if( ((statFlags & STATFLAG_NONAME) != 0) || 
-       (entryName == NULL) ||
-       (entryName[0] == 0) )
+  if (!(statFlags & STATFLAG_NONAME) && source->stgType == STGTY_ROOT)
+  {
+    /* Use the filename for the root storage. */
+    destination->pwcsName = 0;
+    StorageBaseImpl_GetFilename(storage, &destination->pwcsName);
+  }
+  else if( ((statFlags & STATFLAG_NONAME) != 0) ||
+       (source->name[0] == 0) )
   {
     destination->pwcsName = 0;
   }
   else
   {
     destination->pwcsName =
-      CoTaskMemAlloc((lstrlenW(entryName)+1)*sizeof(WCHAR));
+      CoTaskMemAlloc((lstrlenW(source->name)+1)*sizeof(WCHAR));
 
-    strcpyW(destination->pwcsName, entryName);
+    strcpyW(destination->pwcsName, source->name);
   }
 
   switch (source->stgType)
@@ -5732,6 +5848,53 @@ ULONG BlockChainStream_GetSectorOfOffset(BlockChainStream *This, ULONG offset)
   return This->indexCache[min_run].firstSector + offset - This->indexCache[min_run].firstOffset;
 }
 
+HRESULT BlockChainStream_GetBlockAtOffset(BlockChainStream *This,
+    ULONG index, BlockChainBlock **block, ULONG *sector, BOOL create)
+{
+  BlockChainBlock *result=NULL;
+  int i;
+
+  for (i=0; i<2; i++)
+    if (This->cachedBlocks[i].index == index)
+    {
+      *sector = This->cachedBlocks[i].sector;
+      *block = &This->cachedBlocks[i];
+      return S_OK;
+    }
+
+  *sector = BlockChainStream_GetSectorOfOffset(This, index);
+  if (*sector == BLOCK_END_OF_CHAIN)
+    return STG_E_DOCFILECORRUPT;
+
+  if (create)
+  {
+    if (This->cachedBlocks[0].index == 0xffffffff)
+      result = &This->cachedBlocks[0];
+    else if (This->cachedBlocks[1].index == 0xffffffff)
+      result = &This->cachedBlocks[1];
+    else
+    {
+      result = &This->cachedBlocks[This->blockToEvict++];
+      if (This->blockToEvict == 2)
+        This->blockToEvict = 0;
+    }
+
+    if (result->dirty)
+    {
+      if (!StorageImpl_WriteBigBlock(This->parentStorage, result->sector, result->data))
+        return STG_E_WRITEFAULT;
+      result->dirty = 0;
+    }
+
+    result->read = 0;
+    result->index = index;
+    result->sector = *sector;
+  }
+
+  *block = result;
+  return S_OK;
+}
+
 BlockChainStream* BlockChainStream_Construct(
   StorageImpl* parentStorage,
   ULONG*         headOfStreamPlaceHolder,
@@ -5747,6 +5910,11 @@ BlockChainStream* BlockChainStream_Construct(
   newStream->indexCache              = NULL;
   newStream->indexCacheLen           = 0;
   newStream->indexCacheSize          = 0;
+  newStream->cachedBlocks[0].index = 0xffffffff;
+  newStream->cachedBlocks[0].dirty = 0;
+  newStream->cachedBlocks[1].index = 0xffffffff;
+  newStream->cachedBlocks[1].dirty = 0;
+  newStream->blockToEvict          = 0;
 
   if (FAILED(BlockChainStream_UpdateIndexCache(newStream)))
   {
@@ -5758,10 +5926,30 @@ BlockChainStream* BlockChainStream_Construct(
   return newStream;
 }
 
+HRESULT BlockChainStream_Flush(BlockChainStream* This)
+{
+  int i;
+  if (!This) return S_OK;
+  for (i=0; i<2; i++)
+  {
+    if (This->cachedBlocks[i].dirty)
+    {
+      if (StorageImpl_WriteBigBlock(This->parentStorage, This->cachedBlocks[i].sector, This->cachedBlocks[i].data))
+        This->cachedBlocks[i].dirty = 0;
+      else
+        return STG_E_WRITEFAULT;
+    }
+  }
+  return S_OK;
+}
+
 void BlockChainStream_Destroy(BlockChainStream* This)
 {
   if (This)
+  {
+    BlockChainStream_Flush(This);
     HeapFree(GetProcessHeap(), 0, This->indexCache);
+  }
   HeapFree(GetProcessHeap(), 0, This);
 }
 
@@ -5827,6 +6015,8 @@ HRESULT BlockChainStream_ReadAt(BlockChainStream* This,
   ULONG blockIndex;
   BYTE* bufferWalker;
   ULARGE_INTEGER stream_size;
+  HRESULT hr;
+  BlockChainBlock *cachedBlock;
 
   TRACE("(%p)-> %i %p %i %p\n",This, offset.u.LowPart, buffer, size, bytesRead);
 
@@ -5848,32 +6038,50 @@ HRESULT BlockChainStream_ReadAt(BlockChainStream* This,
    */
   bufferWalker = buffer;
 
-  while ( (size > 0) && (blockIndex != BLOCK_END_OF_CHAIN) )
+  while (size > 0)
   {
     ULARGE_INTEGER ulOffset;
     DWORD bytesReadAt;
+
     /*
      * Calculate how many bytes we can copy from this big block.
      */
     bytesToReadInBuffer =
       min(This->parentStorage->bigBlockSize - offsetInBlock, size);
 
-     TRACE("block %i\n",blockIndex);
-     ulOffset.u.HighPart = 0;
-     ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
-                             offsetInBlock;
+    hr = BlockChainStream_GetBlockAtOffset(This, blockNoInSequence, &cachedBlock, &blockIndex, size == bytesToReadInBuffer);
 
-     StorageImpl_ReadAt(This->parentStorage,
-         ulOffset,
-         bufferWalker,
-         bytesToReadInBuffer,
-         &bytesReadAt);
-    /*
-     * Step to the next big block.
-     */
-    if( size > bytesReadAt && FAILED(StorageImpl_GetNextBlockInChain(This->parentStorage, blockIndex, &blockIndex)))
-      return STG_E_DOCFILECORRUPT;
+    if (FAILED(hr))
+      return hr;
+
+    if (!cachedBlock)
+    {
+      /* Not in cache, and we're going to read past the end of the block. */
+      ulOffset.u.HighPart = 0;
+      ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
+                               offsetInBlock;
+
+      StorageImpl_ReadAt(This->parentStorage,
+           ulOffset,
+           bufferWalker,
+           bytesToReadInBuffer,
+           &bytesReadAt);
+    }
+    else
+    {
+      if (!cachedBlock->read)
+      {
+        if (!StorageImpl_ReadBigBlock(This->parentStorage, cachedBlock->sector, cachedBlock->data))
+          return STG_E_READFAULT;
 
+        cachedBlock->read = 1;
+      }
+
+      memcpy(bufferWalker, cachedBlock->data+offsetInBlock, bytesToReadInBuffer);
+      bytesReadAt = bytesToReadInBuffer;
+    }
+
+    blockNoInSequence++;
     bufferWalker += bytesReadAt;
     size         -= bytesReadAt;
     *bytesRead   += bytesReadAt;
@@ -5903,51 +6111,61 @@ HRESULT BlockChainStream_WriteAt(BlockChainStream* This,
   ULONG bytesToWrite;
   ULONG blockIndex;
   const BYTE* bufferWalker;
-
-  /*
-   * Find the first block in the stream that contains part of the buffer.
-   */
-  blockIndex = BlockChainStream_GetSectorOfOffset(This, blockNoInSequence);
-
-  /* BlockChainStream_SetSize should have already been called to ensure we have
-   * enough blocks in the chain to write into */
-  if (blockIndex == BLOCK_END_OF_CHAIN)
-  {
-    ERR("not enough blocks in chain to write data\n");
-    return STG_E_DOCFILECORRUPT;
-  }
+  HRESULT hr;
+  BlockChainBlock *cachedBlock;
 
   *bytesWritten   = 0;
   bufferWalker = buffer;
 
-  while ( (size > 0) && (blockIndex != BLOCK_END_OF_CHAIN) )
+  while (size > 0)
   {
     ULARGE_INTEGER ulOffset;
     DWORD bytesWrittenAt;
+
     /*
-     * Calculate how many bytes we can copy from this big block.
+     * Calculate how many bytes we can copy to this big block.
      */
     bytesToWrite =
       min(This->parentStorage->bigBlockSize - offsetInBlock, size);
 
-    TRACE("block %i\n",blockIndex);
-    ulOffset.u.HighPart = 0;
-    ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
-                             offsetInBlock;
+    hr = BlockChainStream_GetBlockAtOffset(This, blockNoInSequence, &cachedBlock, &blockIndex, size == bytesToWrite);
 
-    StorageImpl_WriteAt(This->parentStorage,
-         ulOffset,
-         bufferWalker,
-         bytesToWrite,
-         &bytesWrittenAt);
+    /* BlockChainStream_SetSize should have already been called to ensure we have
+     * enough blocks in the chain to write into */
+    if (FAILED(hr))
+    {
+      ERR("not enough blocks in chain to write data\n");
+      return hr;
+    }
 
-    /*
-     * Step to the next big block.
-     */
-    if(size > bytesWrittenAt && FAILED(StorageImpl_GetNextBlockInChain(This->parentStorage, blockIndex,
-                                             &blockIndex)))
-      return STG_E_DOCFILECORRUPT;
+    if (!cachedBlock)
+    {
+      /* Not in cache, and we're going to write past the end of the block. */
+      ulOffset.u.HighPart = 0;
+      ulOffset.u.LowPart = StorageImpl_GetBigBlockOffset(This->parentStorage, blockIndex) +
+                               offsetInBlock;
+
+      StorageImpl_WriteAt(This->parentStorage,
+           ulOffset,
+           bufferWalker,
+           bytesToWrite,
+           &bytesWrittenAt);
+    }
+    else
+    {
+      if (!cachedBlock->read && bytesToWrite != This->parentStorage->bigBlockSize)
+      {
+        if (!StorageImpl_ReadBigBlock(This->parentStorage, cachedBlock->sector, cachedBlock->data))
+          return STG_E_READFAULT;
+      }
 
+      memcpy(cachedBlock->data+offsetInBlock, bufferWalker, bytesToWrite);
+      bytesWrittenAt = bytesToWrite;
+      cachedBlock->read = 1;
+      cachedBlock->dirty = 1;
+    }
+
+    blockNoInSequence++;
     bufferWalker  += bytesWrittenAt;
     size          -= bytesWrittenAt;
     *bytesWritten += bytesWrittenAt;
@@ -5970,6 +6188,7 @@ static BOOL BlockChainStream_Shrink(BlockChainStream* This,
 {
   ULONG blockIndex;
   ULONG numBlocks;
+  int i;
 
   /*
    * Figure out how many blocks are needed to contain the new size
@@ -6037,6 +6256,18 @@ static BOOL BlockChainStream_Shrink(BlockChainStream* This,
       last_run->lastOffset--;
   }
 
+  /*
+   * Reset the last accessed block cache.
+   */
+  for (i=0; i<2; i++)
+  {
+    if (This->cachedBlocks[i].index >= numBlocks)
+    {
+      This->cachedBlocks[i].index = 0xffffffff;
+      This->cachedBlocks[i].dirty = 0;
+    }
+  }
+
   return TRUE;
 }