97 lines
4.4 KiB
Diff
97 lines
4.4 KiB
Diff
--- /usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py
|
|
+++ /usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py
|
|
@@ -10,16 +10,13 @@
|
|
|
|
The batch size should be larger than the number of GPUs used.
|
|
|
|
- .. warning::
|
|
- It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
|
|
- instead of this class, to do multi-GPU training, even if there is only a single
|
|
- node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
|
|
+ See also: :ref:`cuda-nn-dataparallel-instead`
|
|
|
|
Arbitrary positional and keyword inputs are allowed to be passed into
|
|
- DataParallel but some types are specially handled. tensors will be
|
|
- **scattered** on dim specified (default 0). tuple, list and dict types will
|
|
- be shallow copied. The other types will be shared among different threads
|
|
- and can be corrupted if written to in the model's forward pass.
|
|
+ DataParallel EXCEPT Tensors. All tensors will be scattered on dim
|
|
+ specified (default 0). Primitive types will be broadcasted, but all
|
|
+ other types will be a shallow copy and can be corrupted if written to in
|
|
+ the model's forward pass.
|
|
|
|
The parallelized :attr:`module` must have its parameters and buffers on
|
|
``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
|
|
@@ -27,9 +24,9 @@
|
|
|
|
.. warning::
|
|
In each forward, :attr:`module` is **replicated** on each device, so any
|
|
- updates to the running module in ``forward`` will be lost. For example,
|
|
+ updates to the runing module in ``forward`` will be lost. For example,
|
|
if :attr:`module` has a counter attribute that is incremented in each
|
|
- ``forward``, it will always stay at the initial value because the update
|
|
+ ``forward``, it will always stay at the initial value becasue the update
|
|
is done on the replicas which are destroyed after ``forward``. However,
|
|
:class:`~torch.nn.DataParallel` guarantees that the replica on
|
|
``device[0]`` will have its parameters and buffers sharing storage with
|
|
@@ -74,7 +71,7 @@
|
|
Example::
|
|
|
|
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
|
|
- >>> output = net(input_var) # input_var can be on any device, including CPU
|
|
+ >>> output = net(input_var)
|
|
"""
|
|
|
|
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
|
|
@@ -82,15 +79,13 @@
|
|
def __init__(self, module, device_ids=None, output_device=None, dim=0):
|
|
super(DataParallel, self).__init__()
|
|
|
|
- device_type = _get_available_device_type()
|
|
- if device_type is None:
|
|
+ if not torch.cuda.is_available():
|
|
self.module = module
|
|
self.device_ids = []
|
|
return
|
|
|
|
if device_ids is None:
|
|
- device_ids = _get_all_device_indices()
|
|
-
|
|
+ device_ids = list(range(torch.cuda.device_count()))
|
|
if output_device is None:
|
|
output_device = device_ids[0]
|
|
|
|
@@ -98,23 +93,15 @@
|
|
self.module = module
|
|
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
|
self.output_device = _get_device_index(output_device, True)
|
|
- self.src_device_obj = torch.device(device_type, self.device_ids[0])
|
|
|
|
_check_balance(self.device_ids)
|
|
|
|
if len(self.device_ids) == 1:
|
|
- self.module.to(self.src_device_obj)
|
|
+ self.module.cuda(device_ids[0])
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
if not self.device_ids:
|
|
return self.module(*inputs, **kwargs)
|
|
-
|
|
- for t in chain(self.module.parameters(), self.module.buffers()):
|
|
- if t.device != self.src_device_obj:
|
|
- raise RuntimeError("module must have its parameters and buffers "
|
|
- "on device {} (device_ids[0]) but found one of "
|
|
- "them on device: {}".format(self.src_device_obj, t.device))
|
|
-
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
return self.module(*inputs[0], **kwargs[0])
|
|
@@ -123,7 +110,7 @@
|
|
return self.gather(outputs, self.output_device)
|
|
|
|
def replicate(self, module, device_ids):
|
|
- return replicate(module, device_ids, not torch.is_grad_enabled())
|
|
+ return replicate(module, device_ids)
|
|
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|