Generating lyrics with RNNs
[Snoop Dogg]
That's a to anything changed in the living down there
I'm gonna tell 'em to try to see these ****** about to get that money
I'm from the streets on the telly as that make me that dash seems
I'm on the game to kill and pull a game to
I ain't seen the time when you're saying that we all in a car
[Drake: Ice-T]
Smash the most hat, my man can pay her
Look at his shit
That's why I got to be that here all with the way they don't see the money
Them cars go to ass, it's just a reason to be my take
To my dog, we just wanting to be at the game shit
What they do, get a game to the bang, but you love you
So I love you, you see this *****, you on the danger route but I won't be
But I may fuck around and see your dome
He ain't got next then I don't get 'em
If you search around in the output, you can almost imagine that the lyrics sort of make sense occasionally. The neat thing about these models is that they are only trained on characters, not words, yet still manage to capture some of the longer-range structure of the text. A simpler network trained on the same input text produced output like this:
[GZA]
Fuck who got it it left burn this way and you did
And I keep the long of a city ****** still blows ain't beat
Get you now you're can't say they bropels
Call the speakin on who's metches in the come game
won't upstice up, I do the motherfuckn
Meres to us with all 19th I like the hundred
They got compore money thing, succoss hoes
[Verse Two: Wather Dogg]
It's a bitch deconolin ways boys
We all me
I start the heaf in the crime shotue
I's most it lang and some I got my slick
Upita the line when it's good with a dark can I heart
Should let my app.. One up, the sacument black
The most noticeable thing is that there seem to be more misspellings. It's hard to say if the lyrical content is any better or worse since they don't make much sense anyway.
The lyrics dataset used for training was 7MB of lyrics from
The Original Hip Hop Lyrics Archive, specifically it's all the lyrics from 15 of the top artists that I recognized such as Eminem and Snoop Dogg. The first set of outputs was created from a 3 layer LSTM RNN with 1024 nodes per layer, the second with a 1 layer LSTM RNN with 64 nodes per layer.
How do different size networks compare? Here are some results using the cross entropy loss metric for different sizes of networks:
layers | nodes per layer | parameters | training time | loss (bits/char) |
1 | 64 | 47k | 1h46m | 2.28 |
1 | 128 | 128k | 2h20m | 2.14 |
3 | 64 | 113k | 4h35m | 2.28 |
3 | 128 | 391k | 6h0m | 2.11 |
3 | 1024 | 21M | 3h30m* | 1.99 |
*The last model trained for much longer than 3h30m, but because of the size of the network, it overfitted the training data and so the validation loss was much higher after the 4h mark.
Training was done using
rnn.py with the following settings:
Terminal
python rnn.py \
-input_file lyrics.txt \
--num_layers <layers> \
-model lstm \
-rnn_size <size> \
-train_frac 0.8 \
-val_frac 0.1
Sampling was done with the same program and took about 17m:
Terminal
python rnn.py -temperature 0.8 -sample <checkpoint> -sample_length 100000
Training and sampling were done on a ~1.3 TFLOPS Nvidia GTX 750 Ti.
Rhyming
RNNs seem to be able to capture features like closing parenthetical statements or brackets, which requires some longer-range knowledge of what happened previously than offered by something like a Markov chain (discussed below). Can RNNs capture the rhyming nature of lyrics in the generated text without hand-coding or manually tweaking the network specifically to produce rhyming text?
One way to measure this is to invent some rhyme metric to measure how much a set of lyrics rhymes and compare the RNN output to the input.
We could look at the last syllable of each line and see if it has a similar sound to the next line. In order to convert words into sounds, there's this nice
CMU Pronouncing Dictionary for converting dictionary words into phonemes.
Here's a quote from a Shakespeare sonnet:
1 Who will believe my verse in time to come,
2 If it were filled with your most high deserts?
3 Though yet heaven knows it is but as a tomb
4 Which hides your life, and shows not half your parts.
5 If I could write the beauty of your eyes,
6 And in fresh numbers number all your graces,
7 The age to come would say This poet lies;
8 Such heavenly touches neer touched earthly faces.
When converted to phonemes it looks like this:
1 HH UW W IH L B IH L IY V M AY V ER S IH N T AY M T UW K AH M
2 IH F IH T W ER F IH L D W IH DH Y AO R M OW S T HH AY D EH Z ER T S
3 DH OW Y EH T HH EH V AH N N OW Z IH T IH Z B AH T AE Z AH T UW M
4 W IH CH HH AY D Z Y AO R L AY F AH N D SH OW Z N AA T HH AE F Y AO R P AA R T S
5 IH F AY K UH D R AY T DH AH B Y UW T IY AH V Y AO R AY Z
6 AH N D IH N F R EH SH N AH M B ER Z N AH M B ER AO L Y AO R G R EY S IH Z
7 DH AH EY JH T UW K AH M W UH D S EY DH IH S P OW AH T L AY Z
8 S AH CH HH EH V AH N L IY T AH CH AH Z N IH R T AH CH T ER TH L IY F EY S AH Z
Looking at the last syllables of the lines, "eyes" rhymes with "lies" and not as much with "faces" even though the endings of the words (-es) are the same. In the phonemes those words appear as "AY Z", "L AY Z", and "F EY S AH Z". The last sounds of "eyes" and "lies" are the same, "AY Z", while "faces" only shares the "Z" sound.
For a simpler approach, we could look at the last vowel sound of each word and see if it matches any nearby lines. For example:
[Verse 3 - Drake]
Uh, why is words from a decent man
Back when I was trying to put a ring on Alicia hand
This lost boy got fly without Peter Pan
And my delivery has got me buzzing like the Pizza Man
In this excerpt "man" is being rhymed with "hand", but they don't share the same last syllable (the "d" on "hand" isn't pronounced in the song). Each line in this section rhymes with a nearby line except "[Verse 3 - Drake]", which is not filtered out in this metric. Looking at +/- 2 lines, the percent of rhyming lines for various datasets looks like this:
If we can't identify the syllables of the last word on a line, we ignore that line for calculating the percentage. The sonnets (which have a very regular rhyme structure) have the highest score, with the lyrics coming in about halfway between the sonnets and 1MB of some random Project Gutenberg text which doesn't appear to contain any poetry.
If we run this rhyming check on the output of the RNN, the result is 31%, which is the lowest out of all of these data sources. In fact, I randomly chose words from the lyrics file and put each one on a line by itself:
with
you
Bang
pimpin
now
We
******
F
old
game
This text has a rhyme score of 33%, so slightly better than the RNN. Training on phonemes instead of on characters also produced rhyming no better than random chance. It's possible that this particular RNN either cannot notice such a subtle feature of the input, or that the information stored by the network decays too quickly to allow it to reproduce a feature like rhyming words. This is not an exhaustive test, the RNN used here could be insufficiently powerful or not trained correctly.
The rhyming metric used here is not too advanced either, end-of-line rhymes seem accurate enough for Shakespearean sonnets, but hip hop often used more advanced rhyming patterns, as discussed in this
analysis of rhyming patterns from the Broadway musical "Hamilton".
Recurrent Neural Networks
So what are RNNs all about? Recurrent neural networks (RNNs) let you predict (or
generate) sequences by training the network on some input data that matches the sort of data you want to later predict/generate. The version used above is a kind of RNN called a Long Short Term Memory (LSTM) network. There's a great overview with diagrams on
Christopher Olah's blog that explains how they work and how they differ from basic RNN networks.
To see an example of how an RNN can learn the structure of some text, we can train it on the following simple pattern:
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
The only tricky part about this pattern is that the RNN can't just put a newline character after each "z" character, it needs to put it after every 3rd one. The input for training the network is just the same line repeated 8k times.
If we train a single layer LSTM RNN on this data, with 64 hidden nodes, we end up with ~25k total parameters (various weights and biases) that describe the network. The command to train it using the rnn.py program looks like this:
Terminal
python rnn.py \
-input_file pattern.txt \
-num_layers 1 \
-model lstm \
-rnn_size 64 \
-train_frac 0.8 \
-val_frac 0.1 \
-max_epochs 200
For each iteration we take the next character from the training input and provide it to the network. The network outputs what it thinks the following character will be in the form of a probability for each possible output character:
RNN("a") => {"a": 0.01, "b": 0.5, "c": 0.01, ...}
An interesting feature of the RNN is that it actually keeps internal state so it has some context for what characters came before, if we make that explicit it looks like this:
RNN("a", initial_state) => {"a": 0.01, "b": 0.5, "c": 0.01, ...}, state1
RNN("b", state1) => {"a": 0.01, "b": 0.01, "c": 0.5, ...}, state2
Loss Function
In order to train the network, we need some function to minimize (a "loss" function) so we can figure out how to set the weights of the network. A common metric to use for something like this is the cross entropy loss which estimates how much information (called entropy, measured in bits) that we need to choose the next character given the output of the network.
Since the network outputs a probability for each character, and there is only one correct character in the training set, we can take the -log2(probability) for the correct character to see how correct the network was. If the network outputs a probability of 1.0 for the correct character, -log2(1.0) = 0 bits, meaning that no additional information was needed. If the network outputs 0.0 for the correct character, -log2(0) = infinity bits, so hopefully that never happens. The probabilities output by the network can be very close to 0.0 or 1.0, but cannot quite get there so we should not get any infinities when running this. If the network outputs 0.5, -log2(0.5) = 1.0 bits so we would need an additional bit of information to choose the correct character here.
Entropy
As an example for how entropy works, let's say we have an alphabet of 27 characters:
"ABCDEFGHIJKLMNOPQRSTUVWXYZ "
If we generate some text by randomly choosing between them with equal probability, the entropy would be -log2(1/27) = log2(27) = 4.76 bits which is the maximum entropy you can have for an alphabet of 27 symbols.
Actual english text, however, is not uniformly random, it contains various redundancies which is why it often compresses well from the ASCII 8 bits per character encoding. For example, consider the string "technolog". The next character could be "y" or perhaps "i" as in "technological" but most other characters like "z" are pretty unlikely. If these two choices have roughly 50% probability each, then the conditional entropy of the next character here would be about 1 bit.
As a frame of reference, super badass scientist and inventor of information entropy Claude Shannon
estimated the conditional entropy of English text (reduced to 27 chararacters), given the preceding 100 characters, at something on the order of 1 bit per character. Compressing Project Gutenberg text, filtered down to a similar 27 characters, with Python's zlib module reduces the ASCII text from 8 bits per character to 2.84 bits per character.
Training
So to measure the cross entropy loss of this RNN, we give it each character from the data one at a time, and ask it to predict the next character. We then average the -log2(probability) values for each character output to get the overall loss of the network. Note that the loss output by rnn.py is in nats, not bits, but I've converted everything to bits for this article so it's consistent.
For our text, the underlying conditional entropy is 0 bits per character, since you know which character comes next by looking at, at most, the previous 79 characters. The cross entropy loss for the training starts off at 4.7 average bits per character, approximately equal to log2(27) where 27 is the number of unique symbols for this file (A-Z plus newline). So it's just randomly guessing at each point what character will show up next. The loss has dropped to 0.04 by the 5th time running through the input data. At this point the model can predict the alphabet part, but it can't decide when "z" is followed by "a" and when it's followed by a newline. Here's the output from sampling that model:
abcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefhhijklmnopqrstuvwxyzabcdffghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefghijkkmnopqrstuvwxyz
Whenever the model encounters a "z" in the input, the output probabilities look like this:
{
"\n": 0.3326,
"a": 0.6670,
...
}
At around the 65th time through the training data, the loss has dropped to 0.00 and the model figures out that it's every third "z" that is followed by a newline:
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz
When the model encounters a "z" inside a line, the output looks like:
{
"\n": 0.0001,
"a": 0.9997,
...
}
When the "z" is at the end of a line, the output looks like:
{
"\n": 0.9793,
"a": 0.0203,
...
}
Now that the model is trained, it can be used to predict the next character given some input, like "abcdefghijkl" => "m" or it can be used to generate the pattern we trained it on. Not the most useful of networks, but nice to know it works for a simple example like this. The training time to get to 0.00 loss was about 8m.
Markov Chains
How do RNNs compare to the conceptually much simpler Markov chain? Markov chains like this are dead simple to program. By looking at your training text you create a table that, for each input string of length n, gives you the probabilities of what the n+1th character will be. You can do this by looking at each prefix of n characters, and keeping track of how many times you see each following character, then dividing by the total number times the prefix appears to get a probability.
Here's what that looks like in Python:
Python
def generate_table(text, n):
if n == 0:
# n = 0 assigns each character an equal probability
chars = {}
for c in text:
chars[c] = 1
return {"": {char: count/float(len(hist)) for char, count in chars.iteritems()}}
# figure out transition counts
table = {}
for i in xrange(len(text) - n):
prefix = text[i:i+n-1]
char = text[i+n-1]
if prefix not in table:
table[prefix] = {}
if char not in table[prefix]:
table[prefix][char] = 0
table[prefix][char] += 1
# normalize counts to get probabilities
for prefix, chars in table.iteritems():
total = sum(chars.values())
for char, count in chars.iteritems():
chars[char] = count / float(total)
return table
Let's train this on an example sentence.
Python
generate_table("the small black cat jumps over the calico cat", 2)
{
"c": 0.375,
"b": 0.125,
"j": 0.125,
"o": 0.125,
"s": 0.125,
"t": 0.125,
}
Since the space character occurs 8 times in the input, each time followed by a different character, there are 8 probabilities here, one for each of the characters after the space. Since the character "c" occurs 3 times, it has 3 times the probability in the generated table.
Predicting the next character is exactly what we are doing with RNNs, although in a much more complicated way. How well do Markov chains do at the same job? Since we use cross-entropy loss to measure the performance of an RNN, looking at the same thing for Markov chains would be ideal. Cross-entropy uses log(probability) and the probability of some characters here is 0, so we have to change the Markov chain model a little bit to account for that. For this comparison:
- We use the same lyrics training/validation data sets as the RNN
- All missing characters for a prefix share a probability of 0.01 (and the rest are re-normalized) to avoid getting a log(0)
- If an n length prefix is not found, we search n-1 length prefixes, etc, until a match is found
As we increase the order of the model, we quickly overfit the training set and the validation loss takes a hit. For a given data set, with a high enough value of n, we should be able to perfectly reproduce the training set from the first n characters. That doesn't help us on the validation set though, since the training process never sees it. The lowest validation loss occurs at n = 5 for this data set with a cross entropy loss of ~2.4 bits per character. With a larger data set, we should have less overfitting for the same values of n. However, the number of parameters to represent progressively larger Markov chains like this quickly becomes large.
For an n = 5 model using the lyrics dataset, which has ~100 unique symbols, we have a 100**4x100 matrix of probabilities, but since we store it as a sparse hash table, we end up with only ~500k parameters. The simplest RNN we used in the first section had a cross entropy loss of ~2.3 bits per character and uses ~50k parameters. How is this model able to capture the same amount of structure (at least measured by cross entropy loss) with much fewer parameters? It's possible that the way in which an RNN models text is closer to text's underlying structure vs how a Markov chain does it.
Mutual Information
One interesting way to look at the difference between Markov chains and RNNs is with an information theory property called mutual information. Given two random variables,
x and
y, the mutual information
I(x,y) is the amount of information (in bits, for example) that the two variables share. If the variables have no relation, they are independent, and the mutual information is 0 since knowing the value of
x tells you nothing about what the value of
y might be and vice versa. Henry Lin and Max Tegmark cover mutual information applied to Markov chains and RNNs in their paper
Critical Behavior from Deep Dynamics: A Hidden Dimension in Natural Language, along with a very neat video
Connections Between Physics and Deep Learning
The idea is that RNNs and other hierarchical language models capture more of the underlying structure than a Markov chain, which, as a 1-dimensional model, is unable to capture that same underlying structure.
There are a number of expressions of the mutual information, but the one that makes the most sense to me is:
Where S(x) is the Shannon entropy of the random variable x. We don't know the actual entropy, but we can estimate it from some data like this:
Python
import math
def entropy(symbols):
hist = {}
for s in symbols:
if s not in hist:
hist[s] = 0
hist[s] += 1
e = 0
total = sum(hist.values())
for count in hist.values():
p = float(count) / total
e -= p * math.log(p, 2)
return e
entropy("010101010101010101010101010101010101010101") # 1.0 bits
entropy("000000000000000000000000000000000000000001") # 0.16 bits
entropy("000000000000000000000000000000000000000000") # 0.0 bits
This estimate is overall entropy, not the conditional entropy like we had earlier. The more text we have, the better the estimate of entropy.
For the text we are looking at, we want to find the mutual information between symbols in the data that are a certain distance from each other so x can be the text, with y the same text offset some number of characters. For example if the distance we are using is 3:
Python
data = "the small black cat jumps over the calico cat"
x = data[:-3] # "the small black cat jumps over the calico "
y = data[3:] # " small black cat jumps over the calico cat"
xy = zip(x,y) # [('t', ' '), ('h', 's'), ('e', 'm'), (' ', 'a'), ...]
I = entropy(x) + entropy(y) - entropy(xy) # 2.50 bits
For any given non-infinite text, this estimate is on the low side, so there's a fancier estimator that is used in the Critical Behavior paper:
Where ψ(x) is the digamma function, N is the length of the text, and K is the number of unique characters in the text.
If we use this to estimate the mutual information of 100MB of English Wikipedia text (the wikitext source code of the pages), we get a graph like this:
If we estimate the mutual information from 10MB of Wikipedia text generated by an RNN (LSTM, 3 layers, 128 nodes per layer) and text generated by a Markov chain (order 5) we get a graph like this:
In Critical Behavior, the authors state that mutual information in the Markov chain falls off as an exponential (something like y = a ** x) while in the Wikipedia text, it falls off roughly as a power law (y = x ** a). The exponential moves much more quickly than the power law, so the Markov chain is unable to retain the information present in the source material. The RNN seems to do a better job at it, and the theory is that the RNN more closely models the hierarchical structure of the source text.
For reference, here's a power law with the equation distance**(-np.e/5) and an exponential with the equation np.exp(-distance/5) (normalized to be 1 at distance=1):
Implementation Notes
The RNN implementation,
rnn.py is in Python using Tensorflow, based roughly on
Andrej
Karpathy's char-rnn. I wanted to write a single file implementation that doesn't outsource the network to a built-in Tensorflow library.
It separates the input into parallel streams that can be processed at the same time so that the GPU implementation is faster than running on the CPU. For rnn.py, with a 3 layer LSTM with 1024 nodes per layer, the GPU (Nvidia GTX 750 Ti) trains about 12 times faster than on the CPU (Intel i3-4130T).