<- home

training a CNN to classify my bad doodles

machine learningtensorflowpython

i built a CNN (convolutional neural network) for classifying doodles that can run in the browser. check it out the demo here.

the dataset

i started with Google's Quick, Draw! dataset and narrowed down our subset to 250,000 images, 25,000 for each of the 10 categories we chose:

Cat, Umbrella, Windmill, Tractor, Bicycle, Cruise Ship, Helicopter, Bowtie, House

example doodles from the dataset

i then applied data augmentation using random rotations, slight scaling, and small translations.

augmented dataset examples

the model

the architecture is a standard CNN:

  • input: 28x28 grayscale images
  • conv layers: two blocks of Conv2D + MaxPooling + Dropout
  • dense layers: flattened into fully connected layers with dropout
  • output: softmax over the class categories

training

the model hit 95% validation accuracy after about 20 epochs with early stopping.

training curves showing convergence to 95% validation accuracy

validation

i used a smaller dataset of 10,000 for validation.

precisionrecallf1-scoresupport
cat0.950.970.961000
umbrella0.980.980.981000
windmill0.930.970.951000
octopus0.980.970.981000
tractor0.930.930.931000
bicycle0.950.960.951000
cruise ship0.970.960.971000
helicopter0.960.930.941000
bowtie0.970.940.951000
house0.980.990.981000
accuracy0.9610000

the confusion matrix. bicycles and tractors are the two things most often confused, followed by helicopters and windmills.

confusion matrix

where it gets things wrong

the model was not perfect. out of 10,000 images, it got 399 of them wrong. i noticed there is a lot of bad data.

batch of incorrect predictions

another batch of incorrect predictions

deploying to the browser

the final model gets converted to TensorFlow.js format and loaded client-side. the canvas captures your drawing, preprocesses it (resize to 28x28, normalize), and runs inference. the whole pipeline runs in under 100ms on most devices.

try it yourself.

all incorrect predictions

all 399 incorrect predictions across 16 batches.

incorrect batch 1incorrect batch 2incorrect batch 3incorrect batch 4incorrect batch 5incorrect batch 6incorrect batch 7incorrect batch 8incorrect batch 9incorrect batch 10incorrect batch 11incorrect batch 12incorrect batch 13incorrect batch 14incorrect batch 15incorrect batch 16

all 10,000 predictions

every single prediction the model made on the test set.

predictions batch 1predictions batch 2predictions batch 3predictions batch 4predictions batch 5predictions batch 6predictions batch 7predictions batch 8predictions batch 9predictions batch 10predictions batch 11predictions batch 12predictions batch 13predictions batch 14predictions batch 15predictions batch 16predictions batch 17predictions batch 18predictions batch 19predictions batch 20predictions batch 21predictions batch 22predictions batch 23predictions batch 24predictions batch 25predictions batch 26predictions batch 27predictions batch 28predictions batch 29predictions batch 30predictions batch 31predictions batch 32predictions batch 33predictions batch 34predictions batch 35predictions batch 36predictions batch 37predictions batch 38predictions batch 39predictions batch 40