Skip to content →

Quantizing images using KMeans in Pytorch DataLoader

That was one too many words in the object title. If you have a better title please suggest.

Often we want to binarize our image, and often instead having an 8 byte image we want a 1 byte image, i.e going from 255 color values to 64 color values. I needed to do this for a VAE model that I was creating where instead of having a grayscale image with 255 channels, I wanted to just have 64 possible values.

Several people use histogram approach and what not, but I believe KMeans has often proven itself to be a good color-quantizer. Here we use KMeans to fit on the training dataset first, and then we call predict on our data loader..

The steps are broken down to the following :

  1. Get access to the images we want to quantize. Either through PyTorch DataLoader or maybe you have access to the data array.
  2. Sample n images from the dataset (here I do 3000) .The idea being that we really don’t need to locate all the images to create clusters, we can look at some and have a decent idea.
  3. Fit sklearn’s Kmeans on the data.
  4. Create a PyTorch Transformer which calls kmeans.predict
dataset = torchvision.datasets.MNIST(folder, train=True, transform=None, target_transform=None, download=True)
numpy_train_data = dataset.train_data.numpy()
random_indices = np.random.choice(len(numpy_train_data), 3000, replace=False)
    numpy_sub_train_data = numpy_train_data[random_indices].reshape(-1, 1) / 255
kmeans = KMeans(n_clusters=n_clusters, n_jobs=-1).fit(numpy_sub_train_data)
kmeans_lambda = transforms.Lambda(lambda x: kmeans.predict(x.view(-1, 1)))
transformer = transforms.Compose([
        transforms.ToTensor(),
        kmeans_lambda
    ])
training_dataset = torchvision.datasets.MNIST(folder, train=True, transform=transformer, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(
        training_dataset, batch_size=32, num_workers=8)
for i, (data, label) in enumerate(train_loader):
      break

If any of it is unclear, please feel to comment below for questions. If you know of a better way feel free to share and I’ll be happy to update my post.

Published in deep learning pytorch

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.