Skip to content

Commit

Permalink
fix: Grammar 모델 negative score 오류 수정
Browse files Browse the repository at this point in the history
[#67]
  • Loading branch information
9ooDa committed Apr 3, 2024
1 parent a9218d1 commit f579155
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
12 changes: 1 addition & 11 deletions Models/grammar_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json
import os

import nltk
import torch
import uvicorn
from fastapi import FastAPI, Form
from gector.gector import GECToR, load_verb_dict, predict_verbose
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer
from typing_extensions import Annotated
from utils import gram_metrics, gram_out_json, gram_visualizer_json
Expand All @@ -24,7 +22,6 @@
# }

model_path = "gotutiyan/gector-roberta-large-5k"
# input_path = os.path.join(gector_path, "input", "macominatya.json")

model = GECToR.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
Expand All @@ -40,9 +37,7 @@ async def upload_json(
gector_path = "./gector"
verb_path = "./gector/data/verb-form-vocab.txt"

nltk.download("punkt")
print(text)
srcs = sent_tokenize(text)
srcs = gram_visualizer_json.process_input_text(text)
encode, decode = load_verb_dict(verb_path)

transcription = text
Expand Down Expand Up @@ -79,11 +74,6 @@ async def upload_json(
with open(check, "w", encoding="utf-8") as c:
json.dump(checker_data, c, indent="\t")

# Run this part if you want the metric json file
# metric = os.path.join(gector_path, "metric", "metric.json")
# with open(metric, "w", encoding="utf-8") as r:
# json.dump(metric_data, r, indent="\t")

# Final output
phase = "phase_2" # either "phase_1" or "phase_2"
score_type = "pwc" # "ec" or "psc" or "pwc"
Expand Down
18 changes: 18 additions & 0 deletions Models/utils/gram_visualizer_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
from transformers import AutoTokenizer


def process_input_text(input_text: str):
punctuation_marks = {'.', '!', '?'}

sentences = []
start = 0
for i, char in enumerate(input_text):
if char in punctuation_marks:
sentence = input_text[start:i+1].strip()
if sentence:
sentences.append(sentence)
start = i + 1

if start < len(input_text):
sentences.append(input_text[start:].strip())

return sentences


def visualizer_json(iteration_log: List[List[Dict]], out_sentence):
# Generate a string to visualize the predictions.
outer_data = {}
Expand Down

0 comments on commit f579155

Please sign in to comment.