Skip to content

Commit

Permalink
make the desired changes as said here -> X-PLUG#22 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
wttc-nitr committed May 7, 2023
1 parent 64cfe45 commit 6bbb72e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions apex_22.01_pp/csrc/mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at

// create output/workspace tensor
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
auto reserved_space = at::empty({reserved_size}, inputs[0].type());
auto reserved_space = at::empty({static_cast<long>(reserved_size)}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, inputs[0].type());

Expand Down Expand Up @@ -135,7 +135,7 @@ std::vector<at::Tensor> mlp_backward(
get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());

// auto work_space = at::empty({work_size*4}, at::kByte);
auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type());
auto work_space = at::empty({static_cast<long>(work_size / sizeof(scalar_t))}, inputs[0].type());

auto result = mlp_bp<scalar_t>(
inputs[0].data_ptr<scalar_t>(),
Expand Down
2 changes: 1 addition & 1 deletion interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_model(checkpoint_path=None, tokenizer_path=None, peft_config=None, devic
tokenizer.eod_id = tokenizer.eos_token_id
img_processor = ImageProcessor()

model = model.to(device)
model = model.to(dtype)
model = model.to(device)
return model, tokenizer, img_processor


Expand Down

0 comments on commit 6bbb72e

Please sign in to comment.