Skip to content
代码片段 群组 项目
sd_hijack_utils.py 1.3 KB
更新 更旧
  • 了解如何忽略特定修订
  • import importlib
    
    class CondFunc:
        def __new__(cls, orig_func, sub_func, cond_func):
            self = super(CondFunc, cls).__new__(cls)
            if isinstance(orig_func, str):
                func_path = orig_func.split('.')
    
                for i in range(len(func_path)-1, -1, -1):
    
                    try:
                        resolved_obj = importlib.import_module('.'.join(func_path[:i]))
                        break
                    except ImportError:
                        pass
                for attr_name in func_path[i:-1]:
                    resolved_obj = getattr(resolved_obj, attr_name)
                orig_func = getattr(resolved_obj, func_path[-1])
                setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
            self.__init__(orig_func, sub_func, cond_func)
            return lambda *args, **kwargs: self(*args, **kwargs)
        def __init__(self, orig_func, sub_func, cond_func):
            self.__orig_func = orig_func
            self.__sub_func = sub_func
            self.__cond_func = cond_func
        def __call__(self, *args, **kwargs):
            if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
                return self.__sub_func(self.__orig_func, *args, **kwargs)
            else:
                return self.__orig_func(*args, **kwargs)