Skip to content
代码片段 群组 项目
safe.py 6.8 KB
更新 更旧
  • 了解如何忽略特定修订
  • # this code is adapted from the script contributed by anon from /h/
    
    import io
    import pickle
    import collections
    import sys
    import traceback
    
    import torch
    import numpy
    import _codecs
    import zipfile
    
    import re
    
    # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
    TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
    
    
    
    def encode(*args):
        out = _codecs.encode(*args)
        return out
    
    
    class RestrictedUnpickler(pickle.Unpickler):
    
        def persistent_load(self, saved_id):
            assert saved_id[0] == 'storage'
    
            return TypedStorage()
    
            if self.extra_handler is not None:
                res = self.extra_handler(module, name)
                if res is not None:
                    return res
    
    
            if module == 'collections' and name == 'OrderedDict':
                return getattr(collections, name)
    
    brkirch's avatar
    brkirch 已提交
            if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
    
                return getattr(torch._utils, name)
    
    brkirch's avatar
    brkirch 已提交
            if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
    
                return getattr(torch, name)
            if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
                return getattr(torch.nn.modules.container, name)
    
    brkirch's avatar
    brkirch 已提交
            if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
                return getattr(numpy.core.multiarray, name)
            if module == 'numpy' and name in ['dtype', 'ndarray']:
                return getattr(numpy, name)
    
            if module == '_codecs' and name == 'encode':
                return encode
            if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
                import pytorch_lightning.callbacks
                return pytorch_lightning.callbacks.model_checkpoint
            if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
                import pytorch_lightning.callbacks.model_checkpoint
                return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
            if module == "__builtin__" and name == 'set':
                return set
    
            # Forbid everything else.
    
            raise Exception(f"global '{module}/{name}' is forbidden")
    
    # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
    allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
    data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
    
    
    def check_zip_filenames(filename, names):
        for name in names:
            if allowed_zip_names_re.match(name):
                continue
    
            raise Exception(f"bad file inside {filename}: {name}")
    
    
    
        try:
    
            # new pytorch format is a zip file
            with zipfile.ZipFile(filename) as z:
    
                check_zip_filenames(filename, z.namelist())
    
    Jim Hays's avatar
    Jim Hays 已提交
    
    
                # find filename of data.pkl in zip file: '<directory name>/data.pkl'
                data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
                if len(data_pkl_filenames) == 0:
                    raise Exception(f"data.pkl not found in {filename}")
                if len(data_pkl_filenames) > 1:
                    raise Exception(f"Multiple data.pkl found in {filename}")
                with z.open(data_pkl_filenames[0]) as file:
    
                    unpickler = RestrictedUnpickler(file)
    
                    unpickler.load()
    
        except zipfile.BadZipfile:
    
            # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
            with open(filename, "rb") as file:
                unpickler = RestrictedUnpickler(file)
    
                for i in range(5):
                    unpickler.load()
    
    
    def load(filename, *args, **kwargs):
    
        return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
    
    
    
    def load_with_extra(filename, extra_handler=None, *args, **kwargs):
        """
    
    Jim Hays's avatar
    Jim Hays 已提交
        this function is intended to be used by extensions that want to load models with
    
        some extra classes in them that the usual unpickler would find suspicious.
    
        Use the extra_handler argument to specify a function that takes module and field name as text,
        and returns that field's value:
    
        ```python
        def extra(module, name):
            if module == 'collections' and name == 'OrderedDict':
                return collections.OrderedDict
    
            return None
    
        safe.load_with_extra('model.pt', extra_handler=extra)
        ```
    
        The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
        definitely unsafe.
        """
    
    
        from modules import shared
    
        try:
            if not shared.cmd_opts.disable_safe_unpickle:
    
        except pickle.UnpicklingError:
            print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
    
            print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
            print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
    
            return None
    
    
        except Exception:
            print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
    
            print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
            print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
    
            return None
    
        return unsafe_torch_load(filename, *args, **kwargs)
    
    
    
    class Extra:
        """
        A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
        (because it's not your code making the torch.load call). The intended use is like this:
    
    ```
    import torch
    from modules import safe
    
    def handler(module, name):
        if module == 'torch' and name in ['float64', 'float16']:
            return getattr(torch, name)
    
        return None
    
    with safe.Extra(handler):
        x = torch.load('model.pt')
    ```
        """
    
        def __init__(self, handler):
            self.handler = handler
    
        def __enter__(self):
            global global_extra_handler
    
            assert global_extra_handler is None, 'already inside an Extra() block'
            global_extra_handler = self.handler
    
        def __exit__(self, exc_type, exc_val, exc_tb):
            global global_extra_handler
    
            global_extra_handler = None
    
    
    
    unsafe_torch_load = torch.load
    torch.load = load