diff --git a/.gitignore b/.gitignore index 954f6dfb..c60a7810 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ src __pycache__/ *.py[cod] *$py.class - +outputs/ .vscode # C extensions diff --git a/.gitmodules b/.gitmodules index 21f138b4..e69de29b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "examples/imagenet-example"] - path = examples/imagenet-example - url = git@github.com:libffcv/ffcv-imagenet.git diff --git a/.nojekyll b/.nojekyll deleted file mode 100644 index e69de29b..00000000 diff --git a/README.md b/README.md index 74c403a3..1b6a4b5b 100644 --- a/README.md +++ b/README.md @@ -1,133 +1,23 @@ -
-Fast Forward Computer Vision: train models at a fraction of the cost with accelerated data loading! -
-
[install]
-[quickstart]
-[features]
+[new features]
[docs]
-[support slack]
-[homepage]
[paper]
-
-Maintainers:
-Guillaume Leclerc,
-Andrew Ilyas and
-Logan Engstrom
-
-Computer vision or not, FFCV can help make training faster in a variety of
-resource-constrained settings!
-Our performance guide
-has a more detailed account of the ways in which FFCV can adapt to different
-performance bottlenecks.
-
-
-- **Plug-and-play with any existing training code**: Rather than changing
- aspects of model training itself, FFCV focuses on removing *data bottlenecks*,
- which turn out to be a problem everywhere from neural network training to
- linear regression. This means that:
-
- - FFCV can be introduced into any existing training code in just a few
- lines of code (e.g., just swapping out the data loader and optionally the
- augmentation pipeline);
- - You don't have to change the model itself to make it faster (e.g., feel
- free to analyze models *without* CutMix, Dropout, momentum scheduling, etc.);
- - FFCV can speed up a lot more beyond just neural network training---in
- fact, the more data-bottlenecked the application (e.g., linear regression,
- bulk inference, etc.), the faster FFCV will make it!
-
- See our [Getting started](https://docs.ffcv.io/basics.html) guide,
- [Example walkthroughs](https://docs.ffcv.io/examples.html), and
- [Code examples](https://github.com/libffcv/ffcv/tree/main/examples)
- to see how easy it is to get started!
-- **Fast data processing without the pain**: FFCV automatically handles data
- reading, pre-fetching, caching, and transfer between devices in an extremely
- efficiently way, so that users don't have to think about it.
-- **Automatically fused-and-compiled data processing**: By either using
- [pre-written](https://docs.ffcv.io/api/transforms.html) FFCV transformations
- or
- [easily writing custom ones](https://docs.ffcv.io/ffcv_examples/custom_transforms.html),
- users can
- take advantage of FFCV's compilation and pipelining abilities, which will
- automatically fuse and compile simple Python augmentations to machine code
- using [Numba](https://numba.pydata.org), and schedule them asynchronously to avoid
- loading delays.
-- **Load data fast from RAM, SSD, or networked disk**: FFCV exposes
- user-friendly options that can be adjusted based on the resources
- available. For example, if a dataset fits into memory, FFCV can cache it
- at the OS level and ensure that multiple concurrent processes all get fast
- data access. Otherwise, FFCV can use fast process-level caching and will
- optimize data loading to minimize the underlying number of disk reads. See
- [The Bottleneck Doctor](https://docs.ffcv.io/bottleneck_doctor.html)
- guide for more information.
-- **Training multiple models per GPU**: Thanks to fully asynchronous
- thread-based data loading, you can now interleave training multiple models on
- the same GPU efficiently, without any data-loading overhead. See
- [this guide](https://docs.ffcv.io/parameter_tuning.html) for more info.
-- **Dedicated tools for image handling**: All the features above work are
- equally applicable to all sorts of machine learning models, but FFCV also
- offers some vision-specific features, such as fast JPEG encoding and decoding,
- storing datasets as mixtures of raw and compressed images to trade off I/O
- overhead and compute overhead, etc. See the
- [Working with images](https://docs.ffcv.io/working_with_images.html) guide for
- more information.
-
-# Contributors
-
-- [Guillaume Leclerc](https://github.com/GuillaumeLeclerc)
-- [Logan Engstrom](http://loganengstrom.com/)
-- [Andrew Ilyas](http://andrewilyas.com/)
-- [Sam Park](http://sungminpark.com/)
-- [Hadi Salman](http://hadisalman.com/)
+
+Compared to the original FFCV, this library has the following new features:
+
+- **crop decode**: RandomCrop and CenterCrop are now implemented to decode the crop region, which can save memory and accelerate decoding.
+
+- **cache strategy**: There is a potential issue that the OS cache will be swapped out. We use `FFCV_DEFAULT_CACHE_PROCESS` to control the cache process. The choices for the cache process are:
+ - `0`: os cache
+ - `1`: process cache
+ - `2`: Shared Memory
+
+- **lossless compression**: PNG is supported for lossless compression. We use `RGBImageField(mode='png')` to enable the lossless compression.
+
+- **few memory**: We optimize the memory usage and accelerate data loading.
+
+Comparison of throughput:
+
+| img\_size | 112 | 160 | 192 | 224 | | | | | 512 |
+|--------------|--------:|--------:|--------:|:-------:|--------:|--------:|--------:|--------:|-------:|
+| batch\_size | 512 | 512 | 512 | 128 | 256 | 512 | | | 512 |
+| num\_workers | 10 | 10 | 10 | 10 | 10 | 5 | 10 | 20 | 10 |
+| loader | | | | | | | | | |
+| ours | 23024.0 | 19396.5 | 16503.6 | 16536.1 | 16338.5 | 12369.7 | 14521.4 | 14854.6 | 4260.3 |
+| ffcv | 16853.2 | 13906.3 | 13598.4 | 12192.7 | 11960.2 | 9112.7 | 12539.4 | 12601.8 | 3577.8 |
+
+Comparison of memory usage:
+| img\_size | 112 | 160 | 192 | 224 | | | | | 512 |
+|--------------|-----:|-----:|-----:|:---:|-----:|-----:|-----:|-----:|-----:|
+| batch\_size | 512 | 512 | 512 | 128 | 256 | 512 | | | 512 |
+| num\_workers | 10 | 10 | 10 | 10 | 10 | 5 | 10 | 20 | 10 |
+| loader | | | | | | | | | |
+| ours | 9.0 | 9.8 | 11.4 | 5.8 | 7.7 | 11.4 | 11.4 | 11.4 | 34.0 |
+| ffcv | 13.4 | 14.8 | 17.7 | 7.6 | 11.0 | 17.7 | 17.7 | 17.7 | 56.6 |
+
diff --git a/examples/benchmark.py b/examples/benchmark.py
new file mode 100644
index 00000000..a8658945
--- /dev/null
+++ b/examples/benchmark.py
@@ -0,0 +1,263 @@
+import argparse
+import builtins
+import datetime
+import json
+import math
+import os
+import sys
+import time
+from pathlib import Path
+
+from ffcv.loader import Loader, OrderOption
+import gin
+import numpy as np
+import timm
+import torch.backends.cudnn as cudnn
+from PIL import Image # a trick to solve loading lib problem
+from tqdm import tqdm
+
+assert timm.__version__ >= "0.6.12" # version check
+from torchvision import datasets
+import ffcv
+
+from psutil import Process, net_io_counters
+import socket
+import json
+from os import getpid
+
+from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View, Convert
+from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder
+
+import torch
+
+IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
+IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
+
+@gin.configurable
+def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0),device='cuda'):
+ device = torch.device(device)
+ image_pipeline = [
+ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio,),
+ RandomHorizontalFlip(),
+ NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
+ ToTensor(), ToTorchImage(),
+ ]
+ label_pipeline = [IntDecoder(), ToTensor(),ToDevice(device), View(-1)]
+ # Pipeline for each data field
+ pipelines = {
+ 'image': image_pipeline,
+ 'label': label_pipeline
+ }
+ return pipelines
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Data loading benchmark', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int,
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)')
+ parser.add_argument('--epochs', default=5, type=int)
+ parser.add_argument('--img_size', default=224,type=int)
+
+ # Dataset parameters
+ parser.add_argument('--data_set', default='ffcv')
+ parser.add_argument("--cache_type",type=int, default=0,)
+ parser.add_argument('--data_path', default=os.getenv("IMAGENET_DIR"), type=str,
+ help='dataset path')
+
+ parser.add_argument('--output_dir', default=None, type=str,
+ help='path where to save, empty for no saving')
+
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+
+ parser.add_argument('--num_workers', default=10, type=int)
+ parser.add_argument('--pin_mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
+ parser.set_defaults(pin_mem=True)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local-rank','--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+
+ return parser
+
+
+class ramqdm(tqdm):
+ """tqdm progress bar that reports RAM usage with each update"""
+ _empty_desc = "using ? GB RAM; ? CPU ? IO"
+ _desc = "{:.2f} GB RAM; {:.2f} % CPU {:.2f} MB IO"
+ _GB = 10**9
+ """"""
+ def __init__(self, *args, **kwargs):
+ """Override desc and get reference to current process"""
+ if "desc" in kwargs:
+ # prepend desc to the reporter mask:
+ self._empty_desc = kwargs["desc"] + " " + self._empty_desc
+ self._desc = kwargs["desc"] + " " + self._desc
+ del kwargs["desc"]
+ else:
+ # nothing to prepend, reporter mask is at start of sentence:
+ self._empty_desc = self._empty_desc.capitalize()
+ self._desc = self._desc.capitalize()
+ super().__init__(*args, desc=self._empty_desc, **kwargs)
+ self._process = Process(getpid())
+ self.metrics = []
+ """"""
+ def update(self, n=1):
+ """Calculate RAM usage and update progress bar"""
+ rss = self._process.memory_info().rss
+ ps = self._process.cpu_percent()
+ io_counters = self._process.io_counters().read_bytes
+ # net_io = net_io_counters().bytes_recv
+ # io_counters += net_io
+
+ current_desc = self._desc.format(rss/self._GB, ps, io_counters/1e6)
+ self.set_description(current_desc)
+ self.metrics.append({'mem':rss/self._GB, 'cpu':ps, 'io':io_counters/1e6})
+ super().update(n)
+
+ def summary(self):
+ res = {}
+ for key in self.metrics[0].keys():
+ res[key] = np.mean([i[key] for i in self.metrics])
+ return res
+
+@gin.configurable(denylist=["args"])
+def build_dataset(args, transform_fn=SimplePipeline):
+ transform_train = transform_fn(img_size=args.img_size)
+ if args.data_set == 'IF':
+ # simple augmentation
+ dataset_train = datasets.ImageFolder(args.data_path, transform=transform_train)
+ elif args.data_set == 'cifar10':
+ dataset_train = datasets.CIFAR10(args.data_path, transform=transform_train)
+ elif args.data_set == 'ffcv':
+ order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM
+ dataset_train = Loader(args.data_path, pipelines=transform_train,
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ batches_ahead=4, #cache_type=args.cache_type,
+ order=order, distributed=args.distributed,seed=args.seed,drop_last=True)
+ else:
+ raise ValueError("Wrong dataset: ", args.data_set)
+ return dataset_train
+
+def load_one_epoch(args,loader):
+ start = time.time()
+ l=ramqdm(loader,disable=args.rank>0)
+
+ for x1,y in l:
+ x1.mean()
+ torch.cuda.synchronize()
+
+ end = time.time()
+
+ if args.rank ==0:
+ res = l.summary()
+ throughput=loader.reader.num_samples/(end-start)
+ res['throughput'] = throughput
+ return res
+
+import torch
+
+def main(args):
+ init_distributed_mode(args)
+
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(', ', ',\n'))
+
+ cudnn.benchmark = True
+
+ # build dataset
+ dataset_train = build_dataset(args)
+
+ num_tasks = args.world_size
+ global_rank = args.rank
+ if args.data_set != "ffcv":
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ )
+ else:
+ data_loader_train = dataset_train
+
+ for epoch in range(args.epochs):
+ res = load_one_epoch(args,data_loader_train)
+ if res:
+ throughput = res['throughput']
+ print(f"Throughput: {throughput:.2f} samples/s for {args.data_path}.")
+ res.update(args.__dict__)
+ res['version'] = ffcv.__version__
+ res['hostname'] = socket.gethostname()
+ res['epoch'] = epoch
+ if args.output_dir:
+ with open(os.path.join(args.output_dir,"data_loading.txt"),"a") as file:
+ file.write(json.dumps(res)+"\n")
+
+
+def init_distributed_mode(args):
+ if hasattr(args,'dist_on_itp') and args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now), *args, **kwargs) # print with time stamp
+
+ builtins.print = print
+
+
+
+if __name__ == '__main__':
+ parser = get_args_parser()
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/docs_examples/linear_regression.py b/examples/docs_examples/linear_regression.py
index f9a6e81c..5431448a 100644
--- a/examples/docs_examples/linear_regression.py
+++ b/examples/docs_examples/linear_regression.py
@@ -57,7 +57,7 @@ def __len__(self):
train_loader = DataLoader(dataset, batch_size=2048, num_workers=8, shuffle=True)
else:
train_loader = Loader('/tmp/linreg_data.beton', batch_size=2048,
- num_workers=8, order=OrderOption.QUASI_RANDOM, os_cache=False,
+ num_workers=8, order=OrderOption.QUASI_RANDOM, cache_type=1,
pipelines={
'covariate': [NDArrayDecoder(), ToTensor(), ToDevice(ch.device('cuda:0'))],
'label': [NDArrayDecoder(), ToTensor(), Squeeze(), ToDevice(ch.device('cuda:0'))]
diff --git a/examples/imagenet-example b/examples/imagenet-example
deleted file mode 160000
index f134cbff..00000000
--- a/examples/imagenet-example
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit f134cbfff7f590954edc5c24275444b7dd2f57f6
diff --git a/examples/profiler.py b/examples/profiler.py
new file mode 100644
index 00000000..c959baee
--- /dev/null
+++ b/examples/profiler.py
@@ -0,0 +1,155 @@
+#%%
+
+import time
+from PIL import Image# a trick to solve loading lib problem
+from ffcv.fields.rgb_image import *
+from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, ToTensor, ToTorchImage, ToDevice
+import numpy as np
+import torchvision
+
+from ffcv import Loader
+import ffcv
+import argparse
+from tqdm.auto import tqdm,trange
+import torch.nn as nn
+import torch
+from psutil import Process, net_io_counters
+import json
+from os import getpid
+
+from ffcv.transforms.ops import Convert
+
+IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
+IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
+
+class ramqdm(tqdm):
+ """tqdm progress bar that reports RAM usage with each update"""
+ _empty_desc = "using ? GB RAM; ? CPU ? IO"
+ _desc = "{:.2f} GB RAM; {:.2f} % CPU {:.2f} MB IO"
+ _GB = 10**9
+ """"""
+ def __init__(self, *args, **kwargs):
+ """Override desc and get reference to current process"""
+ if "desc" in kwargs:
+ # prepend desc to the reporter mask:
+ self._empty_desc = kwargs["desc"] + " " + self._empty_desc
+ self._desc = kwargs["desc"] + " " + self._desc
+ del kwargs["desc"]
+ else:
+ # nothing to prepend, reporter mask is at start of sentence:
+ self._empty_desc = self._empty_desc.capitalize()
+ self._desc = self._desc.capitalize()
+ super().__init__(*args, desc=self._empty_desc, **kwargs)
+ self._process = Process(getpid())
+ self.metrics = []
+ """"""
+ def update(self, n=1):
+ """Calculate RAM usage and update progress bar"""
+ rss = self._process.memory_info().rss
+ ps = self._process.cpu_percent()
+ io_counters = self._process.io_counters().read_bytes
+ # net_io = net_io_counters().bytes_recv
+ # io_counters += net_io
+
+ current_desc = self._desc.format(rss/self._GB, ps, io_counters/1e6)
+ self.set_description(current_desc)
+ self.metrics.append({'mem':rss/self._GB, 'cpu':ps, 'io':io_counters/1e6})
+ super().update(n)
+
+ def summary(self):
+ res = {}
+ for key in self.metrics[0].keys():
+ res[key] = np.mean([i[key] for i in self.metrics])
+ return res
+
+
+def load_one_epoch(args,loader):
+ start = time.time()
+ l=ramqdm(loader)
+
+ for x1,y in l:
+ pass
+ end = time.time()
+ res = l.summary()
+ try:
+ throughput=loader.reader.num_samples/(end-start)
+ except:
+ throughput=len(loader.dataset)/(end-start)
+ res['throughput'] = throughput
+ x1 = x1.float()
+ print("Mean: ", x1.mean().item(), "Std: ", x1.std().item())
+ return res
+
+def main(args):
+ if args.no_ffcv:
+ tfms = torchvision.transforms.Compose([
+ torchvision.transforms.RandomResizedCrop(args.img_size),
+ torchvision.transforms.RandomHorizontalFlip(),
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Normalize(IMAGENET_MEAN/255, IMAGENET_STD/255),
+ ])
+ dataset = torchvision.datasets.ImageFolder(args.data_path, transform=tfms)
+ loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True)
+ else:
+ pipe = {
+ 'image': [RandomResizedCropRGBImageDecoder((args.img_size,args.img_size)),
+ RandomHorizontalFlip(),
+ ToTensor(),
+ # ToDevice(torch.device('cuda')),
+ ToTorchImage(),
+ # NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
+ # Convert(torch.float16),
+ ]
+ }
+ loader = Loader(args.data_path, batch_size=args.batch_size, num_workers=args.num_workers,
+ pipelines=pipe,
+ batches_ahead=2, distributed=False,seed=0,drop_last=True)
+
+
+ # warmup
+ load_one_epoch(args,loader)
+
+ for _ in range(args.repeat):
+ res = load_one_epoch(args,loader)
+ yield res
+
+#%%
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="FFCV Profiler")
+ parser.add_argument("-r", "--repeat", type=int, default=3, help="number of samples to record one step for profile.")
+ parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size")
+ parser.add_argument("-p", "--data_path", type=str, help="data path", required=True)
+ parser.add_argument("--no_ffcv",default=False,action="store_true")
+ parser.add_argument("--num_workers", type=int, default=10, help="number of workers")
+ parser.add_argument("--exp", default=False, action="store_true", help="run experiments")
+ parser.add_argument("--img_size", type=int, default=224, help="image size")
+ parser.add_argument("--write_path", type=str, help='path to write result',default=None)
+ args = parser.parse_args()
+ if args.exp == False:
+ for res in main(args):
+ throughput = res['throughput']
+ print(f"Throughput: {throughput:.2f} samples/s for {args.data_path}.")
+ res.update(args.__dict__)
+ if args.write_path:
+ with open(args.write_path,"a") as file:
+ file.write(json.dumps(res)+"\n")
+ else:
+ data = []
+ with open(args.write_path,"a") as file:
+ for num_workers in [10,20,40]:
+ for use_ffcv in [False,True]:
+ for bs in [128,256,512]:
+ args.num_workers=num_workers
+ args.batch_size = bs
+ args.use_ffcv=use_ffcv
+ row = args.__dict__
+ for res in main(args):
+ row.update(res)
+ file.write(json.dumps(row)+"\n")
+ file.flush()
+ print(row)
+ data.append(row)
+ import pandas as pd
+ df = pd.DataFrame(data)
+ print(df)
+ exit(0)
\ No newline at end of file
diff --git a/examples/vis_loader.py b/examples/vis_loader.py
new file mode 100644
index 00000000..770f51a0
--- /dev/null
+++ b/examples/vis_loader.py
@@ -0,0 +1,47 @@
+import argparse
+import time
+from PIL import Image # a trick to solve loading lib problem
+from ffcv import Loader
+from ffcv.transforms import *
+from ffcv.fields.decoders import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder
+
+
+import numpy as np
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='FFCV Profiler')
+ parser.add_argument('data_path', type=str, default='data/imagenet', help='Path to the dataset')
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
+ parser.add_argument('--write_path', type=str, default='viz.png', help='Path to write result')
+ args = parser.parse_args()
+
+ loader = Loader(args.data_path, batch_size=args.batch_size, num_workers=10, cache_type=0, pipelines={
+ 'image':[CenterCropRGBImageDecoder((224, 224),224/256),
+ ToTensor(),
+ ToTorchImage()]
+ }, batches_ahead=0,)
+
+ print("num samples: ", loader.reader.num_samples, "fields: ", loader.reader.field_names)
+ for x,_ in loader:
+ x1 = x.float()
+ print("Mean: ", x1.mean().item(), "Std: ", x1.std().item())
+ break
+
+ print('Done')
+ num = int(np.sqrt(args.batch_size))
+ import cv2
+
+ image = np.zeros((224*num, 224*num, 3), dtype=np.uint8)
+ for i in range(num):
+ for j in range(num):
+ if i*num+j >= args.batch_size:
+ break
+ img = x[i*num+j].numpy().transpose(1,2,0)
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ image[i*224:(i+1)*224, j*224:(j+1)*224] = (img).astype(np.uint8)
+
+ if args.write_path:
+ Image.fromarray(image).save(args.write_path)
+
+
diff --git a/examples/write_dataset.py b/examples/write_dataset.py
new file mode 100644
index 00000000..a9b38bef
--- /dev/null
+++ b/examples/write_dataset.py
@@ -0,0 +1,106 @@
+"""example usage:
+export IMAGENET_DIR=/path/to/pytorch/format/imagenet/directory/
+export WRITE_DIR=/your/path/here/
+write_dataset train 500 0.50 90
+write_path=$WRITE_DIR/train500_0.5_90.ffcv
+echo "Writing ImageNet train dataset to ${write_path}"
+python examples/write_dataset.py \
+ --cfg.data_dir=$IMAGENET_DIR \
+ --cfg.write_path=$write_path \
+ --cfg.max_resolution=500 \
+ --cfg.write_mode=smart \
+ --cfg.compress_probability=0.50 \
+ --cfg.jpeg_quality=90
+"""
+from PIL import Image
+from torch.utils.data import Subset
+from ffcv.writer import DatasetWriter
+from ffcv.fields import IntField, RGBImageField
+import torchvision
+from torchvision.datasets import ImageFolder
+import torchvision.datasets as torch_datasets
+
+from argparse import ArgumentParser
+from fastargs import Section, Param
+from fastargs.validation import And, OneOf
+from fastargs.decorators import param, section
+from fastargs import get_current_config
+import cv2
+import numpy as np
+
+# hack resizer
+# def resizer(image, target_resolution):
+# if target_resolution is None:
+# return image
+# original_size = np.array([image.shape[1], image.shape[0]])
+# ratio = target_resolution / original_size.min()
+# if ratio < 1:
+# new_size = (ratio * original_size).astype(int)
+# image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA)
+# return image
+# from ffcv.fields import rgb_image
+# rgb_image.resizer = resizer
+
+Section('cfg', 'arguments to give the writer').params(
+ dataset=Param(And(str, OneOf(['cifar', 'imagenet'])), 'Which dataset to write', default='imagenet'),
+ data_dir=Param(str, 'Where to find the PyTorch dataset', required=True),
+ write_path=Param(str, 'Where to write the new dataset', required=True),
+ write_mode=Param(str, 'Mode: raw, smart or jpg', required=False, default='smart'),
+ max_resolution=Param(int, 'Max image side length. 0 any size.', required=False,default=0),
+ num_workers=Param(int, 'Number of workers to use', default=16),
+ chunk_size=Param(int, 'Chunk size for writing', default=100),
+ jpeg_quality=Param(float, 'Quality of jpeg images', default=90),
+ subset=Param(int, 'How many images to use (-1 for all)', default=-1),
+ compress_probability=Param(float, 'compress probability', default=0.5),
+ threshold=Param(int, 'threshold for smart mode to compress by jpeg', default=286432),
+)
+
+@section('cfg')
+@param('dataset')
+@param('data_dir')
+@param('write_path')
+@param('max_resolution')
+@param('num_workers')
+@param('chunk_size')
+@param('subset')
+@param('jpeg_quality')
+@param('write_mode')
+@param('compress_probability')
+@param('threshold')
+def main(dataset, data_dir, write_path, max_resolution, num_workers,
+ chunk_size, subset, jpeg_quality, write_mode,
+ compress_probability, threshold):
+ if dataset == 'imagenet':
+ my_dataset = ImageFolder(root=data_dir)
+ elif dataset == 'cifar':
+ tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
+ my_dataset = torch_datasets.CIFAR10(root=data_dir, train=True, download=True)
+ else:
+ raise ValueError('Unknown dataset')
+
+
+ if subset > 0: my_dataset = Subset(my_dataset, range(subset))
+ writer = DatasetWriter(write_path, {
+ 'image': RGBImageField(write_mode=write_mode,
+ max_resolution=None if max_resolution==0 else max_resolution,
+ compress_probability=compress_probability,
+ jpeg_quality=jpeg_quality,
+ smart_threshold=threshold),
+ 'label': IntField(),
+ }, num_workers=num_workers)
+
+ writer.from_indexed_dataset(my_dataset, chunksize=chunk_size,shuffle_indices=False)
+
+if __name__ == '__main__':
+ config = get_current_config()
+ parser = ArgumentParser()
+ config.augment_argparse(parser)
+ config.collect_argparse_args(parser)
+ config.validate(mode='stderr')
+ config.summary()
+
+ args=config.get().cfg
+ assert args.write_path.endswith('.ffcv'), 'write_path must end with .ffcv'
+ file=open(args.write_path.replace(".ffcv",".meta"), 'w')
+ file.write(str(args.__dict__))
+ main()
diff --git a/ffcv-conda.yml b/ffcv-conda.yml
index f332ea51..f3f39ca9 100644
--- a/ffcv-conda.yml
+++ b/ffcv-conda.yml
@@ -1,4 +1,4 @@
-name: ffcv19
+name: ffcv
channels:
- pytorch
- defaults
diff --git a/ffcv/.DS_Store b/ffcv/.DS_Store
deleted file mode 100644
index f1fad9b3..00000000
Binary files a/ffcv/.DS_Store and /dev/null differ
diff --git a/ffcv/benchmarks/decorator.py b/ffcv/benchmarks/decorator.py
index 4e8d75a7..9651eab3 100644
--- a/ffcv/benchmarks/decorator.py
+++ b/ffcv/benchmarks/decorator.py
@@ -1,3 +1,4 @@
+import tracemalloc
from itertools import product
from time import time
from collections import defaultdict
@@ -46,6 +47,8 @@ def run_all(runs=3, warm_up=1, pattern='*'):
for args in it_args:
# with redirect_stderr(FakeSink()):
+ # Start tracing memory allocations
+ tracemalloc.start()
if True:
benchmark: Benchmark = cls(**args)
with benchmark:
@@ -57,7 +60,9 @@ def run_all(runs=3, warm_up=1, pattern='*'):
start = time()
benchmark.run()
timings.append(time() - start)
-
+ # Stop tracing memory allocations
+ current, peak = tracemalloc.get_traced_memory()
+ tracemalloc.stop()
median_time = np.median(timings)
throughput = None
@@ -66,16 +71,13 @@ def run_all(runs=3, warm_up=1, pattern='*'):
throughput = args['n'] / median_time
unit = 'it/sec'
- if throughput < 1:
- unit = 'sec/it'
- throughput = 1 /throughput
-
- throughput = np.round(throughput * 10) / 10
results[suite_name].append({
**args,
'time': median_time,
- 'throughput': str(throughput) + ' ' + unit
+ f'throughput ({unit})': f"{throughput:.2f}",
+ 'current_memory (MB)': current / 10**6,
+ 'peak_memory (MB)': peak / 10**6,
})
it_args.close()
it_suite.close()
diff --git a/ffcv/benchmarks/suites/image_read.py b/ffcv/benchmarks/suites/image_read.py
index 89f09f46..7ce54d57 100644
--- a/ffcv/benchmarks/suites/image_read.py
+++ b/ffcv/benchmarks/suites/image_read.py
@@ -42,18 +42,21 @@ def __getitem__(self, index):
'length': [3000],
'mode': [
'raw',
- 'jpg'
+ 'jpg',
+ 'png',
],
'num_workers': [
1,
8,
- 16
+ 16,
+ 32,
],
'batch_size': [
500
],
'size': [
(32, 32), # CIFAR
+ (224,224),
(300, 500), # ImageNet
],
'compile': [
@@ -83,13 +86,12 @@ def __enter__(self):
self.handle.__enter__()
name = self.handle.name
- writer = DatasetWriter(self.length, name, {
+ writer = DatasetWriter(name, {
'index': IntField(),
'value': RGBImageField(write_mode=self.mode)
})
- with writer:
- writer.write_pytorch_dataset(self.dataset, num_workers=-1, chunksize=100)
+ writer.from_indexed_dataset(self.dataset, chunksize=100)
reader = Reader(name)
manager = OSCacheManager(reader)
diff --git a/ffcv/benchmarks/suites/jpeg_decode.py b/ffcv/benchmarks/suites/jpeg_decode.py
index 31fc7860..51c74449 100644
--- a/ffcv/benchmarks/suites/jpeg_decode.py
+++ b/ffcv/benchmarks/suites/jpeg_decode.py
@@ -14,9 +14,9 @@
@benchmark({
'n': [500],
'source_image': ['../../../test_data/pig.png'],
- 'image_width': [500, 256, 1024],
- 'quality': [50, 90],
- 'compile': [True]
+ 'image_width': [224, 500, 1024],
+ 'quality': [50, 80, 90, 95],
+ 'compile': [True],
})
class JPEGDecodeBenchmark(Benchmark):
diff --git a/ffcv/benchmarks/suites/memory_read.py b/ffcv/benchmarks/suites/memory_read.py
index e6072516..2666fa34 100644
--- a/ffcv/benchmarks/suites/memory_read.py
+++ b/ffcv/benchmarks/suites/memory_read.py
@@ -59,13 +59,12 @@ def __enter__(self):
handle = self.handle.__enter__()
name = handle.name
dataset = DummyDataset(self.num_samples, self.size_bytes)
- writer = DatasetWriter(self.num_samples, name, {
+ writer = DatasetWriter(name, {
'index': IntField(),
'value': BytesField()
- })
+ }, num_workers=-1)
- with writer:
- writer.write_pytorch_dataset(dataset, num_workers=-1, chunksize=100)
+ writer.from_indexed_dataset(dataset, chunksize=100)
reader = Reader(name)
manager = OSCacheManager(reader)
diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py
index b6420f11..b90dbde8 100644
--- a/ffcv/fields/rgb_image.py
+++ b/ffcv/fields/rgb_image.py
@@ -12,7 +12,7 @@
from ..pipeline.state import State
from ..pipeline.compiler import Compiler
from ..pipeline.allocation_query import AllocationQuery
-from ..libffcv import imdecode, memcpy, resize_crop
+from ..libffcv import *
if TYPE_CHECKING:
from ..memory_managers.base import MemoryManager
@@ -21,6 +21,7 @@
IMAGE_MODES = Dict()
IMAGE_MODES['jpg'] = 0
IMAGE_MODES['raw'] = 1
+IMAGE_MODES['png'] = 2
def encode_jpeg(numpy_image, quality):
@@ -33,6 +34,11 @@ def encode_jpeg(numpy_image, quality):
return result.reshape(-1)
+def encode_png(numpy_image):
+ # x=cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
+ result = cv2.imencode('.png', numpy_image)[1]
+ result = np.frombuffer(result, np.uint8)
+ return result.reshape(-1)
def resizer(image, target_resolution):
if target_resolution is None:
@@ -41,7 +47,7 @@ def resizer(image, target_resolution):
ratio = target_resolution / original_size.max()
if ratio < 1:
new_size = (ratio * original_size).astype(int)
- image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA)
+ image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_CUBIC)
return image
@@ -101,22 +107,24 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
consider RandomResizedCropRGBImageDecoder or CenterCropRGBImageDecoder
instead."""
raise TypeError(msg)
-
- biggest_shape = (max_height, max_width, 3)
+
+ max_shape = ((np.uint64(widths)*np.uint64(heights)*3).max(),)
my_dtype = np.dtype('