Skip to content

Commit

Permalink
Raise error when passing an invalid lang code to HuggingFaceNmtEngine (
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit authored Oct 30, 2023
1 parent f2015aa commit 4f374c5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
33 changes: 25 additions & 8 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gc
from math import exp, prod
from typing import Iterable, List, Sequence, Tuple, Union, cast
from typing import Any, Iterable, List, Sequence, Tuple, Union, cast

import torch # pyright: ignore[reportMissingImports]
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, TranslationPipeline
Expand Down Expand Up @@ -30,15 +30,32 @@ def __init__(
model_config = AutoConfig.from_pretrained(str(model), label2id={}, id2label={}, num_labels=0)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(model), config=model_config))
self._tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)
if "prefix" not in pipeline_kwargs:

src_lang = pipeline_kwargs.get("src_lang")
tgt_lang = pipeline_kwargs.get("tgt_lang")
if (
src_lang is not None
and tgt_lang is not None
and "prefix" not in pipeline_kwargs
and (model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-"))
):
pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
else:
additional_special_tokens = self._tokenizer.additional_special_tokens
if (
"src_lang" in pipeline_kwargs
and "tgt_lang" in pipeline_kwargs
and (model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-"))
src_lang is not None
and src_lang not in cast(Any, self._tokenizer).lang_code_to_id
and src_lang not in additional_special_tokens
):
src_lang = pipeline_kwargs["src_lang"]
tgt_lang = pipeline_kwargs["tgt_lang"]
pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
raise ValueError(f"'{src_lang}' is not a valid language code.")

if (
tgt_lang is not None
and tgt_lang not in cast(Any, self._tokenizer).lang_code_to_id
and tgt_lang not in additional_special_tokens
):
raise ValueError(f"'{tgt_lang}' is not a valid language code.")

self._pipeline = _TranslationPipeline(
model=model,
tokenizer=self._tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

skip("skipping Hugging Face tests on MacOS", allow_module_level=True)

from pytest import approx
from pytest import approx, raises

from machine.translation.huggingface import HuggingFaceNmtEngine

Expand Down Expand Up @@ -36,3 +36,8 @@ def test_translate_greedy() -> None:
assert result.translation == "skaberskaber Dollar Dollar Dollar ፤ gerekir gerekir"
assert result.confidences[0] == approx(1.08e-05, 0.01)
assert str(result.alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7"


def test_construct_invalid_lang() -> None:
with raises(ValueError):
HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es")

0 comments on commit 4f374c5

Please sign in to comment.