Skip to content

Commit

Permalink
Merge pull request #134 from instructlab/logging-updates
Browse files Browse the repository at this point in the history
Logging updates
  • Loading branch information
aldopareja authored Jul 10, 2024
2 parents 2b744af + 07c7f63 commit 6164dbf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
8 changes: 6 additions & 2 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ def main(args):
import yaml

metric_logger = AsyncStructuredLogger(
args.output_dir + "/training_params_and_metrics.jsonl"
args.output_dir
+ f"/training_params_and_metrics_global{os.environ['RANK']}.jsonl"
)
if os.environ["LOCAL_RANK"] == "0":
print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m")
Expand Down Expand Up @@ -658,7 +659,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
print(f"\033[92mRunning command: {' '.join(command)}\033[0m")
process = None
try:
process = StreamablePopen(command)
process = StreamablePopen(
f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log",
command,
)

except KeyboardInterrupt:
print("Process interrupted by user")
Expand Down
24 changes: 13 additions & 11 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,23 @@ class StreamablePopen(subprocess.Popen):
Provides a way of reading stdout and stderr line by line.
"""

def __init__(self, *args, **kwargs):
def __init__(self, output_file, *args, **kwargs):
# remove the stderr and stdout from kwargs
kwargs.pop("stderr", None)
kwargs.pop("stdout", None)

super().__init__(*args, **kwargs)
while True:
if self.stdout:
output = self.stdout.readline().strip()
print(output)
if self.stderr:
error = self.stderr.readline().strip()
print(error, file=sys.stderr)
if self.poll() is not None:
break
super().__init__(
*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs
)
with open(output_file, "wb") as full_log_file:
while True:
byte = self.stdout.read(1)
if byte:
sys.stdout.buffer.write(byte)
sys.stdout.flush()
full_log_file.write(byte)
else:
break


def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000):
Expand Down

0 comments on commit 6164dbf

Please sign in to comment.