Skip to content
代码片段 群组 项目
fix_torch.py 946 字节
更新 更旧
  • 了解如何忽略特定修订
  • import importlib.util
    import shutil
    import os
    import ctypes
    import logging
    
    
    
    Chenlei Hu's avatar
    Chenlei Hu 已提交
    def fix_pytorch_libomp():
        """
        Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
        """
        torch_spec = importlib.util.find_spec("torch")
        for folder in torch_spec.submodule_search_locations:
            lib_folder = os.path.join(folder, "lib")
            test_file = os.path.join(lib_folder, "fbgemm.dll")
            dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
            if os.path.exists(dest):
    
    Chenlei Hu's avatar
    Chenlei Hu 已提交
    
            with open(test_file, "rb") as f:
                contents = f.read()
                if b"libomp140.x86_64.dll" not in contents:
                    break
            try:
    
                ctypes.cdll.LoadLibrary(test_file)
            except FileNotFoundError:
    
    Chenlei Hu's avatar
    Chenlei Hu 已提交
                logging.warning("Detected pytorch version with libomp issue, patching.")
                shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)