PyTorch 并行处理最佳实践
原文: https://pytorch.org/docs/stable/notes/multiprocessing.html
torch.multiprocessing
是 Python python:multiprocessing
模块的替代品。 它支持完全相同的操作,但是对其进行了扩展,以便所有通过python:multiprocessing.Queue
发送的张量将其数据移至共享内存中,并且仅将句柄发送给另一个进程。
注意
将 Tensor
发送到另一个进程时,将共享 Tensor
数据。 如果 torch.Tensor.grad
不是None,则也将其共享。 将没有 torch.Tensor.grad
字段的 Tensor
发送到另一个进程后,它将创建一个特定于标准进程的.grad Tensor
不会自动在所有进程之间共享,这与 Tensor
的数据共享方式不同。
这允许实施各种训练方法,例如 Hogwild,A3C 或任何其他需要异步操作的方法。
并行处理中的 CUDA
CUDA 运行时不支持fork
启动方法。 但是,Python 2 中的python:multiprocessing
只能使用fork
创建子进程。 因此,需要 Python 3 和spawn
或forkserver
启动方法在子进程中使用 CUDA。
Note
可以通过使用multiprocessing.get_context(...)
创建上下文或直接使用multiprocessing.set_start_method(...)
来设置启动方法。
与 CPU 张量不同,只要接收过程保留张量的副本,就需要发送过程来保留原始张量。 它是在幕后实施的,但要求用户遵循最佳实践才能使程序正确运行。 例如,只要使用者进程具有对张量的引用,发送进程就必须保持活动状态,并且如果使用者进程通过致命信号异常退出,则引用计数无法保存您。
最佳做法和提示
避免和消除僵局
产生新进程时,很多事情都会出错,死锁的最常见原因是后台线程。 如果有任何持有锁的线程或导入模块的线程,并且调用了fork
,则子进程很可能会处于损坏状态,并且将以不同的方式死锁或失败。 请注意,即使您不这样做,内置库的 Python 也会这样做-不需要比python:multiprocessing
看起来更深。 python:multiprocessing.Queue
实际上是一个非常复杂的类,它产生用于序列化,发送和接收对象的多个线程,它们也可能导致上述问题。
如果您遇到这种情况,请尝试使用SimpleQueue
,它不使用任何其他线程。
我们正在努力为您提供便利,并确保不会发生这些僵局,但有些事情是我们无法控制的。 如果您有一段时间无法解决的问题,请尝试与论坛联系,我们将解决是否可以解决的问题。
重用通过队列传递的缓冲区
请记住,每次将 Tensor
放入python:multiprocessing.Queue
时,都必须将其移到共享内存中。 如果已共享,则为空操作,否则将产生额外的内存副本,从而减慢整个过程。 即使您有一个将数据发送到单个进程的进程池,也要使它将缓冲区发送回去-这几乎是免费的,并且可以避免在发送下一批时复制。
异步多进程训练(例如 Hogwild)
使用 torch.multiprocessing
,可以异步训练模型,参数可以始终共享,也可以定期同步。 在第一种情况下,我们建议发送整个模型对象,而在后一种情况下,建议仅发送 state_dict() 。
我们建议使用python:multiprocessing.Queue
在进程之间传递各种 PyTorch 对象。 例如 当使用fork
start 方法时,它会继承共享内存中已经存在的张量和存储,但是,它非常容易出错,应谨慎使用,并且只有高级用户可以使用。 队列即使有时不是很优雅的解决方案,也可以在所有情况下正常工作。
警告
您应谨慎使用不受if __name__ == '__main__'
约束的全局语句。 如果使用与fork不同的启动方法,则将在所有子过程中执行它们。
霍格威尔德
您可以在示例存储库中找到具体的 Hogwild 实现,但为了展示代码的整体结构,下面还有一个最小的示例:
import torch.multiprocessing as mp
from model import MyModel
def train(model):
# Construct data_loader, optimizer, etc.
for data, labels in data_loader:
optimizer.zero_grad()
loss_fn(model(data), labels).backward()
optimizer.step() # This will update the shared parameters
if __name__ == '__main__':
num_processes = 4
model = MyModel()
# NOTE: this is required for the ``fork`` method to work
model.share_memory()
processes = []
for rank in range(num_processes):
p = mp.Process(target=train, args=(model,))
p.start()
processes.append(p)
for p in processes:
p.join()