Skip to content
代码片段 群组 项目
hypernetwork.py 9.7 KB
更新 更旧
AUTOMATIC's avatar
AUTOMATIC 已提交
import datetime
import glob
import html
import os
import sys
import traceback
import tqdm

import torch

from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset


class HypernetworkModule(torch.nn.Module):
    def __init__(self, dim, state_dict=None):
        super().__init__()

        self.linear1 = torch.nn.Linear(dim, dim * 2)
        self.linear2 = torch.nn.Linear(dim * 2, dim)

        if state_dict is not None:
            self.load_state_dict(state_dict, strict=True)
        else:
AUTOMATIC's avatar
AUTOMATIC 已提交

            self.linear1.weight.data.normal_(mean=0.0, std=0.01)
            self.linear1.bias.data.zero_()
            self.linear2.weight.data.normal_(mean=0.0, std=0.01)
            self.linear2.bias.data.zero_()
AUTOMATIC's avatar
AUTOMATIC 已提交

        self.to(devices.device)

    def forward(self, x):
        return x + (self.linear2(self.linear1(x)))


class Hypernetwork:
    filename = None
    name = None

    def __init__(self, name=None, enable_sizes=None):
AUTOMATIC's avatar
AUTOMATIC 已提交
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None

        for size in enable_sizes or [320, 640, 768, 1280]:
AUTOMATIC's avatar
AUTOMATIC 已提交
            self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))

    def weights(self):
        res = []

        for k, layers in self.layers.items():
            for layer in layers:
                layer.train()
                res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]

        return res

    def save(self, filename):
        state_dict = {}

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())

        state_dict['step'] = self.step
        state_dict['name'] = self.name
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name

        torch.save(state_dict, filename)

    def load(self, filename):
        self.filename = filename
        if self.name is None:
            self.name = os.path.splitext(os.path.basename(filename))[0]

        state_dict = torch.load(filename, map_location='cpu')

        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)


AUTOMATIC's avatar
AUTOMATIC 已提交
def list_hypernetworks(path):
AUTOMATIC's avatar
AUTOMATIC 已提交
    res = {}
AUTOMATIC's avatar
AUTOMATIC 已提交
    for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
        name = os.path.splitext(os.path.basename(filename))[0]
        res[name] = filename
    return res
AUTOMATIC's avatar
AUTOMATIC 已提交

def load_hypernetwork(filename):
    path = shared.hypernetworks.get(filename, None)
    if path is not None:
        print(f"Loading hypernetwork {filename}")
AUTOMATIC's avatar
AUTOMATIC 已提交
        try:
AUTOMATIC's avatar
AUTOMATIC 已提交
            shared.loaded_hypernetwork = Hypernetwork()
            shared.loaded_hypernetwork.load(path)

AUTOMATIC's avatar
AUTOMATIC 已提交
        except Exception:
AUTOMATIC's avatar
AUTOMATIC 已提交
            print(f"Error loading hypernetwork {path}", file=sys.stderr)
AUTOMATIC's avatar
AUTOMATIC 已提交
            print(traceback.format_exc(), file=sys.stderr)
AUTOMATIC's avatar
AUTOMATIC 已提交
    else:
        if shared.loaded_hypernetwork is not None:
            print(f"Unloading hypernetwork")
AUTOMATIC's avatar
AUTOMATIC 已提交
        shared.loaded_hypernetwork = None
AUTOMATIC's avatar
AUTOMATIC 已提交
def apply_hypernetwork(hypernetwork, context, layer=None):
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
AUTOMATIC's avatar
AUTOMATIC 已提交
    if hypernetwork_layers is None:
        return context, context
AUTOMATIC's avatar
AUTOMATIC 已提交
    if layer is not None:
        layer.hyper_k = hypernetwork_layers[0]
        layer.hyper_v = hypernetwork_layers[1]
AUTOMATIC's avatar
AUTOMATIC 已提交
    context_k = hypernetwork_layers[0](context)
    context_v = hypernetwork_layers[1](context)
    return context_k, context_v
AUTOMATIC's avatar
AUTOMATIC 已提交
def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
AUTOMATIC's avatar
AUTOMATIC 已提交
    context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
AUTOMATIC's avatar
AUTOMATIC 已提交
    k = self.to_k(context_k)
    v = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)


def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
    assert hypernetwork_name, 'embedding not selected'

AUTOMATIC's avatar
AUTOMATIC 已提交
    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
    shared.loaded_hypernetwork.load(path)
AUTOMATIC's avatar
AUTOMATIC 已提交

    shared.state.textinfo = "Initializing hypernetwork training..."
    shared.state.job_count = steps

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
    unload = shared.opts.unload_models_when_training
AUTOMATIC's avatar
AUTOMATIC 已提交

    if save_hypernetwork_every > 0:
        hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
        os.makedirs(hypernetwork_dir, exist_ok=True)
    else:
        hypernetwork_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
AUTOMATIC's avatar
AUTOMATIC 已提交
    hypernetwork = shared.loaded_hypernetwork
AUTOMATIC's avatar
AUTOMATIC 已提交
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True

    optimizer = torch.optim.AdamW(weights, lr=learn_rate)

    losses = torch.zeros((32,))

    last_saved_file = "<none>"
    last_saved_image = "<none>"

    ititial_step = hypernetwork.step or 0
    if ititial_step > steps:
        return hypernetwork, filename

AUTOMATIC's avatar
AUTOMATIC 已提交
    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
AUTOMATIC's avatar
AUTOMATIC 已提交
        hypernetwork.step = i + ititial_step

        if hypernetwork.step > steps:
            break

        if shared.state.interrupted:
            break

        with torch.autocast("cuda"):
AUTOMATIC's avatar
AUTOMATIC 已提交
            x = x.to(devices.device)
            loss = shared.sd_model(x.unsqueeze(0), cond)[0]
AUTOMATIC's avatar
AUTOMATIC 已提交
            del x
AUTOMATIC's avatar
AUTOMATIC 已提交

            losses[hypernetwork.step % losses.shape[0]] = loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pbar.set_description(f"loss: {losses.mean():.7f}")

        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
            hypernetwork.save(last_saved_file)

        if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
            last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')

            preview_text = text if preview_image_prompt == "" else preview_image_prompt

            optimizer.zero_grad()
            shared.sd_model.cond_stage_model.to(devices.device)
            shared.sd_model.first_stage_model.to(devices.device)

AUTOMATIC's avatar
AUTOMATIC 已提交
            p = processing.StableDiffusionProcessingTxt2Img(
                sd_model=shared.sd_model,
                prompt=preview_text,
                steps=20,
                do_not_save_grid=True,
                do_not_save_samples=True,
            )

            processed = processing.process_images(p)
            image = processed.images[0]

            if unload:
                shared.sd_model.cond_stage_model.to(devices.cpu)
                shared.sd_model.first_stage_model.to(devices.cpu)

AUTOMATIC's avatar
AUTOMATIC 已提交
            shared.state.current_image = image
            image.save(last_saved_image)

            last_saved_image += f", prompt: {preview_text}"

        shared.state.job_no = hypernetwork.step

        shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""

    checkpoint = sd_models.select_checkpoint()

    hypernetwork.sd_checkpoint = checkpoint.hash
    hypernetwork.sd_checkpoint_name = checkpoint.model_name
    hypernetwork.save(filename)

    return hypernetwork, filename