Files
neuralchen-SimSwap/DataParallel.patch
chenxuanhong 3ee304b0e1 update
2021-06-09 13:12:21 +08:00

97 lines
4.4 KiB
Diff

--- /usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py
+++ /usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py
@@ -10,16 +10,13 @@
The batch size should be larger than the number of GPUs used.
- .. warning::
- It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
- instead of this class, to do multi-GPU training, even if there is only a single
- node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
+ See also: :ref:`cuda-nn-dataparallel-instead`
Arbitrary positional and keyword inputs are allowed to be passed into
- DataParallel but some types are specially handled. tensors will be
- **scattered** on dim specified (default 0). tuple, list and dict types will
- be shallow copied. The other types will be shared among different threads
- and can be corrupted if written to in the model's forward pass.
+ DataParallel EXCEPT Tensors. All tensors will be scattered on dim
+ specified (default 0). Primitive types will be broadcasted, but all
+ other types will be a shallow copy and can be corrupted if written to in
+ the model's forward pass.
The parallelized :attr:`module` must have its parameters and buffers on
``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
@@ -27,9 +24,9 @@
.. warning::
In each forward, :attr:`module` is **replicated** on each device, so any
- updates to the running module in ``forward`` will be lost. For example,
+ updates to the runing module in ``forward`` will be lost. For example,
if :attr:`module` has a counter attribute that is incremented in each
- ``forward``, it will always stay at the initial value because the update
+ ``forward``, it will always stay at the initial value becasue the update
is done on the replicas which are destroyed after ``forward``. However,
:class:`~torch.nn.DataParallel` guarantees that the replica on
``device[0]`` will have its parameters and buffers sharing storage with
@@ -74,7 +71,7 @@
Example::
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
- >>> output = net(input_var) # input_var can be on any device, including CPU
+ >>> output = net(input_var)
"""
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
@@ -82,15 +79,13 @@
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()
- device_type = _get_available_device_type()
- if device_type is None:
+ if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return
if device_ids is None:
- device_ids = _get_all_device_indices()
-
+ device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
@@ -98,23 +93,15 @@
self.module = module
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
- self.src_device_obj = torch.device(device_type, self.device_ids[0])
_check_balance(self.device_ids)
if len(self.device_ids) == 1:
- self.module.to(self.src_device_obj)
+ self.module.cuda(device_ids[0])
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
-
- for t in chain(self.module.parameters(), self.module.buffers()):
- if t.device != self.src_device_obj:
- raise RuntimeError("module must have its parameters and buffers "
- "on device {} (device_ids[0]) but found one of "
- "them on device: {}".format(self.src_device_obj, t.device))
-
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
@@ -123,7 +110,7 @@
return self.gather(outputs, self.output_device)
def replicate(self, module, device_ids):
- return replicate(module, device_ids, not torch.is_grad_enabled())
+ return replicate(module, device_ids)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)