๐ฆ TensorFlow Datasets (TFDS): A Treasure Trove of Ready-to-Use Datasets
If you're working with TensorFlow or JAX and want quick access to high-quality, standardized datasets for training, testing, or benchmarking — look no further than TensorFlow Datasets (TFDS).
TFDS is an awesome collection of preprocessed, curated datasets for machine learning. It supports datasets across a wide range of domains — from images and text to audio and structured data — and is designed to seamlessly integrate with both TensorFlow and JAX workflows.
๐ What is TensorFlow Datasets?
TFDS is a library that provides:
-
✅ Ready-to-use datasets with standard splits
-
๐งผ Automatic preprocessing, caching, and versioning
-
๐ Integration with TensorFlow (
tf.data.Dataset
) -
๐ง Compatibility with JAX and PyTorch (with
.as_numpy_iterator()
)
It helps you go from dataset loading to model training in just a few lines.
๐ Installation
pip install tensorflow-datasets
If you're using JAX, install the optional dependencies:
pip install tensorflow-datasets[jax]
๐ฅ Loading a Dataset
import tensorflow_datasets as tfds
ds, info = tfds.load('mnist', split='train', with_info=True)
print(info)
You can access:
-
train
,test
, andvalidation
splits -
Dataset metadata like features, size, and citation
-
Iterators compatible with
tf.data
๐งช Example: Load and Visualize CIFAR-10
import tensorflow_datasets as tfds
import tensorflow as tf
import matplotlib.pyplot as plt
ds = tfds.load('cifar10', split='train', shuffle_files=True)
for example in ds.take(1):
image, label = example['image'], example['label']
plt.imshow(image.numpy())
plt.title(f"Label: {label.numpy()}")
plt.show()
๐ Categories of Datasets
TFDS includes 1,000+ datasets across many domains:
Domain | Examples |
---|---|
๐ผ️ Vision | MNIST, CIFAR, ImageNet, COCO |
๐ NLP | IMDB, GLUE, SQuAD, WikiText |
๐ Audio | Speech Commands, LibriSpeech |
๐ Structured | Titanic, California Housing |
๐ฌ Scientific | Galaxy Zoo, QM9, OpenML |
๐ง Reinforcement | Atari, dSprites, BSuite |
⚙️ Custom Splits & Preprocessing
You can define your own splits and transformations:
ds = tfds.load('mnist', split='train[:80%]', as_supervised=True)
ds = ds.map(lambda x, y: (tf.cast(x, tf.float32)/255.0, y))
ds = ds.shuffle(1000).batch(32)
This integrates perfectly with TensorFlow’s Model.fit()
pipeline.
๐ JAX and NumPy Support
TFDS also works with JAX or NumPy:
for batch in tfds.as_numpy(ds):
x, y = batch['image'], batch['label']
This is especially useful for research, simulations, or custom training loops.
๐ก Use Cases
-
Fast prototyping with popular datasets
-
Benchmarking models with standardized data
-
Teaching and experimentation
-
Data exploration and visualization
๐งฐ Building Custom Datasets
You can even create your own dataset builder using TFDS's framework by subclassing tfds.core.GeneratorBasedBuilder
.
tfds new dataset_name
Then define _info()
, _split_generators()
, and _generate_examples()
in your builder script.
๐ Offline Use & Caching
Once downloaded, TFDS caches datasets locally (usually in ~/.tensorflow_datasets
) so they load super fast the next time.
You can also pre-download datasets in a CI/CD or offline environment using:
tfds build --dataset_dir=/data
๐ฏ Final Thoughts
TensorFlow Datasets (TFDS) is a powerful ally when you're experimenting, testing, or deploying ML models. It removes the tedious parts of data wrangling and lets you focus on what matters most — building great models.
With just a few lines of code, you get access to hundreds of high-quality, preprocessed datasets, ready to use in production or in research.
๐ Useful Links: