Skip to content

Commit

Permalink
upd model
Browse files Browse the repository at this point in the history
  • Loading branch information
easezyc committed Feb 28, 2023
1 parent c4b75ee commit f16b29d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions models/textcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, emb_dim, mlp_dims, dataset, dropout):
self.bert = BertModel.from_pretrained('hfl/chinese-bert-wwm-ext').requires_grad_(False)
elif dataset == 'en':
self.bert = RobertaModel.from_pretrained('roberta-base').requires_grad_(False)
self.embedding = self.bert.embeddings

feature_kernel = {1: 64, 2: 64, 3: 64, 5: 64, 10: 64}
self.convs = cnn_extractor(feature_kernel, emb_dim)
Expand All @@ -25,9 +26,8 @@ def __init__(self, emb_dim, mlp_dims, dataset, dropout):

def forward(self, **kwargs):
inputs = kwargs['content']
masks = kwargs['content_masks']
bert_feature = self.bert(inputs, attention_mask = masks).last_hidden_state
feature = self.convs(bert_feature)
feature = self.embedding(inputs)
feature = self.convs(feature)
output = self.mlp(feature)
return torch.sigmoid(output.squeeze(1))

Expand Down

0 comments on commit f16b29d

Please sign in to comment.