update
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
--- /usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py
|
||||
+++ /usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py
|
||||
@@ -10,16 +10,13 @@
|
||||
|
||||
The batch size should be larger than the number of GPUs used.
|
||||
|
||||
- .. warning::
|
||||
- It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
|
||||
- instead of this class, to do multi-GPU training, even if there is only a single
|
||||
- node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
|
||||
+ See also: :ref:`cuda-nn-dataparallel-instead`
|
||||
|
||||
Arbitrary positional and keyword inputs are allowed to be passed into
|
||||
- DataParallel but some types are specially handled. tensors will be
|
||||
- **scattered** on dim specified (default 0). tuple, list and dict types will
|
||||
- be shallow copied. The other types will be shared among different threads
|
||||
- and can be corrupted if written to in the model's forward pass.
|
||||
+ DataParallel EXCEPT Tensors. All tensors will be scattered on dim
|
||||
+ specified (default 0). Primitive types will be broadcasted, but all
|
||||
+ other types will be a shallow copy and can be corrupted if written to in
|
||||
+ the model's forward pass.
|
||||
|
||||
The parallelized :attr:`module` must have its parameters and buffers on
|
||||
``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
|
||||
@@ -27,9 +24,9 @@
|
||||
|
||||
.. warning::
|
||||
In each forward, :attr:`module` is **replicated** on each device, so any
|
||||
- updates to the running module in ``forward`` will be lost. For example,
|
||||
+ updates to the runing module in ``forward`` will be lost. For example,
|
||||
if :attr:`module` has a counter attribute that is incremented in each
|
||||
- ``forward``, it will always stay at the initial value because the update
|
||||
+ ``forward``, it will always stay at the initial value becasue the update
|
||||
is done on the replicas which are destroyed after ``forward``. However,
|
||||
:class:`~torch.nn.DataParallel` guarantees that the replica on
|
||||
``device[0]`` will have its parameters and buffers sharing storage with
|
||||
@@ -74,7 +71,7 @@
|
||||
Example::
|
||||
|
||||
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
|
||||
- >>> output = net(input_var) # input_var can be on any device, including CPU
|
||||
+ >>> output = net(input_var)
|
||||
"""
|
||||
|
||||
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
|
||||
@@ -82,15 +79,13 @@
|
||||
def __init__(self, module, device_ids=None, output_device=None, dim=0):
|
||||
super(DataParallel, self).__init__()
|
||||
|
||||
- device_type = _get_available_device_type()
|
||||
- if device_type is None:
|
||||
+ if not torch.cuda.is_available():
|
||||
self.module = module
|
||||
self.device_ids = []
|
||||
return
|
||||
|
||||
if device_ids is None:
|
||||
- device_ids = _get_all_device_indices()
|
||||
-
|
||||
+ device_ids = list(range(torch.cuda.device_count()))
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
@@ -98,23 +93,15 @@
|
||||
self.module = module
|
||||
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
- self.src_device_obj = torch.device(device_type, self.device_ids[0])
|
||||
|
||||
_check_balance(self.device_ids)
|
||||
|
||||
if len(self.device_ids) == 1:
|
||||
- self.module.to(self.src_device_obj)
|
||||
+ self.module.cuda(device_ids[0])
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
-
|
||||
- for t in chain(self.module.parameters(), self.module.buffers()):
|
||||
- if t.device != self.src_device_obj:
|
||||
- raise RuntimeError("module must have its parameters and buffers "
|
||||
- "on device {} (device_ids[0]) but found one of "
|
||||
- "them on device: {}".format(self.src_device_obj, t.device))
|
||||
-
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
@@ -123,7 +110,7 @@
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
||||
def replicate(self, module, device_ids):
|
||||
- return replicate(module, device_ids, not torch.is_grad_enabled())
|
||||
+ return replicate(module, device_ids)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||
Reference in New Issue
Block a user