From 00f8361ac2adadb9940b027e8ded42bf7ecd541a Mon Sep 17 00:00:00 2001 From: SangwonYoon Date: Thu, 13 Apr 2023 17:35:31 +0000 Subject: [PATCH] =?UTF-8?q?monolomonologg/koelectra-base-finetuned-nsmc=20?= =?UTF-8?q?/=20=EC=A0=84=EC=B2=98=EB=A6=AC=20=EB=8D=B0=EC=9D=B4=ED=84=B0?= =?UTF-8?q?=20#27?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 60 +++++++++++++++++++++++++++++++++++++++++-------- wandb_tuning.py | 2 +- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 83f5fd0..7e5412d 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,10 @@ import random import nltk from nltk.corpus import stopwords +import re +from soynlp.normalizer import repeat_normalize +from soynlp.tokenizer import RegexTokenizer +# from konlpy.tag import Hannanum TODO def compute_pearson_correlation(pred): preds = pred.predictions.flatten() @@ -59,6 +63,18 @@ def __init__(self, state, data_file, text_columns, target_columns=None, delete_c self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.inputs, self.targets = self.preprocessing(self.data) + # create Korean stemmer and lemmatizer + # self.stemmer = Hannanum() TODO + + # 한글 불용어 파일을 다운로드 + # wget -O korean_stopwords.txt https://www.ranks.nl/stopwords/korean + + # 파일의 내용을 읽어 불용어 리스트 생성 + # with open('korean_stopwords.txt', 'r', encoding='utf-8') as f: + # stopwords = f.read().splitlines() + + # self.stopwords = stopwords + def __getitem__(self, idx): if len(self.targets) == 0: return torch.tensor(self.inputs[idx]) @@ -73,11 +89,37 @@ def remove_stopwords(self, text): words = [word for word in words if word not in stopwords] return ' '.join(words) + def preprocess_text_wrapper(self, text_list): + text1, text2 = text_list + return self.preprocess_text(text1), self.preprocess_text(text2) + + def preprocess_text(self, text): + # create Korean tokenizer using soynlp library + # tokenizer = RegexTokenizer() + + # 2회 이상 반복된 문자를 정규화 + text = repeat_normalize(text, num_repeats=2) + # 불용어 제거 + # text = ' '.join([token for token in text.split() if not token in stopwords]) + # 대문자를 소문자로 변경 + text = text.lower() + # ""을 "사람"으로 변경 + text = re.sub('', '사람', text) + # 한글 문자, 영어 문자, 공백 문자를 제외한 모든 문자 제거 + text = re.sub('[^가-힣a-z\\s]', '', text) + # 텍스트를 토큰으로 분리 예) "안녕하세요" -> "안녕", "하", "세요" + # tokens = tokenizer.tokenize(text) + # 어간 추출 + # tokens = [self.stemmer.morphs(token)[0] for token in text.split()] + # join tokens back into sentence + # text = ' '.join(tokens) + return text + def tokenizing(self, dataframe): data = [] for idx, item in tqdm(dataframe.iterrows(), desc='Tokenizing', total=len(dataframe)): - text = '[SEP]'.join([item[text_column] for text_column in self.text_columns]) + text = '[SEP]'.join([self.preprocess_text(item[text_column]) for text_column in self.text_columns]) ##불용어 제거 outputs = self.tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=self.max_length) @@ -98,24 +140,24 @@ def preprocessing(self, data): if __name__ == '__main__': seed_everything(42) - model = AutoModelForSequenceClassification.from_pretrained("lighthouse/mdeberta-v3-base-kor-further",num_labels=1,ignore_mismatched_sizes=True) + model = AutoModelForSequenceClassification.from_pretrained("monologg/koelectra-base-finetuned-nsmc",num_labels=1,ignore_mismatched_sizes=True) #model = AutoModelForSequenceClassification.from_pretrained("E:/nlp/checkpoint/best_acc/checkpoint-16317",num_labels=1,ignore_mismatched_sizes=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) - Train_textDataset = Train_val_TextDataset('train','./data/train.csv',['sentence_1', 'sentence_2'],'label','binary-label',max_length=512,model_name="lighthouse/mdeberta-v3-base-kor-further") - Val_textDataset = Train_val_TextDataset('val','./data/dev.csv',['sentence_1', 'sentence_2'],'label','binary-label',max_length=512,model_name="lighthouse/mdeberta-v3-base-kor-further") + Train_textDataset = Train_val_TextDataset('train','./data/train.csv',['sentence_1', 'sentence_2'],'label','binary-label',max_length=512,model_name="monologg/koelectra-base-finetuned-nsmc") + Val_textDataset = Train_val_TextDataset('val','./data/dev.csv',['sentence_1', 'sentence_2'],'label','binary-label',max_length=512,model_name="monologg/koelectra-base-finetuned-nsmc") args = TrainingArguments( - "E:/nlp/checkpoint/best_acc_mdeberta", + "./checkpoint/baseline_Test_fine_3.073982620831417e-05", evaluation_strategy = "epoch", save_strategy = "epoch", - learning_rate=0.00002340865224868444, #0.000005 - per_device_train_batch_size=8, - per_device_eval_batch_size=8, + learning_rate=0.00003073982620831417, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, num_train_epochs=8, - weight_decay=0.5, + weight_decay=0.2, load_best_model_at_end=True, dataloader_num_workers = 4, logging_steps=200, diff --git a/wandb_tuning.py b/wandb_tuning.py index 6e489d6..68d56f0 100644 --- a/wandb_tuning.py +++ b/wandb_tuning.py @@ -115,7 +115,7 @@ def preprocessing(self, data): # 하이퍼 파라미터 sweep config sweep_config = { - 'method': 'random', + 'method': 'random' 'parameters' = parameters_dict }