Skip to content →

Using PySpark and PyTorch together with TFRecords

TLDR : Jump to the meaty part here.

I joined Iterable as a data scientist when there wasn’t a lot of Deep Learning happening. Being a recent graduate from a dominantly research school (pre TF 2.0), I had reasons to be biased towards PyTorch. Also, in school, Pandas would suffice for all the projects we had. The industry big-data problems didn’t really exist in school. Once I joined Iterable, I soon enough realized that Pandas cannot solve my problems for multi-billion row datasets. Quickly enough I was trying to solve the problem of having to use PySpark but also not letting go of my love for PyTorch.

I was actually surprised by the lack of tools, that would let you use a big data tool like Spark with a famous Deep Learning library, PyTorch.

Problem with Spark DataFrames and PyTorch

As of PyTorch 1.1 there were no Iterable-Style datasets, which (almost) meant, that to keep using the DataLoader API, you would need to create a Map Based dataset where an index i mapped to the ith element of the dataset. And also meant you would have had to implement both __getitem_ and `__len__ function yourself. To have an out of the box solution for PySpark DataFrames (which could be writing ~200 files) was definitely non-trivial.

Solution 1 : Petastorm by Uber solves the problem of Spark and PyTorch (partially).

Around this time I started exploring my options and came across Petastorm, by some folks at Uber. In their own words “Petastorm library enables building deep learning models from datasets in Apache Parquet format”. This was definitely the best option for the few months (in fact I was passionate enough to make a few issues/review PRs and contribute to the repo). I explored this option for a few months and the few drawbacks I found were the 1) verbosity 2) often stuck waiting for I/O while my GPU would be sitting Idle. Lastly, they had written their own DataLoader, which means there wasn’t an option of using the num_workers of PyTorch. Having their own DataLoader also meant more complications to come for Distributed Training.

Solution #2 : Come IterableDatasets + TFRecords

In August of 2019, I was lucky to be part of the PyTorch Summer Hackathon (where our team came first 🏆 ). That is when I first heard of Iterable-Style datasets. Fast forward few months, I happened to stumble upon Vahid Kazemi‘s TFRecordReader for PyTorch, which would leverage them.

TFRecords + PyTorch + Spark = ???

Some people might get confused with the irony of using TFRecords (from Google) and PyTorch (from Facebook) together. Well, sure TFRecords isn’t meant for PyTorch. However the good thing was that a TFRecord writer for Spark existed since 2017, and someone just came up with with a library that would make using TFRecord files in PyTorch bearable. The combination of two is what makes this blog post possible.

How to use TFRecords with PyTorch DataLoader?

Now that I gave some background, I’ll jump straight into code and how I have solved this problem. We’ll create a fake dataframe, write it as tfrecords and try to train a model.

1. Create Data

For the purpose of the blog, we’ll create some fake data in PySpark that has Float, Array(Float) and Integer type dataframe.

# imports
import os
import numpy as np
from pyspark.sql.types import IntegerType, ArrayType, FloatType, StructField, StructType

# define schema
schema = StructType([
    StructField("float_feature", FloatType(), True),
    StructField("array_feature", ArrayType(FloatType()), True),
    StructField("label", IntegerType(), True)

# number of rows in dataframe
N = 100000

# create fake data
data = [
  [ np.random.randn(),
    np.random.randint(0, 2),
 ] for x in range(N)]

# create dataframe from fake data and schema
df = spark.createDataFrame(data, schema)

2. Write dataframe as tfrecords

To write a spark dataframe in tfrecords format, you’ll need to leverage the spark-tensorflow-connector built by people at Tensorflow, here.
We use Databricks at Iterable, and most of the Databricks Runtimes (ML) come preinstalled with the tf-connector.

Once you have installed the connector, all you need to run is :

tfrecords_directory = "/dbfs/ml/pytorch_data"

(Below in extra section I go through how to create index files for your dataframe that will allows you to leverage DataLoader parallelism using num_workers)

3. Find all tfrecords files

If you’ve used Spark enough, you might know that usually spark doesn’t create one big file, but rather numerous small files (the same as the number of partitions of your RDD). Each of the files created by our tfrecords-connector starts with a part-r-. Each of these files, is actually a machine-readable file by itself and doesn’t require the subsequent parts.

import os

# get all files that start with part-r inside tfrecords_directory i.e where you saved your dataframe
tfrecords_files = [os.path.join(tfrecords_directory, file) for file in os.listdir(tfrecords_directory) if file.startswith("part-r-")]

4. Create Description for our TFRecords files (optional)

To read a TFRecords file, what we’ll need is something called a description. The tfrecord repo describes description to be “A dictionary of key and values where keys are the name of the features and values correspond to data type. The data type can be “byte”, “float” or “int”.”.

In our case the three columns that we created have a description that can be defined as

description = {
  "float_feature" : "float", 
  "array_feature" : "float", 
  "label" : "int"

(Notice how both float_feature and array_feature are of type float. This is possible because if you look inside the code here, you’ll see that we create an array regardless. So even a single floating point number is [1.0] instead of 1.0 similarly, an array of float is also [[1.0, 2.0]] instead of [1.0, 2.0])

5. Create TFRecords Dataset

Once we know which all our files are, and we have their descriptions, all we need to do is create a TFRecordDataset for them. Due to the nature of using PySpark, we’ll probably have multiple part-r- files, so the only other nuance left is to actually chain all these small datasets together using

from import DataLoader, ChainDataset
from tfrecord.torch.dataset import TFRecordDataset

# create a TFrecordDataset for each file, this is a List[TFRecordDatasets]
mini_tfrdatasets = [TFRecordDataset(file, index_path=None, description=description) for file in tfrecords_files]

# create one big dataset by chaining all the mini_datasets
main_dataset = ChainDataset(mini_tfrdatasets)

Now that we have a DataSet we can easily use it using our DataLoader (or not, for those who like writing their own batching functionality)

# if you want to access it without dataloader

# if you like dataloaders (recommended way)
dataloader = DataLoader(main_dataset, batch_size=5)

for batch in dataloader:

Published in Deep Learning Pytorch Spark


Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.