diff --git a/datasets/preprocess/preprocess.py b/datasets/preprocess/preprocess.py index 0b10895..92e7655 100644 --- a/datasets/preprocess/preprocess.py +++ b/datasets/preprocess/preprocess.py @@ -8,13 +8,51 @@ def mini_imagenet(): with tarfile.open('mini_imagenet_full_size.tar.bz2', 'r') as tar: - tar.extractall() + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar) os.rename('mini_imagenet_full_size', 'mini_imagenet') def tiered_imagenet(): with tarfile.open('tiered_imagenet.tar', 'r') as tar: - tar.extractall() + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar) def CIFAR_FS(): @@ -35,7 +73,26 @@ def CIFAR_FS(): def CUB(): phase_list = ['train', 'val', 'test'] with tarfile.open('CUB_200_2011.tgz', 'r') as tar: - tar.extractall() + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar) for phase in phase_list: os.makedirs('CUB/{}'.format(phase)) for phase in phase_list: