70 lines
2.3 KiB
Diff
70 lines
2.3 KiB
Diff
--- /usr/local/lib/python3.5/dist-packages/torch/nn/modules/container.py
|
|
+++ /usr/local/lib/python3.5/dist-packages/torch/nn/modules/container.py
|
|
@@ -22,15 +22,7 @@
|
|
]))
|
|
"""
|
|
|
|
- @overload
|
|
- def __init__(self, *args: Module) -> None:
|
|
- ...
|
|
-
|
|
- @overload
|
|
- def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
|
|
- ...
|
|
-
|
|
- def __init__(self, *args: Any):
|
|
+ def __init__(self, *args):
|
|
super(Sequential, self).__init__()
|
|
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
|
for key, module in args[0].items():
|
|
@@ -48,18 +40,17 @@
|
|
idx %= size
|
|
return next(islice(iterator, idx, None))
|
|
|
|
- @_copy_to_script_wrapper
|
|
- def __getitem__(self: T, idx) -> T:
|
|
+ def __getitem__(self, idx):
|
|
if isinstance(idx, slice):
|
|
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
|
else:
|
|
return self._get_item_by_idx(self._modules.values(), idx)
|
|
|
|
- def __setitem__(self, idx: int, module: Module) -> None:
|
|
+ def __setitem__(self, idx, module):
|
|
key = self._get_item_by_idx(self._modules.keys(), idx)
|
|
return setattr(self, key, module)
|
|
|
|
- def __delitem__(self, idx: Union[slice, int]) -> None:
|
|
+ def __delitem__(self, idx):
|
|
if isinstance(idx, slice):
|
|
for key in list(self._modules.keys())[idx]:
|
|
delattr(self, key)
|
|
@@ -67,26 +58,16 @@
|
|
key = self._get_item_by_idx(self._modules.keys(), idx)
|
|
delattr(self, key)
|
|
|
|
- @_copy_to_script_wrapper
|
|
- def __len__(self) -> int:
|
|
+ def __len__(self):
|
|
return len(self._modules)
|
|
|
|
- @_copy_to_script_wrapper
|
|
def __dir__(self):
|
|
keys = super(Sequential, self).__dir__()
|
|
keys = [key for key in keys if not key.isdigit()]
|
|
return keys
|
|
|
|
- @_copy_to_script_wrapper
|
|
- def __iter__(self) -> Iterator[Module]:
|
|
- return iter(self._modules.values())
|
|
-
|
|
- # NB: We can't really type check this function as the type of input
|
|
- # may change dynamically (as is tested in
|
|
- # TestScript.test_sequential_intermediary_types). Cannot annotate
|
|
- # with Any as TorchScript expects a more precise type
|
|
def forward(self, input):
|
|
- for module in self:
|
|
+ for module in self._modules.values():
|
|
input = module(input)
|
|
return input
|
|
|