Skip to content
代码片段 群组 项目
interrogate.py 8.5 KB
更新 更旧
  • 了解如何忽略特定修订
  • AUTOMATIC's avatar
    AUTOMATIC 已提交
    import os
    import sys
    import traceback
    from collections import namedtuple
    
    from pathlib import Path
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    import re
    
    import torch
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    from torchvision import transforms
    from torchvision.transforms.functional import InterpolationMode
    
    
    from modules import devices, paths, shared, lowvram, modelloader, errors
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    blip_image_eval_size = 384
    clip_model_name = 'ViT-L/14'
    
    Category = namedtuple("Category", ["name", "topn", "items"])
    
    re_topn = re.compile(r"\.top(\d+)\.")
    
    
    def category_types():
        return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
    
    
    def download_default_clip_interrogate_categories(content_dir):
        print("Downloading CLIP categories...")
    
    
        category_types = ["artists", "flavors", "mediums", "movements"]
    
    
            os.makedirs(tmpdir, exist_ok=True)
    
    Vladimir Mandic's avatar
    Vladimir Mandic 已提交
            for category_type in category_types:
                torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
    
            os.rename(tmpdir, content_dir)
    
        except Exception as e:
            errors.display(e, "downloading default CLIP interrogate categories")
        finally:
            if os.path.exists(tmpdir):
    
    darnell8's avatar
    darnell8 已提交
                os.removedirs(tmpdir)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    class InterrogateModels:
        blip_model = None
        clip_model = None
        clip_preprocess = None
    
        dtype = None
    
        running_on_cpu = None
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
        def __init__(self, content_dir):
    
            self.skip_categories = []
    
            self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
            if not os.path.exists(self.content_dir):
                download_default_clip_interrogate_categories(self.content_dir)
    
            if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
    
    Vladimir Mandic's avatar
    Vladimir Mandic 已提交
               return self.loaded_categories
    
    
            self.loaded_categories = []
    
            if os.path.exists(self.content_dir):
    
                self.skip_categories = shared.opts.interrogate_clip_skip_categories
                category_types = []
                for filename in Path(self.content_dir).glob('*.txt'):
                    category_types.append(filename.stem)
                    if filename.stem in self.skip_categories:
    
    Vladimir Mandic's avatar
    Vladimir Mandic 已提交
                        continue
    
                    m = re_topn.search(filename.stem)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
                    topn = 1 if m is None else int(m.group(1))
    
    Vladimir Mandic's avatar
    Vladimir Mandic 已提交
                    with open(filename, "r", encoding="utf8") as file:
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
                        lines = [x.strip() for x in file.readlines()]
    
    
                    self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
        def create_fake_fairscale(self):
            class FakeFairscale:
                def checkpoint_wrapper(self):
                    pass
    
            sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
    
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
        def load_blip_model(self):
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            files = modelloader.load_models(
                model_path=os.path.join(paths.models_path, "BLIP"),
                model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
                ext_filter=[".pth"],
                download_name='model_base_caption_capfilt_large.pth',
            )
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            blip_model.eval()
    
            return blip_model
    
        def load_clip_model(self):
            import clip
    
    
            if self.running_on_cpu:
    
                model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
    
                model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            model.eval()
    
            model = model.to(devices.device_interrogate)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
            return model, preprocess
    
        def load(self):
            if self.blip_model is None:
                self.blip_model = self.load_blip_model()
    
                if not shared.cmd_opts.no_half and not self.running_on_cpu:
    
                    self.blip_model = self.blip_model.half()
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
            self.blip_model = self.blip_model.to(devices.device_interrogate)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
            if self.clip_model is None:
                self.clip_model, self.clip_preprocess = self.load_clip_model()
    
                if not shared.cmd_opts.no_half and not self.running_on_cpu:
    
                    self.clip_model = self.clip_model.half()
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
            self.clip_model = self.clip_model.to(devices.device_interrogate)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
            self.dtype = next(self.clip_model.parameters()).dtype
    
    
        def send_clip_to_ram(self):
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            if not shared.opts.interrogate_keep_models_in_memory:
                if self.clip_model is not None:
                    self.clip_model = self.clip_model.to(devices.cpu)
    
    
        def send_blip_to_ram(self):
            if not shared.opts.interrogate_keep_models_in_memory:
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
                if self.blip_model is not None:
                    self.blip_model = self.blip_model.to(devices.cpu)
    
    
        def unload(self):
            self.send_clip_to_ram()
            self.send_blip_to_ram()
    
            devices.torch_gc()
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
        def rank(self, image_features, text_array, top_count=1):
            import clip
    
    
    Vladimir Mandic's avatar
    Vladimir Mandic 已提交
            devices.torch_gc()
    
    
            if shared.opts.interrogate_clip_dict_limit != 0:
                text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
    
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            top_count = min(top_count, len(text_array))
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
    
            text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            text_features /= text_features.norm(dim=-1, keepdim=True)
    
    
            similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            for i in range(image_features.shape[0]):
                similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
            similarity /= image_features.shape[0]
    
            top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
            return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
    
        def generate_caption(self, pil_image):
            gpu_image = transforms.Compose([
                transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    
            ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
            with torch.no_grad():
                caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
    
            return caption[0]
    
    
        def interrogate(self, pil_image):
    
            shared.state.begin()
            shared.state.job = 'interrogate'
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
            try:
    
                if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                    lowvram.send_everything_to_cpu()
                    devices.torch_gc()
    
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
                self.load()
    
                caption = self.generate_caption(pil_image)
    
                self.send_blip_to_ram()
                devices.torch_gc()
    
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
                res = caption
    
    
                clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
                with torch.no_grad(), devices.autocast():
    
    Aidan Holland's avatar
    Aidan Holland 已提交
                    image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
                    image_features /= image_features.norm(dim=-1, keepdim=True)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
                    for cat in self.categories():
                        matches = self.rank(image_features, cat.items, top_count=cat.topn)
    
                        for match, score in matches:
    
                            if shared.opts.interrogate_return_ranks:
                                res += f", ({match}:{score/100:.3f})"
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
            except Exception:
    
                print("Error interrogating", file=sys.stderr)
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
                print(traceback.format_exc(), file=sys.stderr)
    
                res += "<error>"
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
            self.unload()
    
            shared.state.end()
    
    AUTOMATIC's avatar
    AUTOMATIC 已提交
    
            return res