Skip to content

Commit

Permalink
refactor: 중복 오류 수정 및 output에 index 추가
Browse files Browse the repository at this point in the history
[#73]
  • Loading branch information
9ooDa committed Jul 16, 2024
1 parent f579155 commit f2fec41
Show file tree
Hide file tree
Showing 4 changed files with 385 additions and 132 deletions.
25 changes: 15 additions & 10 deletions Models/grammar_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import os
import json

import torch
import uvicorn
from fastapi import FastAPI, Form
from gector.gector import GECToR, load_verb_dict, predict_verbose
from transformers import AutoTokenizer
from typing_extensions import Annotated

from gector.gector import GECToR, load_verb_dict, predict_verbose
from utils import gram_metrics, gram_out_json, gram_visualizer_json

app = FastAPI()
Expand Down Expand Up @@ -65,7 +66,8 @@ async def upload_json(
# Grammar Error Correction Process
final_corrected_sents, iteration_log = predict_verbose(**predict_args)
checker_data = gram_visualizer_json.visualizer_json(
iteration_log, final_corrected_sents
iteration_log=iteration_log,
out_sentence=final_corrected_sents
)

# dump visualized checker .json file
Expand All @@ -77,17 +79,20 @@ async def upload_json(
# Final output
phase = "phase_2" # either "phase_1" or "phase_2"
score_type = "pwc" # "ec" or "psc" or "pwc"
error_count_type = "new" # "new" or "old"
out_path = os.path.join(gector_path, "real", f"grammar_{phase}.json")
score = gram_metrics.get_score(checker_data=checker_data, score_type=score_type)
token_path = os.path.join(gector_path, "data", "token_labels.txt")

score = gram_metrics.get_score(checker_data=checker_data, score_type=score_type,
error_count_type=error_count_type, token_path=token_path)

print(
gram_out_json.create_json(
phase=phase, out_path=out_path, score=score, check_data=checker_data
)
gram_out_json.create_json(phase=phase, out_path=out_path, score=score,
checker_data=checker_data, token_path=token_path)
)

return gram_out_json.create_json(
phase=phase, out_path=out_path, score=score, check_data=checker_data
)
return gram_out_json.create_json(phase=phase, out_path=out_path, score=score,
checker_data=checker_data, token_path=token_path)

except Exception as e:
return {"text": None, "status": f"Error: {str(e)}"}
Expand Down
97 changes: 68 additions & 29 deletions Models/utils/gram_metrics.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,80 @@
import string
from typing import Dict, Literal, get_args

from .gram_out_json import get_cleaned_token_list, get_scrs_tok
from gram_out_json import get_cleaned_token_list, get_scrs_tok, get_tag_grammar, get_phase_2_inner_data


# ec = error count, psc = per sentence count, pwc = per word count
_STYPES = Literal["ec", "psc", "pwc"]
_ETYPES = Literal["new", "old"]


def get_error_count(
checker_data: Dict,
error_count_type: _ETYPES = "new",
token_path: str = "../gector/data/token_labels.txt"
):
metric_data = {"main": {}}
ctl = get_cleaned_token_list()
for og_sent, inner_dict in checker_data.items():
if inner_dict["edited"] == True:
inner_data = get_scrs_tok(inner_dict, ctl)
metric_data["main"][og_sent] = inner_data
else:
metric_data["main"][og_sent] = {
"edited": False,
"corrected_sentence": inner_dict["fin_sentence"],
}

error_count = 0
for og in metric_data["main"]:
v = metric_data["main"][og]
if v["edited"] == True:
error_count += len(v["action_tag"])
ctl = get_cleaned_token_list(token_path)

# 중복 에러 없앤 버전
if error_count_type == "new":
og_list = list(checker_data.keys())
# corr_list = [v.get('fin_sentence') for _, v in checker_data.items()]
tag_grammar = get_tag_grammar(ctl)

for og_sent, inner_dict in checker_data.items():
gector_dict = get_scrs_tok(inner_dict, ctl)
sid = og_list.index(og_sent)
sent = og_sent
corr_sent = inner_dict["fin_sentence"]
edited = inner_dict["edited"]
ref_word_list = gector_dict["og_word"]
tag_list = gector_dict["full_tag"]
if edited == True:
inner = get_phase_2_inner_data(sid=sid, sent=sent, corr_sent=corr_sent, edited=edited, ref_word_list=ref_word_list,
tag_list=tag_list, tag_grammar=tag_grammar)
error_count += len(inner["ref_word"])

# 기존 버전
elif error_count_type == "old":
metric_data = {
"main": {}
}
for og_sent, inner_dict in checker_data.items():
if inner_dict["edited"] == True:
inner_data = get_scrs_tok(inner_dict, ctl)
metric_data["main"][og_sent] = inner_data
else:
metric_data["main"][og_sent] = {"edited": False, "corrected_sentence": inner_dict["fin_sentence"]}

for og in metric_data["main"]:
v = metric_data["main"][og]
if v["edited"] == True:
error_count += len(v["action_tag"])

return error_count


def get_error_rate_sen(
checker_data: Dict,
error_count_type: _ETYPES = "new",
token_path: str = "../gector/data/token_labels.txt"
):
og_list = list(checker_data.keys())
error_count = get_error_count(checker_data=checker_data)
error_count = get_error_count(checker_data=checker_data, error_count_type=error_count_type, token_path=token_path)
sentence_count = len(og_list)

return round(error_count / sentence_count, 2)


def get_error_rate_word(
checker_data: Dict,
error_count_type: _ETYPES = "new",
token_path: str = "../gector/data/token_labels.txt"
):
og_list = list(checker_data.keys())
error_count = get_error_count(checker_data=checker_data)
error_count = get_error_count(checker_data=checker_data, error_count_type=error_count_type, token_path=token_path)

word_count = 0
for sen in og_list:
Expand All @@ -54,18 +86,25 @@ def get_error_rate_word(
return round(result, 2)


# ec = error count, psc = per sentence count, pwc = per word count
_TYPES = Literal["ec", "psc", "pwc"]


def get_score(checker_data: Dict, score_type: _TYPES = "pwc"):
def get_score(
checker_data: Dict,
score_type: _STYPES = "pwc",
error_count_type: _ETYPES = "new",
token_path: str = "../gector/data/token_labels.txt"
):
# raise error if the score type is not in the _TYPE list
options = get_args(_TYPES)
assert score_type in options, f"'{score_type}' is not in {options}"
soptions = get_args(_STYPES)
eoptions = get_args(_ETYPES)
assert score_type in soptions, f"'{score_type}' is not in {soptions}"
assert error_count_type in eoptions, f"'{error_count_type}' is not in {eoptions}"

if score_type == "ec":
return get_error_count(checker_data=checker_data)
return get_error_count(checker_data=checker_data, error_count_type=error_count_type, token_path=token_path)
elif score_type == "psc":
return get_error_rate_sen(checker_data=checker_data)
return get_error_rate_sen(checker_data=checker_data, error_count_type=error_count_type, token_path=token_path)
elif score_type == "pwc":
return get_error_rate_word(checker_data=checker_data)
return get_error_rate_word(checker_data=checker_data, error_count_type=error_count_type, token_path=token_path)

if __name__ == "__main__":
token_path = "../gector/data/token_labels.txt"
get_cleaned_token_list(token_path)
Loading

0 comments on commit f2fec41

Please sign in to comment.