From c787b351ea37dc248a982a68ac5302b21f5cb962 Mon Sep 17 00:00:00 2001
From: Sen Huang <senhuang96@fb.com>
Date: Thu, 7 Nov 2019 11:46:25 -0500
Subject: [PATCH] Use ZSTD Error codes, improve explanation of
 ZSTD_loadCEntropy() and ZSTD_loadDEntropy()

---
 lib/compress/zstd_compress.c              | 4 ++--
 lib/compress/zstd_compress_internal.h     | 4 ++--
 lib/decompress/zstd_decompress_internal.h | 2 +-
 lib/dictBuilder/zdict.c                   | 8 ++++----
 lib/dictBuilder/zdict.h                   | 2 +-
 5 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c
index 3c04718e3..edc238b8b 100644
--- a/lib/compress/zstd_compress.c
+++ b/lib/compress/zstd_compress.c
@@ -2772,7 +2772,7 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace,
                          short* offcodeNCount, unsigned* offcodeMaxValue,
                          const void* const dict, size_t dictSize) 
 {
-    const BYTE* dictPtr = (const BYTE*)dict + 8;
+    const BYTE* dictPtr = (const BYTE*)dict + 8;    /* skip magic num and dict ID */
     const BYTE* const dictEnd = dictPtr + dictSize;
 
     {   unsigned maxSymbolValue = 255;
@@ -2869,7 +2869,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs,
     dictID = params->fParams.noDictIDFlag ? 0 :  MEM_readLE32(dictPtr);
     dictPtr += 4;
 
-    dictPtr += eSize - 8;   /* size of header + magic number already accounted for */
+    dictPtr += eSize - 8;
 
     {   size_t const dictContentSize = (size_t)(dictEnd - dictPtr);
         U32 offcodeMax = MaxOff;
diff --git a/lib/compress/zstd_compress_internal.h b/lib/compress/zstd_compress_internal.h
index 62ee3f9bc..0811ccf9a 100644
--- a/lib/compress/zstd_compress_internal.h
+++ b/lib/compress/zstd_compress_internal.h
@@ -931,13 +931,13 @@ MEM_STATIC void ZSTD_debugTable(const U32* table, U32 max)
 }
 #endif
 /* ===============================================================
- * Public declarations
+ * Shared internal declarations
  * These prototypes may be called from sources not in lib/compress
  * =============================================================== */
 
 /* ZSTD_loadCEntropy() :
  * dict : must point at beginning of a valid zstd dictionary.
- * return : size of entropy tables read */
+ * return : size of dictionary header (size of magic number + dict ID + entropy tables) */
 size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace,
                          short* offcodeNCount, unsigned* offcodeMaxValue,
                          const void* const dict, size_t dictSize);
diff --git a/lib/decompress/zstd_decompress_internal.h b/lib/decompress/zstd_decompress_internal.h
index ccbdfa090..99eab854c 100644
--- a/lib/decompress/zstd_decompress_internal.h
+++ b/lib/decompress/zstd_decompress_internal.h
@@ -160,7 +160,7 @@ struct ZSTD_DCtx_s
 
 /*! ZSTD_loadDEntropy() :
  *  dict : must point at beginning of a valid zstd dictionary.
- * @return : size of entropy tables read */
+ * @return : size of dictionary header (size of magic number + dict ID + entropy tables) */
 size_t ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy,
                    const void* const dict, size_t const dictSize);
 
diff --git a/lib/dictBuilder/zdict.c b/lib/dictBuilder/zdict.c
index 6d76fb521..1c0915fe3 100644
--- a/lib/dictBuilder/zdict.c
+++ b/lib/dictBuilder/zdict.c
@@ -102,22 +102,22 @@ unsigned ZDICT_getDictID(const void* dictBuffer, size_t dictSize)
 
 size_t ZDICT_getDictHeaderSize(const void* dictBuffer, size_t dictSize)
 {
-    if (dictSize <= 8 || MEM_readLE32(dictBuffer) != ZSTD_MAGIC_DICTIONARY) return 0;
+    if (dictSize <= 8 || MEM_readLE32(dictBuffer) != ZSTD_MAGIC_DICTIONARY) return ERROR(dictionary_corrupted);
 
     {   size_t headerSize;
         unsigned offcodeMaxValue = MaxOff;
         ZSTD_compressedBlockState_t* dummyBs = (ZSTD_compressedBlockState_t*)malloc(sizeof(ZSTD_compressedBlockState_t));
         U32* wksp = (U32*)malloc(HUF_WORKSPACE_SIZE);
         short* offcodeNCount = (short*)malloc((MaxOff+1)*sizeof(short));
-        if (!dummyBs || !wksp) {
-            return 0;
+        if (!dummyBs || !wksp || !offcodeNCount) {
+            return ERROR(memory_allocation);
         }
 
         headerSize = ZSTD_loadCEntropy(dummyBs, wksp, offcodeNCount, &offcodeMaxValue, dictBuffer, dictSize);
         free(dummyBs);
         free(wksp);
         free(offcodeNCount);
-        return headerSize;
+        return headerSize;  /* this may be an error value if ZSTD_loadCEntropy() encountered an error */
     }
 }
 
diff --git a/lib/dictBuilder/zdict.h b/lib/dictBuilder/zdict.h
index bb89f1f9f..1313bd214 100644
--- a/lib/dictBuilder/zdict.h
+++ b/lib/dictBuilder/zdict.h
@@ -64,7 +64,7 @@ ZDICTLIB_API size_t ZDICT_trainFromBuffer(void* dictBuffer, size_t dictBufferCap
 
 /*======   Helper functions   ======*/
 ZDICTLIB_API unsigned ZDICT_getDictID(const void* dictBuffer, size_t dictSize);  /**< extracts dictID; @return zero if error (not a valid dictionary) */
-ZDICTLIB_API size_t ZDICT_getDictHeaderSize(const void* dictBuffer, size_t dictSize);  /* returns dict header size; returns zero if error (not a valid dictionary or mem alloc failure) */
+ZDICTLIB_API size_t ZDICT_getDictHeaderSize(const void* dictBuffer, size_t dictSize);  /* returns dict header size; returns a ZSTD error code on failure */
 ZDICTLIB_API unsigned ZDICT_isError(size_t errorCode);
 ZDICTLIB_API const char* ZDICT_getErrorName(size_t errorCode);
 
-- 
GitLab