That was one too many words in the object title. If you have a better title please suggest.
Often we want to binarize our image, and often instead having an 8 byte image we want a 1 byte image, i.e going from 255 color values to 64 color values. I needed to do this for a VAE model that I was creating where instead of having a grayscale image with 255 channels, I wanted to just have 64 possible values.
Several people use histogram approach and what not, but I believe KMeans has often proven itself to be a good color-quantizer. Here we use KMeans to fit on the training dataset first, and then we call predict on our data loader..
The steps are broken down to the following :
- Get access to the images we want to quantize. Either through PyTorch DataLoader or maybe you have access to the data array.
- Sample n images from the dataset (here I do 3000) .The idea being that we really don’t need to locate all the images to create clusters, we can look at some and have a decent idea.
- Fit sklearn’s Kmeans on the data.
- Create a PyTorch Transformer which calls kmeans.predict
dataset = torchvision.datasets.MNIST(folder, train=True, transform=None, target_transform=None, download=True) numpy_train_data = dataset.train_data.numpy() random_indices = np.random.choice(len(numpy_train_data), 3000, replace=False) numpy_sub_train_data = numpy_train_data[random_indices].reshape(-1, 1) / 255 kmeans = KMeans(n_clusters=n_clusters, n_jobs=-1).fit(numpy_sub_train_data)
kmeans_lambda = transforms.Lambda(lambda x: kmeans.predict(x.view(-1, 1))) transformer = transforms.Compose([ transforms.ToTensor(), kmeans_lambda ]) training_dataset = torchvision.datasets.MNIST(folder, train=True, transform=transformer, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader( training_dataset, batch_size=32, num_workers=8) for i, (data, label) in enumerate(train_loader): break
If any of it is unclear, please feel to comment below for questions. If you know of a better way feel free to share and I’ll be happy to update my post.