Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 17, 2024
1 parent 319fe5b commit 9d9e5a7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def max_steps(self) -> int:
if self.cfg.max_duration.endswith("T"):
# convert to float *first* to handle scientific notation
max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
tokens_remaining = max_tokens - self.global_train_tokens_seen
tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
steps_remaining = tokens_remaining // self.tokens_per_batch
return self.global_step + steps_remaining
elif self.cfg.max_duration.endswith("ep"):
Expand All @@ -163,7 +163,7 @@ def max_tokens(self) -> int:
if isinstance(self.cfg.max_duration, int):
return (
self.global_train_tokens_seen
+ min(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
+ max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
)
elif isinstance(self.cfg.max_duration, str):
if self.cfg.max_duration.endswith("T"):
Expand All @@ -176,7 +176,7 @@ def max_tokens(self) -> int:
# convert to float *first* to handle scientific notation
return (
self.global_train_tokens_seen
+ min(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
+ max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
)
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
Expand Down

0 comments on commit 9d9e5a7

Please sign in to comment.