Files
neuralchen-SimSwap/Sequential.patch
chenxuanhong 3ee304b0e1 update
2021-06-09 13:12:21 +08:00

70 lines
2.3 KiB
Diff

--- /usr/local/lib/python3.5/dist-packages/torch/nn/modules/container.py
+++ /usr/local/lib/python3.5/dist-packages/torch/nn/modules/container.py
@@ -22,15 +22,7 @@
]))
"""
- @overload
- def __init__(self, *args: Module) -> None:
- ...
-
- @overload
- def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
- ...
-
- def __init__(self, *args: Any):
+ def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
@@ -48,18 +40,17 @@
idx %= size
return next(islice(iterator, idx, None))
- @_copy_to_script_wrapper
- def __getitem__(self: T, idx) -> T:
+ def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
- def __setitem__(self, idx: int, module: Module) -> None:
+ def __setitem__(self, idx, module):
key = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
- def __delitem__(self, idx: Union[slice, int]) -> None:
+ def __delitem__(self, idx):
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
@@ -67,26 +58,16 @@
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
- @_copy_to_script_wrapper
- def __len__(self) -> int:
+ def __len__(self):
return len(self._modules)
- @_copy_to_script_wrapper
def __dir__(self):
keys = super(Sequential, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
- @_copy_to_script_wrapper
- def __iter__(self) -> Iterator[Module]:
- return iter(self._modules.values())
-
- # NB: We can't really type check this function as the type of input
- # may change dynamically (as is tested in
- # TestScript.test_sequential_intermediary_types). Cannot annotate
- # with Any as TorchScript expects a more precise type
def forward(self, input):
- for module in self:
+ for module in self._modules.values():
input = module(input)
return input