I'm trying to implement first-order version of ProtoMAML (https://arxiv.org/pdf/1903.03096.pdf) for a sequence labelling task. If I use BERT as encoder, I run into this error at the line diffopt.step
: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
. Instead, if I use an LSTM as an encoder, then it runs successfully. Perhaps the graph is purged somehow since BERT is a large model?
Here is a self-contained code to replicate the issue. The issue occurs on both CPU and GPU. On line 117, you can specify the encoder as bert
or lstm
. It requires the transformers
library from HuggingFace to run.
import higher
import torch
from torch import nn, optim
from transformers import BertModel, BertTokenizer
from torch.nn import functional as F
class BaseModel(nn.Module):
def __init__(self, encoder, max_length, device):
super(BaseModel, self).__init__()
self.max_length = max_length
self.device = device
if encoder == 'bert':
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.encoder = BertModel.from_pretrained('bert-base-uncased')
self.encoder.pooler.dense.weight.requires_grad = False
self.encoder.pooler.dense.bias.requires_grad = False
elif encoder == 'lstm':
self.encoder = nn.LSTM(batch_first=True, input_size=32, hidden_size=768)
self.linear = nn.Linear(768, 192)
self.to(self.device)
def encode_text(self, text):
if isinstance(self.encoder, BertModel):
encode_result = self.tokenizer.batch_encode_plus(text, return_token_type_ids=False, max_length=self.max_length,
pad_to_max_length=True, return_tensors='pt')
for key in encode_result:
encode_result[key] = encode_result[key].to(self.device)
return encode_result
elif isinstance(self.encoder, nn.LSTM):
return torch.randn((len(text), 32, 32), device=self.device)
def forward(self, inputs):
if isinstance(self.encoder, BertModel):
out, _ = self.encoder(inputs['input_ids'], attention_mask=inputs['attention_mask'])
elif isinstance(self.encoder, nn.LSTM):
out, _ = self.encoder(inputs)
out = out[:, 1:-1, :]
out = self.linear(out)
return out
class ProtoMAML:
def __init__(self, device, encoder):
self.output_layer_weight = None
self.output_layer_bias = None
self.learner = BaseModel(encoder=encoder, max_length=32, device=device)
self.inner_optimizer = optim.SGD([p for p in self.learner.parameters() if p.requires_grad], lr=0.001)
self.loss_fn = nn.CrossEntropyLoss()
self.output_lr = 0.001
self.device = device
self.updates = 5
def output_layer(self, input, weight, bias):
return F.linear(input, self.output_layer_weight + weight, self.output_layer_bias + bias)
def initialize_with_proto_weights(self, support_repr, support_label, n_classes):
prototypes = self.build_prototypes(support_repr, support_label, n_classes)
weight = 2 * prototypes
bias = -torch.norm(prototypes, dim=1) ** 2
self.output_layer_weight = torch.zeros_like(weight, requires_grad=True)
self.output_layer_bias = torch.zeros_like(bias, requires_grad=True)
return weight, bias
def build_prototypes(self, data_repr, data_label, num_outputs):
n_dim = data_repr.shape[2]
data_repr = data_repr.view(-1, n_dim)
data_label = data_label.view(-1)
prototypes = torch.zeros((num_outputs, n_dim), device=self.device)
for c in range(num_outputs):
idx = torch.nonzero(data_label == c).view(-1)
if idx.nelement() != 0:
prototypes[c] = torch.mean(data_repr[idx], dim=0)
return prototypes
def initialize_output_layer(self, n_classes):
self.output_layer_weight = torch.randn((n_classes, 768), requires_grad=True)
self.output_layer_bias = torch.randn(n_classes, requires_grad=True)
def train(self, support_text, labels, n_classes, n_iter):
for itr in range(n_iter):
print('Iteration ', itr)
self.learner.zero_grad()
self.initialize_output_layer(n_classes)
x = self.learner.encode_text(support_text)
y = labels.to(device)
output_repr = self.learner(x)
init_weights, init_bias = self.initialize_with_proto_weights(output_repr, y, n_classes)
with higher.innerloop_ctx(self.learner, self.inner_optimizer,
copy_initial_weights=False,
track_higher_grads=False) as (flearner, diffopt):
for i in range(self.updates):
output = flearner(x)
output = self.output_layer(output, init_weights, init_bias)
output = output.view(output.size()[0] * output.size()[1], -1)
loss = self.loss_fn(output, y)
output_weight_grad, output_bias_grad = torch.autograd.grad(loss, [self.output_layer_weight, self.output_layer_bias],
retain_graph=True)
self.output_layer_weight = self.output_layer_weight - self.output_lr * output_weight_grad
self.output_layer_bias = self.output_layer_bias - self.output_lr * output_bias_grad
diffopt.step(loss)
if __name__ == '__main__':
encoder = 'bert' # or 'lstm'
support_text = [['This is a support text']] * 64
labels = torch.randint(0, 10, (64 * 30, ))
n_classes = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ProtoMAML(device=device, encoder=encoder)
model.train(support_text, labels, n_classes, n_iter=10)