How to build a convincing reddit personality with GPT2 and BERT
Last month, I experimented with building a reddit comment bot that generated natural language replies by combining two pre-trained deep learning models: GPT-2 and BERT. I wrote another post on the motivation and background, but here I wanted to give a step by step walkthrough so others can work with what I've built. If you prefer, you can jump straight to the project code. And to see the work that I based this on see this and this
Model overview
Before getting into the nitty-gritty, I wanted to give a general overview of the process that I'm going to be using. This flow diagram shows the 3 models that I needed to train, as well as the process fr hooking the models together to generate the output.
There are quite a few steps, but I hope it doesn't get too confusing. Check out my previous post for an even higher-level architecture overview. Here are the steps I'll be explaining in this post.
- step 0: get some reddit comment data from your favorite subreddits and format into strings that look like "comment [SEP] reply"
- step 1: fine tune GPT-2 to generate reddit text in the format "comment [SEP] reply"
- step 2: fine tune two BERT classifiers to:
- a: differentiate real replies from GPT-2 generated ones
- b: predict how many upvotes comments will get
- step 3: use praw to download current comments
- step 4: use fine-tuned GPT2 to generate many replies for each comment
- step 5: pass the generated replies to two BERT models to generate a prediction of realisticness and number of upvotes
- step 6: use some criteria for choosing which replies to submit
- step 7: use praw to submit the chosen comments
- step 8: chuckle with amusement
Getting lots of reddit comment data
As with any machine learning project, nothing can start until you have data from which to train your model.
The data I used to fine-tune the models came from a large database of previously retrieved reddit comments. There is an ongoing project that scrapes many sites around the web and stores them in a bunch of Google BigQuery tables. To me, it's very surprising that I couldn't find a central page about such a big project, but I used a few reddit and medium posts to piece together the format of the queries I'd need.
To start, I just downloaded a bunch of comment and reply information for the subreddits on 'writing', 'scifi', 'sciencefiction', 'MachineLearning', 'philosophy', 'cogsci', 'neuro', and 'Futurology'. This query works to pull the comments for a specific year and month ({ym}
) from bigquery.
SELECT s.subreddit as subreddit, s.selftext as submission, a.body AS comment, b.body as reply, s.score as submission_score, a.score as comment_score, b.score as reply_score, s.author as submission_author, a.author as comment_author, b.author as reply_author FROM `fh-bigquery.reddit_comments.{ym}` a LEFT JOIN `fh-bigquery.reddit_comments.{ym}` b ON CONCAT('t1_',a.id) = b.parent_id LEFT JOIN `fh-bigquery.reddit_posts.{ym}` s ON CONCAT('t3_',s.id) = a.parent_id where b.body is not null and s.selftext is not null and s.selftext != '' and b.author != s.author and b.author != a.author and s.subreddit IN ('writing', 'scifi', 'sciencefiction', 'MachineLearning', 'philosophy', 'cogsci', 'neuro', 'Futurology')
I used the bigquery python API to automate the generation of the queries I needed to download the data across a number of months in 2017 and 2018. This script iterated over the time periods I needed and downloaded them to local disk in the raw_data/
folder.
In the end, I'm going to want to be able to prime the GPT-2 network with a comment and generate a reply. To do this, I needed to reformat the data to contain both parts separated by a special [SEP]
string to let the algorithm know which part is which. Each line of training data file will look like the following.
"a bunch of primary comment text [SEP] all of the reply text"
After I train the model with this format, I can then feed the trained model a string like "some new primary comment text" [SEP]
, and it will start to generate the remaining "some new reply"
that it thinks fits best based on the training data. I'll explain in more detail below about how to feed this kind of data into the GPT-2 fine-tuning script. For now, you can use this script to convert the data into the format that GPT-2 fine-tuning will need and save it as gpt2_finetune.csv
Fine tuning GPT-2 and generating text for reddit
The major advantage of using GPT-2 is that it has been pre-trained on a massive dataset of millions of pages of text on the internet. However, if you were to use GPT-2 straight "out-of-the-box," you'd end up generating text that could look like anything you might find on the internet. Sometimes it'll generate a news article, sometimes it'll generate a cooking blog recipe, sometimes it'll generate a rage-filled facebook post. You don't really have too much control, and therefore, you won't really be able to use it to effectively generate reddit comments.
To overcome this issue, I needed to "fine-tune" the pre-trained model. Fine-tuning means taking a model that was already trained on a big dataset, and then continuing to train it on just the specific type of data that you want to use it on. This process (somewhat magically) allows you to take a lot of the general information about language from the big pretrained model, and sculpt that down with all the specific information about the exact output format you are trying to generate.
Fine-tuning is a standard process, but it still isn't super easy to do. I'm not an expert deep learning researcher, but fortunately for me, a really wonderful expert had already built some incredibly simple wrapper utilities called gpt-2-simple for make fine-tuning GPT-2, well... simple.
The best part is that the author of gpt-2-simple, even set up a Google Colab notebook that walked through fine-tuning. In case you haven't heard, Google Colab is an amazing FREE (as in beer) resource that lets you run a python jupyter notebook on a Google GPU server. Full disclosure, I am officially a lifetime fanboy of Google for making a free tier on Google App Engine, BigQuery, and Google Colab.
You can follow along with the tutorial notebook to learn all about how to fine-tune a GPT-2 model with gpt-2-simple. For my use case, I took all of that code and condensed and reformatted it a little to make my own gpt-2 fine tuning notebook that runs off the gpt2_finetune.csv
file that I generated in the previous step. Just like in the original tutorial, you need to give the notebook permission to read and write from your Google Drive. The model is then saved into you Google Drive for reloading from later scripts.
Training BERT models for fake detection and upvote prediction
Even after fine-tuning, the output of this model, while normally somewhat reasonable, is often pretty weird. To improve the quality of responses, I adapted the concept of GANs to create another meta-model that is able to throw out all the really weird replies. So I use GPT-2 to generate a 10+ candidate responses for every comment, and then I use another model to filter out which are the best replies I could release.
To determine the best, I actually want to do two things:
- Filter out unrealistic replies
- For the realistic replies, pick the one that I predict will have the most upvotes
So in order to do this, I have to train two classifiers, one to predict the probability of being a real reply and another to predict the probability of being a high scoring reply. There are lots of ways to perform this prediction task, but one of the most successful language models recently built for this kind of thing is another deep learning architecture called Bidirectional Encoder Representations from Transformers or BERT. One big benefit of using this model is that, similar to GPT-2, researchers have pre-trained networks on very large corpora of data that I would never have the financial means to access.
Again, I'm not the biggest expert in working with deep learning infrastructure so luckily, other brilliant tensorflowhub experts wrote a Google Colab tutorial for fine-tuning text classifier models using a pretrained BERT network. So all I had to do was combine the two with some glue.
In the next section, I'll walk through the fine-tuning and some model evaluation, but if you'd like to get a jumpstart and don't want to bother fine-tuning yourself, you can download the three fine-tuned models from here, here and here.
BERT Discriminator model performance
The realisticness model was trained just like in a traditional GAN. I had another Colab notebook generate thousands of fakes and then created a dataset that combined my fakes with thousands of real comments. I then fed that dataset into a BERT realisticness fine-tuning notebook to train and evaluate. The model actually has amazingly high distinguishing power between real and fake comments.
BERT Realisticness Model Metrics
'auc': 0.9933777, 'eval_accuracy': 0.9986961, 'f1_score': 0.99929225, 'false_negatives': 3.0, 'false_positives': 11.0, 'precision': 0.9988883, 'recall': 0.99969655, 'true_negatives': 839.0, 'true_positives': 9884.0
Going forward, every reply that the generator creates can be run through this BERT discriminator to get a score from 0 to 1 based on how realistic it is. I then just filter to only return comments that are predicted to be the most likely to be real.
To predict how many upvotes a reply will get, I built another model in a similar way. This time the model was just trained just on a dataset containing a bunch of real reddit comments to predict how many upvotes they actually got.
This model also had surprisingly high predictive accuracy. This ROC curve shows that we can get a lot of true positives correct without having too many false positives. For more on what true positive and false positive means see this article.
ROC curve for BERT based upvote prediction
Buoyed by the model cross-validation performance, I was excited to hook it up to a real-time commenting system and start shipping my bot's thoughts!
Pulling real-time comments with PRAW
Although I could generate the training sets using data on bigquery, most of that data is actually a couple of months old. Replying to months old comments is a very non-human thing to do on social media sites so it was important to be able to pull down fresh data from reddit somehow.
Fortunately, I could use the praw library along with the following snippet to get all comments from the top 5 "rising" posts in a couple of subreddits that I thought would produce some interesting responses.
for subreddit_name in ['sciencefiction', 'artificial', 'scifi', 'BurningMan', 'writing', 'MachineLearning', 'randonauts']: subreddit = reddit.subreddit(subreddit_name) for h in subreddit.rising(limit=5):
I could run each comment through the generator and discriminators to produce a reply.
Running the generator and discriminators
Finally, I just had to build something to reload all the fine-tuned models and pass the new reddit comments through them to get replies. In an ideal world, I would have run both the GPT-2 and the BERT models in one script that could be run from end to end. Unfortunately, a quirk in the way the designers immplemented the gpt2-simple package made it impossible to have two computation graphs instantiated in the same environment.
So instead, I just ran a GPT-2 generator notebook on its own to download new comments, generate a batch of candidate replies, and store them in csv files on my google drive. Then, I reloaded the candidates in a separate BERT discriminator notebookto pick the best replies and submit them back to reddit.
You can view the whole workflow in my github repo for the project or in my Google Drive folder. Please submit issues to the project if you think things can be explained more clearly, or if you find bugs.
Last Step: Chuckle with Amusement
I submitted all my replies under the reddit account of tupperware-party (which hopefully won't get shut down for trademark shit). You can check out some highlights from the model output here or see the full list of comments to inspect everything the system outputted. I've also shared a folder on Google Drive with all of the candidate responses and their scores from the BERT models if you want to take a look.
Finally, I know there are definitely some ethical considerations when creating something like this. You can read my thoughts on that here. In short, please try to use this responsibly and spread the word that we are living in a world where this is possible. And if you have a problem, tell me on Twitter. I swear it'll really be me who responds.