FastHugs: Language Modelling with Tranformers and Fastai
Train a transformer language model from scratch or fine-tune a pretrained one using fastai and Huggingface.
This aims to be an end-to-end description with code of how to train a transformer language model using fastai (v2) and HuggingFace, enjoy!
TL;DR
Main interesting bits in this notebook:
- Provides full code to train a transformer (RoBERTa) using a Masked Language Model task
- Utilise's many of HuggingFace's tokenizer features within fastai
- Make predictions of masked tokens like this:
Before we get started
- First off, huge thanks as always to both the Fastai and HuggingFace teams for giving so much back to the community by open-sourcing so much
For an example of text sequence classification using HuggingFace and fastai, have a look at my previous notebook here
This tutorial is heavily based on HuggingFace's "How to train a new language model from scratch using Transformers and Tokenizers" tutorial, I highly recommend checking that out too. I try and highlight throughout where code has been used, borrowed or inspired by HuggingFace's code.
MLM Tranform
I feel the most useful thing in this notebook is the MLMTokensLabels
transform*. This carries out the Masked Language Model task that RoBERTa was originally trained on.
This will take tokens ids (tokens after the have been numericalized), select a subset and either mask a certain amount of them (for prediction) or replace them with other random token ids (for regularisation). This transform also creates our labels by copying the input token ids and masking the tokens that do not need to be predicted, so that no loss is calculated on them.
Note the if you wish to train BERT or other transformer language models you will probably need to use a different task, e.g. BERT was trained on 2 tasks simultaneously, MLM and Next Sentence Prediction (NSP). Have a look at any blog posts or arxiv paper of the transformer of interest to find which task was used to pretrain it.
*This transform code is a re-write of the mask_tokens
function used in HugginFace's tutorial, code here
Pretraining + Fine-Tuning:
As shown in ULMFit, MultiFiT, and elsewhere, you will get better results on your downstream task if you first fine-tune your pretrained model with the text of the same domain as your pretrained task. e.g. training an IMDB movie review classifier who's language model was trained on wikipedia text.
1/ Really excited about this one! "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks" is live! With @anmarasovic, @swabhz , @kylelostat , @i_beltagy , Doug Downey, and @nlpnoah, to appear at ACL2020.
— Suchin Gururangan (@ssgrn) April 24, 2020
Paper: https://t.co/hVbSQYnclk
Code: https://t.co/7wKgE1mUme
Using a Custom Tokenizer?
This code has not been tested using a custom tokenizer. You may want to do so if your text is very specific to a certain domain. If so then you'll have to add a number of attributes to your tokenzier to be able to use the code here. I really recommend the HuggingFace language model tutorial linked above for an example of training your own tokenizer with your own dataset
Data
We'll use the IMDB_SAMPLE
here, pretending we are fine-tuning our transformer model before doing sentiment classification on IMDB. If you are pretraining a language model from scratch you'd aim to use a larger, more generic source like a wikipedia dataset. fastai have the full WikiText103
(100 million tokens) dataset available for easy download here if you'd like to train an enligh language model from scratch:
path = untar_data(URLs.WIKITEXT)
HuggingFace Auto Classes
HuggingFace have a numer of useful "Auto" classes that enable you to create different models and tokenizers by changing just the model name.
AutoModelWithLMHead
will define our Language model for us. This can either be a pretrained model or a randomly initialised modelAutoTokenizer
will load our tokenizer and enable us grab our vocabAutoConfig
will define the model architecture and settings, note that we use the pretrained config here for ease of use, but one can easily modify this config if neededmodel_name
is the model architecture (and optionally model weights) you'd like to use.- Language Models tested so far with this notebook:
roberta-base
- You can find all of HuggingFace's models at https://huggingface.co/models, most, but not all of them are supported by
AutoModel
,AutoConfig
andAutoTokenizer
- Language Models tested so far with this notebook:
We can now easily call whichever transformer we like as below:
model_name = 'roberta-base'
lm_model_class = AutoModelWithLMHead
config_dict = AutoConfig.from_pretrained(model_name)
#collapse
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_vocab=tokenizer.get_vocab()
tokenizer_vocab_ls = [k for k, v in sorted(tokenizer_vocab.items(), key=lambda item: item[1])]
print(f'Tokenizer "{tokenizer.__class__}" vocab length is : {len(tokenizer_vocab_ls)}')
tokenizer.special_tokens_map
class FastHugsTokenizer():
"""
transformer_tokenizer : takes the tokenizer that has been loaded from the tokenizer class
model_name : model type set by the user
max_seq_len : override default sequence length, typically 512 for bert-like models.
`transformer_tokenizer.max_len_single_sentence` and `transformer_tokenizer.max_len_sentences_pair`
both account for the need to add additional special tokens, i.e. for RoBERTa-base
max_len_single_sentence==510, leaving space for the 2 additional special tokens
to be added for the model's default 512 positional embeddings
pair : whether a single sentence (sequence) or pair of sentences are used
Returns:
- Tokenized text, up to the max sequence length set by the user or the tokenzier default
"""
def __init__(self, transformer_tokenizer=None, model_name='roberta', max_seq_len=None,
pretrained=True, pair=False, **kwargs):
self.model_name, self.tok, self.max_seq_len=model_name, transformer_tokenizer, max_seq_len
if pretrained:
if self.max_seq_len:
if pair: assert self.max_seq_len<=self.tok.max_len_sentences_pair, 'WARNING: max_seq_len needs to be less than or equal to transformer_tokenizer.max_len_sentences_pair'
else: assert self.max_seq_len<=self.tok.max_len_single_sentence, 'WARNING: max_seq_len needs to be less than or equal to transformer_tokenizer.max_len_single_sentence'
else:
if pair: self.max_seq_len=ifnone(max_seq_len, self.tok.max_len_sentences_pair)
else: self.max_seq_len=ifnone(max_seq_len, self.tok.max_len_single_sentence)
def do_tokenize(self, o:str):
"""Returns tokenized text, adds prefix space if needed, limits the maximum sequence length"""
if 'roberta' in model_name: tokens=self.tok.tokenize(o, add_prefix_space=True)[:self.max_seq_len]
else: tokens = self.tok.tokenize(o)[:self.max_seq_len]
return tokens
def de_tokenize(self, o):
"""Return string from tokens"""
text=self.tok.convert_tokens_to_string(o)
return text
def __call__(self, items):
for o in items: yield self.do_tokenize(o)
The Fastai bit
fasthugstok
and our tok_fn
Lets incorporate the tokenizer
from HuggingFace into fastai-v2's framework by specifying a function called fasthugstok
that we can then pass on to Tokenizer.from_df
. (Note .from_df
is the only method I have tested)
Max Seqence Length
max_seq_len
is the longest sequece our tokenizer will output. We can also the max sequence length for the tokenizer by changing max_seq_len
. It uses the tokenizer's default, typically 512
. 1024
or even 2048
can also be used depending on your GPU memory. Note when using pretrained models you won't be able to use a max_seq_len
larger than the default.
max_seq_len = None
sentence_pair=False
fasthugstok = partial(FastHugsTokenizer, transformer_tokenizer=tokenizer, model_name=model_name,
max_seq_len=max_seq_len, sentence_pair=sentence_pair)
We create a MLMTokenizer
class which inherits from fastai's Tokenizer
in order to fully decode
#collapse
class MLMTokenizer(Tokenizer):
def __init__(self, tokenizer, rules=None, counter=None, lengths=None, mode=None, sep=' ', **kwargs):
super().__init__(tokenizer, rules, counter, lengths, mode, sep)
def _detokenize1(self, o):return self.tokenizer.de_tokenize(o)
def decodes(self, o): return TitledStr(str(self._detokenize1(o)))
Set up fastai's Tokenizer.from_df
, we pass rules=[fix_html]
to clean up some of HTML messiness in our text. If you do not want any rules then you sould pass rules=[]
to override fastai's default text processing rules
#collapse
fastai_tokenizer = MLMTokenizer.from_df(text_cols='text', res_col_name='text', tok_func=fasthugstok,
rules=[fix_html], post_rules=[])
fastai_tokenizer.rules
class AddSpecialTokens(Transform):
"Add special token_ids to the numericalized tokens for Sequence Classification"
def __init__(self, tokenizer):
self.tok=tokenizer
def encodes(self, o):
return(TensorText(self.tok.build_inputs_with_special_tokens(list(o))))
#collapse
class MLMTokensLabels(Transform):
'''
MLM task
- Select subset of input token ids, given by `mlm_probability`
- Mask a subset of these, `mask_token_prob`
- Replace half of the first subset with random tokens
- This code most comes from the `mask_tokens` function here https://github.com/huggingface/transformers/blob/a21d4fa410dc3b4c62f93aa0e6bbe4b75a101ee9/examples/run_language_modeling.py#L66
Returns: input ids and labels
'''
def __init__(self, tokenizer=None, mlm_probability=0.15, mask_token_prob=0.8):
self.tok, self.mlm_probability, self.mask_token_prob=tokenizer, mlm_probability, mask_token_prob
def _gen_probability_matrix(self, labels):
# We sample a few tokens in each sequence for masked-LM training (with probability mlm_probability, defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = self.tok.get_special_tokens_mask(labels.tolist(), already_has_special_tokens=True)
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tok._pad_token is not None:
padding_mask = labels.eq(self.tok.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
return probability_matrix
def _replace_with_mask(self, inputs, labels, masked_indices):
# for `mask_token_prob`% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_token_prob)).bool() & masked_indices
inputs[indices_replaced] = self.tok.convert_tokens_to_ids(self.tok.mask_token)
return inputs, indices_replaced
def _replace_with_other(self, inputs, labels, masked_indices, indices_replaced):
# 1-`mask_token_prob`)/210% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tok), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
return inputs
def encodes(self, inputs):
if self.tok.mask_token is None:
raise ValueError("This tokenizer does not have a mask token which is necessary for masked language modeling.")
labels = inputs.clone()
# Get probability of whether a token will be masked
probability_matrix = self._gen_probability_matrix(labels)
# Create random mask indices according to probability matrix
masked_indices = torch.bernoulli(probability_matrix).bool()
# Mask the labels for indices that are NOT masked, we only compute loss on masked tokens
labels[~masked_indices] = -100
# Randomly replace with mask token
inputs, indices_replaced = self._replace_with_mask(inputs, labels, masked_indices)
# Randomly replace with mask token
inputs = self._replace_with_other(inputs, labels, masked_indices, indices_replaced)
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return (inputs,labels)
We change decodes
in our Numericalize
class to deal with the <loss_mask>
tokens
# collapse
@Numericalize
def decodes(self,o):
'Add the ability to parse masks for the loss function, set as `-100`'
if isinstance(o, tuple): o=o[0]
tmp_vocab=self.vocab.copy()
tmp_vocab.append('<loss_mask>')
o=[-1 if o_ == -100 else o_ for o_ in o]
return L(tmp_vocab[o_] for o_ in o if tmp_vocab[o_] != PAD)
And we modify Datasets
so as to not wrap out tuple in another tuple
# collapse
@delegates(Datasets)
class Datasets(Datasets):
"Doesn't create a tuple in __getitem__ as x is already a tuple"
def __init__(self, items=None, tfms=None, tls=None, n_inp=None, dl_type=None, **kwargs):
super().__init__(items=items, tfms=tfms, tls=tls, n_inp=n_inp, dl_type=dl_type, **kwargs)
def __getitem__(self, it):
# same as Datasets.__getitem__ but not wrapped in a tuple
res = [tl[it] for tl in self.tls]
return res[0] if is_indexer(it) else list(zip(*res))
Our dataset is now ready to be created, lets look at an some of our (x,y) that will be passed to the model. When -100
is passed to our loss function (nn.CrossEntropyLoss
) it will be ignored in the calculation. Our model will also ignore any padding tokens (usually defined as 1
) when passed to it.
#collapse-hide
splits = ColSplitter()(df)
tfms=[attrgetter("text"), fastai_tokenizer, Numericalize(vocab=tokenizer_vocab_ls),
AddSpecialTokens(tokenizer), MLMTokensLabels(tokenizer)]
dsets = Datasets(df, splits=splits, tfms=[tfms], dl_type=SortedDL)
dsets[0][0][:20], dsets[0][1][:20]
#collapse
def pad_mlm_input(samples, pad_idx=1, pad_fields=[0,1], pad_first=False, max_seq_len=None, backwards=False):
"Function that collect `samples` and adds padding, modified `max_len_l` in fastai's `pad_input`"
pad_fields = L(pad_fields)
#max_len_l = ifnone(max_seq_len, pad_fields.map(lambda f: max([len(s[f]) for s in samples])))
max_len_l = pad_fields.map(lambda f: max_seq_len)
if backwards: pad_first = not pad_first
def _f(field_idx, x):
if isinstance(x, tuple): x=(x[0]) ## Added this line too, removes tuple if present
if field_idx not in pad_fields: return x
idx = pad_fields.items.index(field_idx) #TODO: remove items if L.index is fixed
sl = slice(-len(x), sys.maxsize) if pad_first else slice(0, len(x))
pad = x.new_zeros(max_len_l[idx]-x.shape[0])+pad_idx
x1 = torch.cat([pad, x] if pad_first else [x, pad])
if backwards: x1 = x1.flip(0)
return retain_type(x1, x)
return [tuple(map(lambda idxx: _f(*idxx), enumerate(s))) for s in samples]
def transformer_mlm_padding(tokenizer=None, max_seq_len=None, sentence_pair=False):
'Uses `pad_fields=[0,1]` to pad both input and label'
if tokenizer.padding_side == 'right': pad_first=False
else: pad_first=True
max_seq_len = ifnone(max_seq_len, tokenizer.max_len)
return partial(pad_mlm_input, pad_fields=[0,1], pad_first=pad_first,
pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len)
#collapse
padding=transformer_mlm_padding(tokenizer)
bs=4
dls = dsets.dataloaders(bs=bs, before_batch=[padding])
Check our batch
We can see our special RoBERTa tokens ('<s>'
, '</s>'
), which translate to 0, 2
in its vocab, have been added to the start and end of each sequence in the batch. Your can look at these indices in tokenizer.get_vocab()
to confirm this. We can also see that most of the tokens in our target (text_
) are masked out as we only want to calculate the loss on the ~15% of the text
tokens that have been masked.
#collapse
b=dls.one_batch()
b[0].size(), b[1].size()
#collapse
dls.show_batch()
class LMModel(nn.Module):
def __init__(self, lm_model_class=None, tokenizer=None, model_name=None, config_dict=None, pretrained=False):
super().__init__()
self.tok=tokenizer
if pretrained: self.model = lm_model_class.from_pretrained(model_name)
else: self.model = lm_model_class.from_config(config_dict)
self.model = self.model.module if hasattr(self.model, "module") else self.model
self.model.resize_token_embeddings(len(tokenizer))
def forward(self, input_ids):
attention_mask = (input_ids!=self.tok.pad_token_id).type(input_ids.type())
return self.model(input_ids, attention_mask=attention_mask)[0] # only return the prediction_scores (and not hidden states and attention)
Pretrained Language Model
Lets fine-tune our pretrained Language Model. We would typically do this before training the model on our specific text. Note that here we are not training the language model head before we train the full model, but we could do so if we created a splitter and passed it to our learner
To load the pretrained HuggingFace model just use pretrained=True
when calling your model:
model = LMModel(lm_model_class=lm_model_class, tokenizer=tokenizer, model_name=model_name,
config_dict=config_dict, pretrained=True)
#collapse
opt_func = partial(Adam, decouple_wd=True)
loss = CrossEntropyLossFlat()
learn = Learner(dls, model, opt_func=opt_func, #splitter=model_splitter,
loss_func=loss, metrics=[accuracy, Perplexity()]).to_fp16()
We check our learning rate finder
#collapse-hide
learn.lr_find(suggestions=True, stop_div=False)
We do some training
#collapse-hide
learn.fit_one_cycle(10, lr_max=1e-4)
And we see how our loss progressed
Lets Look at the model's predictions
Manually checking how well our model makes predictions for masked tokens is a simple way to see how it is training
Here function get_mask_pred
takes masked string given by the user and returns the topk
predictions given by the model for that masked token. With it we can sanity check that our model has learned something useful!
*Note that get_mask_pred
is mostly code from FillMaskPipeline
in HuggingFace's Transformers repo, full credit to them!
#collapse
def get_mask_pred(model, masked_text:str, topk:int=5):
"Code lightly modified from `FillMaskPipeline` in the HuggingFace Transformers library"
aa=fastai_tokenizer.encodes(masked_text)
bb=Numericalize(vocab=tokenizer_vocab_ls)(aa)
cc=AddSpecialTokens(tokenizer)(bb)
outs=model(cc.unsqueeze(0).cuda())
masked_index = (cc == tokenizer.mask_token_id).nonzero().item()
logits = outs[0, masked_index, :]
probs = logits.softmax(dim=0)
values, predictions = probs.topk(topk)
result=[]
for i, vv in enumerate(zip(values.tolist(), predictions.tolist())):
v, p =vv
tokens = cc.numpy()
if i == 0: result.append({"word":"Input text", "score": 0., "token": 0, "sequence": tokenizer.decode(tokens)})
tokens[masked_index] = p
tokens = tokens[np.where(tokens != tokenizer.pad_token_id)]
w = tokenizer.decode(p)
result.append({"word":w, "score": v, "token": p, "sequence": tokenizer.decode(tokens)})
return pd.DataFrame(result)
Here we can input our own masked sentence and see how the model does. Note that even without fine-tuning the performance below will still be very strong as the pretrained RoBERTa model is very strong.
text2 = 'I was walking to <mask> when I came across a cat on the road'
pred2 = get_mask_pred(model1, text2);pred2.head()
Not bad at all! Now lets see how it does on a movie review, lets look at an example from our validation set. We mask the word might
from the first sentence of the reivew, ... shows what might happen...
mask_indices=[7]
txts=df.text.values
masked_text = txts[800].split(' ') # our validation split starts at index 800
masked_text[mask_indices[0]] = '<mask>'
masked_text = " ".join(masked_text)
pred1 = get_mask_pred(model1, masked_text);pred1.head()
Boom, pretty darn good! Lets try the same example, replacing ancient
in discovery of ancient documents
mask_indices=[54]
txts=df.text.values
masked_text = txts[800].split(' ') # our validation split starts at index 800
masked_text[mask_indices[0]] = '<mask>'
masked_text = " ".join(masked_text)
pred1 = get_mask_pred(model, masked_text);pred1.head()
Again, pretty solid predictions!
#collapse
model = LMModel(lm_model_class=lm_model_class, tokenizer=tokenizer, model_name=model_name,
config_dict=config_dict, pretrained=False)
opt_func = partial(Adam, decouple_wd=True)
loss = CrossEntropyLossFlat()
learn = Learner(dls, model, opt_func=opt_func, loss_func=loss, metrics=[accuracy, Perplexity()]).to_fp16()
model2=learn.model
text2 = 'I was walking to <mask> when I cam across a cat on the road'
pred2 = get_mask_pred(model2, text2);pred2.head()
Pretty bad 👎, and see how the unconfident it is in its predictions! This doesn't perform well because we have only used 800 movie reviews to train our model, we'll need a lot more text to get decent results!
Again, just for fun, lets see how it does on a movie review, lets look at an example from our validation set. We mask the word might
from the first sentence of the reivew, ... shows what might happen...
mask_indices=[7]
txts=df.text.values
masked_text = txts[800].split(' ') # our validation split starts at index 800
masked_text[mask_indices[0]] = '<mask>'
masked_text = " ".join(masked_text)
pred1 = get_mask_pred(model2, masked_text);pred1.head()
Ewww..
mask_indices=[54]
txts=df.text.values
masked_text = txts[800].split(' ') # our validation split starts at index 800
masked_text[mask_indices[0]] = '<mask>'
masked_text = " ".join(masked_text)
pred1 = get_mask_pred(model2, masked_text);pred1.head()
Yuck!
Notes & Hacky Bits
Notes
The validation set will change slightly due to random masking. While the data in the validaion set remains constant, different tokens will be masked each time the validation dataloader is called due to
MLMTokensLabels
calling a random probability each time.- If a perfectly reproducable validation set is needed then you'll probably have to create a separate transform for it's masking and set it's
split_idx = 1
.
- If a perfectly reproducable validation set is needed then you'll probably have to create a separate transform for it's masking and set it's
I didn't have time to get
learn.predict
working. One issue that needs to be fixed is thatMLMTokensLabels
transform shouldn't be called on your masked input text as it will add more masks, which you don't want.FastHugsTokenizer
will have to be modified to:- enable sequence lengths larger than the tokenizer default
- to use a non-pretrained tokenizer (e.g. one you trained yourself)
The HuggingFace
encode_plus
orbatch_encode_plus
functions are great and I would have used them, but don't play nice with fastai multiprocessiing
Hacks
I had to overwrite
__getitem__
in theDatasets
class so that it wouldn't return a tuple as what it thinks is ourx
is actually our(x,y)
. Wrapping this tuple in anoother tuple causes headaches down the line. Creating a customDatasets
class and inheriting from it didn't work aslearn.predict
calls onDatasets
and not the custom dataset class.The function
get_mask_pred
(used to view predictions of masked text) is mostly code fromFillMaskPipeline
in HuggingFace's Transformers repo, full credit to them!
Give me a shout 📣
Thats it for this, I hope you found it useful and learned a thing or two. If you have any questions or would like to get in touch you can find me on Twitter @mcgenergy