Skip to content

Commit

Permalink
Second round of fixes for OOM errors
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 9, 2023
1 parent 562bab2 commit c61f841
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 12 deletions.
9 changes: 8 additions & 1 deletion machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Any, cast

Expand All @@ -15,6 +16,8 @@
from ..nmt_model_factory import NmtModelFactory
from ..shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class HuggingFaceNmtModelFactory(NmtModelFactory):
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
Expand Down Expand Up @@ -67,7 +70,11 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens,
)

def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
if half_previous_batch_size:
self._config.huggingface.generate_params.batch_size = max(
self._config.huggingface.generate_params.batch_size // 2, 1
)
return HuggingFaceNmtEngine(
self._model,
src_lang=self._config.src_lang,
Expand Down
24 changes: 14 additions & 10 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,16 @@ def run(
inference_step_count = sum(1 for _ in src_pretranslations)
with ExitStack() as stack:
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
engine = stack.enter_context(self._nmt_model_factory.create_engine())
src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations())
writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer())
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["batch_size"]
translate_batch = TranslateBatch(batch_size)
translate_batch = TranslateBatch(stack, self._nmt_model_factory)
for pi_batch in batch(src_pretranslations, batch_size):
if check_canceled is not None:
check_canceled()
translate_batch.translate(engine, pi_batch, writer)
translate_batch.translate(pi_batch, writer)
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))

Expand All @@ -100,25 +99,30 @@ def run(


class TranslateBatch:
def __init__(self, initial_batch_size):
self.batch_size = initial_batch_size
def __init__(self, stack: ExitStack, nmt_model_factory: NmtModelFactory):
self._stack = stack
self._nmt_model_factory = nmt_model_factory
self._engine = self._stack.enter_context(self._nmt_model_factory.create_engine())

def translate(
self,
engine: TranslationEngine,
batch: Sequence[PretranslationInfo],
writer: PretranslationWriter,
) -> None:
while True:
source_segments = [pi["translation"] for pi in batch]
outer_batch_size = len(source_segments)
try:
for step in range(0, outer_batch_size, self.batch_size):
for i, result in enumerate(engine.translate_batch(source_segments[step : step + self.batch_size])):
for step in range(0, outer_batch_size, self._engine.get_batch_size()):
for i, result in enumerate(
self._engine.translate_batch(source_segments[step : step + self._engine.get_batch_size()])
):
batch[i + step]["translation"] = result.translation
for i in range(len(source_segments)):
writer.write(batch[i])
break
except Exception:
self.batch_size = max(self.batch_size // 2, 1)
logger.info(f"Out of memory error, reducing batch size to {self.batch_size}")
logger.info(f"Out of memory error, reducing batch size to {self._engine.get_batch_size() // 2}")
self._engine = self._stack.enter_context(
self._nmt_model_factory.create_engine(half_previous_batch_size=True)
)
2 changes: 1 addition & 1 deletion machine/jobs/nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
...

@abstractmethod
def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
...

@abstractmethod
Expand Down
5 changes: 5 additions & 0 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
):
raise ValueError(f"'{tgt_lang}' is not a valid language code.")

self._batch_size = int(pipeline_kwargs.get("batch_size")) # type: ignore[assignment]

self._pipeline = _TranslationPipeline(
model=model,
tokenizer=self._tokenizer,
Expand All @@ -71,6 +73,9 @@ def translate_n(self, n: int, segment: Union[str, Sequence[str]]) -> Sequence[Tr
def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequence[TranslationResult]:
return [results[0] for results in self.translate_n_batch(1, segments)]

def get_batch_size(self) -> int:
return self._batch_size

def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
Expand Down
4 changes: 4 additions & 0 deletions machine/translation/translation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def translate_n(self, n: int, segment: Union[str, Sequence[str]]) -> Sequence[Tr
def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequence[TranslationResult]:
...

@abstractmethod
def get_batch_size(self) -> int:
...

@abstractmethod
def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
Expand Down

0 comments on commit c61f841

Please sign in to comment.