diff --git a/models/textcnn.py b/models/textcnn.py index ca6ce87..abbe20e 100644 --- a/models/textcnn.py +++ b/models/textcnn.py @@ -17,6 +17,7 @@ def __init__(self, emb_dim, mlp_dims, dataset, dropout): self.bert = BertModel.from_pretrained('hfl/chinese-bert-wwm-ext').requires_grad_(False) elif dataset == 'en': self.bert = RobertaModel.from_pretrained('roberta-base').requires_grad_(False) + self.embedding = self.bert.embeddings feature_kernel = {1: 64, 2: 64, 3: 64, 5: 64, 10: 64} self.convs = cnn_extractor(feature_kernel, emb_dim) @@ -25,9 +26,8 @@ def __init__(self, emb_dim, mlp_dims, dataset, dropout): def forward(self, **kwargs): inputs = kwargs['content'] - masks = kwargs['content_masks'] - bert_feature = self.bert(inputs, attention_mask = masks).last_hidden_state - feature = self.convs(bert_feature) + feature = self.embedding(inputs) + feature = self.convs(feature) output = self.mlp(feature) return torch.sigmoid(output.squeeze(1))