diff --git a/pytorch_wrapper/system.py b/pytorch_wrapper/system.py index f992230..56b2d62 100644 --- a/pytorch_wrapper/system.py +++ b/pytorch_wrapper/system.py @@ -514,6 +514,9 @@ def _execute_method_on_multi_gpus(self, return results + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) + class _Trainer(object):