1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| from torch import nn
class KoBERTClassifier(nn.Module): def __init__(self, bert, hidden_size = 768, num_classes=6, dr_rate=None): super(BERTClassifier, self).__init__() self.bert = bert self.dr_rate = dr_rate
self.classifier = nn.Linear(hidden_size , num_classes)
if dr_rate: self.dropout = nn.Dropout(p=dr_rate) def gen_attention_mask(self, token_ids, valid_length): attention_mask = torch.zeros_like(token_ids) for i, v in enumerate(valid_length): attention_mask[i][:v] = 1 return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids): attention_mask = self.gen_attention_mask(token_ids, valid_length) _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
if self.dr_rate: out = self.dropout(pooler)
return self.classifier(out)
bertmodel, vocab = get_pytorch_kobert_model()
model = KoBERTClassifier(bertmodel, dr_rate=0.5).to(device)
|