Skip to content
代码片段 群组 项目
提交 a00cd8b9 编辑于 作者: AUTOMATIC's avatar AUTOMATIC
浏览文件

attempt to fix memory monitor with multiple CUDA devices

上级 6033de18
No related branches found
No related tags found
无相关合并请求
...@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread): ...@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
self.data = defaultdict(int) self.data = defaultdict(int)
try: try:
torch.cuda.mem_get_info() self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device) torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled") print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True self.disabled = True
def cuda_mem_get_info(self):
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
return torch.cuda.mem_get_info(index)
def run(self): def run(self):
if self.disabled: if self.disabled:
return return
...@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread): ...@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.clear() self.run_flag.clear()
continue continue
self.data["min_free"] = torch.cuda.mem_get_info()[0] self.data["min_free"] = self.cuda_mem_get_info()[0]
while self.run_flag.is_set(): while self.run_flag.is_set():
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free) self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate) time.sleep(1 / self.opts.memmon_poll_rate)
...@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread): ...@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
def read(self): def read(self):
if not self.disabled: if not self.disabled:
free, total = torch.cuda.mem_get_info() free, total = self.cuda_mem_get_info()
self.data["free"] = free self.data["free"] = free
self.data["total"] = total self.data["total"] = total
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册