Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hemildesai committed Jul 18, 2024
1 parent 0dabb76 commit e4d1dbf
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 26 deletions.
11 changes: 5 additions & 6 deletions nemo/collections/llm/models/llama2_7b.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import nemo_sdk as sdk
import pytorch_lightning as pl

from nemo import lightning as nl
Expand All @@ -8,7 +7,7 @@
from nemo.collections.llm.models.log.default import default_log
from nemo.collections.llm.models.optim.adam import adam_with_cosine_annealing
from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.utils import factory
from nemo.collections.llm.utils import Partial, factory

NAME = "llama2_7b"

Expand Down Expand Up @@ -37,8 +36,8 @@ def hf_resume() -> nl.AutoResume:


@factory(name=NAME, for_task="llm.pretrain")
def pretrain_recipe() -> sdk.Partial:
return sdk.Partial(
def pretrain_recipe() -> Partial:
return Partial(
pretrain,
model=model,
trainer=trainer,
Expand All @@ -49,8 +48,8 @@ def pretrain_recipe() -> sdk.Partial:


@factory(name=NAME, for_task="llm.finetune")
def finetune_recipe() -> sdk.Partial:
return sdk.Partial(
def finetune_recipe() -> Partial:
return Partial(
finetune,
model=model,
trainer=trainer,
Expand Down
11 changes: 5 additions & 6 deletions nemo/collections/llm/models/llama3_8b.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import nemo_sdk as sdk
import pytorch_lightning as pl

from nemo import lightning as nl
Expand All @@ -8,7 +7,7 @@
from nemo.collections.llm.models.log.default import default_log
from nemo.collections.llm.models.optim.adam import adam_with_cosine_annealing
from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.utils import factory
from nemo.collections.llm.utils import Partial, factory

NAME = "llama3_8b"

Expand Down Expand Up @@ -37,8 +36,8 @@ def hf_resume() -> nl.AutoResume:


@factory(name=NAME, for_task="llm.pretrain")
def pretrain_recipe() -> sdk.Partial:
return sdk.Partial(
def pretrain_recipe() -> Partial:
return Partial(
pretrain,
model=model,
trainer=trainer,
Expand All @@ -49,8 +48,8 @@ def pretrain_recipe() -> sdk.Partial:


@factory(name=NAME, for_task="llm.finetune")
def finetune_recipe() -> sdk.Partial:
return sdk.Partial(
def finetune_recipe() -> Partial:
return Partial(
finetune,
model=model,
trainer=trainer,
Expand Down
7 changes: 3 additions & 4 deletions nemo/collections/llm/models/llama3_8b_16k.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import nemo_sdk as sdk
import pytorch_lightning as pl

from nemo import lightning as nl
Expand All @@ -7,7 +6,7 @@
from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel
from nemo.collections.llm.models.log.default import default_log
from nemo.collections.llm.models.optim.adam import adam_with_cosine_annealing
from nemo.collections.llm.utils import factory
from nemo.collections.llm.utils import Partial, factory

NAME = "llama3_8b_16k"

Expand Down Expand Up @@ -36,8 +35,8 @@ def trainer(devices=8) -> nl.Trainer:


@factory(name=NAME, for_task="llm.pretrain")
def pretrain_recipe() -> sdk.Partial:
return sdk.Partial(
def pretrain_recipe() -> Partial:
return Partial(
pretrain,
model=model,
trainer=trainer,
Expand Down
7 changes: 3 additions & 4 deletions nemo/collections/llm/models/llama3_8b_64k.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import nemo_sdk as sdk
import pytorch_lightning as pl

from nemo import lightning as nl
Expand All @@ -7,7 +6,7 @@
from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel
from nemo.collections.llm.models.log.default import default_log
from nemo.collections.llm.models.optim.adam import adam_with_cosine_annealing
from nemo.collections.llm.utils import factory
from nemo.collections.llm.utils import Partial, factory

NAME = "llama3_8b_64k"

Expand Down Expand Up @@ -36,8 +35,8 @@ def trainer(devices=8) -> nl.Trainer:


@factory(name=NAME, for_task="llm.pretrain")
def pretrain_recipe() -> sdk.Partial:
return sdk.Partial(
def pretrain_recipe() -> Partial:
return Partial(
pretrain,
model=model,
trainer=trainer,
Expand Down
11 changes: 5 additions & 6 deletions nemo/collections/llm/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import nemo_sdk as sdk
import pytorch_lightning as pl

from nemo import lightning as nl
Expand All @@ -8,7 +7,7 @@
from nemo.collections.llm.models.log.default import default_log
from nemo.collections.llm.models.optim.adam import adam_with_cosine_annealing
from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.utils import factory
from nemo.collections.llm.utils import Partial, factory

NAME = "mistral"

Expand Down Expand Up @@ -37,8 +36,8 @@ def hf_resume() -> nl.AutoResume:


@factory(name=NAME, for_task="llm.pretrain")
def pretrain_recipe() -> sdk.Partial:
return sdk.Partial(
def pretrain_recipe() -> Partial:
return Partial(
pretrain,
model=model,
trainer=trainer,
Expand All @@ -49,8 +48,8 @@ def pretrain_recipe() -> sdk.Partial:


@factory(name=NAME, for_task="llm.finetune")
def finetune_recipe() -> sdk.Partial:
return sdk.Partial(
def finetune_recipe() -> Partial:
return Partial(
finetune,
model=model,
trainer=trainer,
Expand Down

0 comments on commit e4d1dbf

Please sign in to comment.