diff --git a/examples/language_model/gpt/README.md b/examples/language_model/gpt/README.md index f385671da106a..1f72fd891d77c 100644 --- a/examples/language_model/gpt/README.md +++ b/examples/language_model/gpt/README.md @@ -14,6 +14,7 @@ GPT-[2](https://cdn.openai.com/better-language-models/language_models_are_unsupe ├── decompress.sh # 数据集解压脚本 ├── deploy/ # 模型部署的inference脚本 ├── export_model.py # 导出预测部署的模型脚本 +├── faster_gpt/ # 使用 FasterGPT 高性能预测 sample ├── lr.py # 学习率控制 ├── predict.py # 生成文本示例demo ├── README.md # 文档 diff --git a/examples/language_model/gpt/faster_gpt/README.md b/examples/language_model/gpt/faster_gpt/README.md new file mode 100644 index 0000000000000..6a95a1c50a53f --- /dev/null +++ b/examples/language_model/gpt/faster_gpt/README.md @@ -0,0 +1,193 @@ +# Faster GPT 使用 + +在这里我们集成了 NVIDIA [Faster Transformer](https://github.com/NVIDIA/FasterTransformer/tree/v3.1) 用于预测加速。同时集成了 FasterGPT float32 以及 float16 预测。以下是使用 FasterGPT 的使用说明。 + +## 使用环境说明 + +* 本项目依赖于 PaddlePaddle 2.1.0 及以上版本或适当的 develop 版本 +* CMake >= 3.10 +* CUDA 10.1(需要 PaddlePaddle 框架一致) +* gcc 版本需要与编译 PaddlePaddle 版本一致,比如使用 gcc8.2 +* 推荐使用 Python3 +* [Faster Transformer](https://github.com/NVIDIA/FasterTransformer/tree/v3.1#setup) 使用必要的环境 + +## 快速开始 + +我们实现了基于 GPU 的 FasterGPT 的自定义 op 的接入。接下来,我们将分别介绍基于 Python 动态图和预测库使用 FasterGPT 自定义 op 的方式,包括 op 的编译与使用。 + +## Python 动态图使用自定义 op + +### 编译自定义OP + +在 Python 动态图下使用自定义 OP 需要将实现的 C++、CUDA 代码编译成动态库,我们已经提供对应的 CMakeLists.txt ,可以参考使用如下的方式完成编译。同样的自定义 op 编译的说明也可以在自定义 op 对应的路径 `PaddleNLP/paddlenlp/ops/` 下面找到。 + +#### 克隆 PaddleNLP + +首先,因为需要基于当前环境重新编译,当前的 paddlenlp 的 python 包里面并不包含 FasterGPT 相关 lib,需要从源码自行编译,可以直接使用 Python 的 package 下的 paddlenlp,或是可从 github 克隆一个 PaddleNLP,并重新编译。 + +``` sh +git clone https://github.com/PaddlePaddle/PaddleNLP.git +``` + +其次,配置环境变量,让我们可以使用当前 clone 的 paddlenlp,并进入到自定义 OP 的路径,准备后续的编译操作: + +``` sh +export PYTHONPATH=$PWD/PaddleNLP/:$PYTHONPATH +cd PaddleNLP/paddlenlp/ops/ +``` + +#### 编译 + +编译之前,请确保安装的 PaddlePaddle 的版本是大于 2.1.0 或是最新的 develop 分支的代码编译,并且正常可用。 + +编译自定义 OP 可以参照一下步骤: + +``` sh +mkdir build +cd build/ +cmake .. -DSM=xx -DCMAKE_BUILD_TYPE=Release -DPY_CMD=python3.x -DWITH_GPT=ON +make -j +cd ../ +``` + +其中, +* `-DSM`: 是指的所用 GPU 的 compute capability。举例来说,可以将之指定为 70(V100) 或是 75(T4)。 +* `-DPY_CMD`: 是指编译所使用的 python,若未指定 `-DPY_CMD` 将会默认使用系统命令 `python` 对应的 Python 版本。 +* `-DWITH_GPT`: 是指是否编译带有 FasterGPT 自定义 op 的动态库。 + + +最终,编译会在 `./build/lib/` 路径下,产出 `libdecoding_op.so`,即需要的 FasterGPT decoding 执行的库。 + +### 使用 GPT-2 decoding 高性能推理 + +编写 python 脚本的时候,调用 `FasterGPT` API 并传入 `libdecoding_op.so` 的位置即可实现将 FasterGPT 用于当前的预测。 + +``` python +from paddlenlp.ops import FasterGPT +from paddlenlp.transformers import GPTModel, GPTForPretraining + +MODEL_CLASSES = { + "gpt2-medium-en": (GPTLMHeadModel, GPTTokenizer), +} + +model_class, tokenizer_class = MODEL_CLASSES[args.model_name] +tokenizer = tokenizer_class.from_pretrained(args.model_name) +model = model_class.from_pretrained(args.model_name) + +# Define model +gpt = FasterGPT( + model=model, + candidate_num=args.candidate_num, + probability_threshold=args.probability_threshold, + max_seq_len=args.max_seq_len, + start_id=start_id, + end_id=end_id, + temperature=args.temperature, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) +``` + +目前,GPT-2 的例子仅支持 `batch size` 为 `1` 或是 batch 内输入的样本的长度都是相同的情况。并且,仅支持 topk-sampling 和 topp-sampling,不支持 beam-search。 + +更详细的例子可以参考 `./infer.py`,我们提供了更详细用例。 + +#### 执行 GPT-2 decoding on PaddlePaddle + +使用 PaddlePaddle 仅执行 decoding 测试(float32): + +``` sh +export CUDA_VISIBLE_DEVICES=0 +python infer.py --model_name_or_path gpt2-medium-en --decoding_lib ./build/lib/libdecoding_op.so --batch_size 1 --topk 4 --topp 0.0 --max_out_len 32 --start_token "<|endoftext|>" --end_token "<|endoftext|>" --temperature 1.0 +``` + +其中,各个选项的意义如下: +* `--model_name_or_path`: 预训练模型的名称或是路径。 +* `--decoding_lib`: 指向 `libdecoding_op.so` 的路径。需要包含 `libdecoding_op.so`。若不存在则将自动进行 jit 编译产出该 lib。 +* `--batch_size`: 一个 batch 内,样本数目的大小。 +* `--candidate_num`: 执行 topk-sampling 的时候的 `k` 的大小,默认是 4。 +* `--probability_threshold`: 执行 topp-sampling 的时候的阈值的大小,默认是 0.0 表示不执行 topp-sampling。 +* `--max_seq_len`: 最长的生成长度。 +* `--start_token`: 字符串,表示任意生成的时候的开始 token。 +* `--end_token`: 字符串,生成的结束 token。 +* `--temperature`: temperature 的设定。 +* `--use_fp16_decoding`: 是否使用 fp16 进行推理。 + + +## C++ 预测库使用自定义 op + +### 编译自定义OP + +在 C++ 预测库使用自定义 OP 需要将实现的 C++、CUDA 代码**以及 C++ 预测的 demo**编译成一个可执行文件。因预测库支持方式与 Python 不同,这个过程将不会产生自定义 op 的动态库,将直接得到可执行文件。我们已经提供对应的 CMakeLists.txt ,可以参考使用如下的方式完成编译。并获取执行 demo。 + +#### 克隆 PaddleNLP + +首先,仍然是需要克隆一个 PaddleNLP: + +``` sh +git clone https://github.com/PaddlePaddle/PaddleNLP.git +``` + +其次,让我们可以使用当前 clone 的 paddlenlp,并进入到自定义 OP 的路径,准备后续的编译操作: + +``` sh +cd PaddleNLP/paddlenlp/ops/ +``` + +#### 编译 + +编译之前,请确保安装的 PaddlePaddle 预测库的版本是基于最新的 develop 分支的代码编译,并且正常可用。 + +编译自定义 OP 可以参照一下步骤: + +``` sh +mkdir build +cd build/ +cmake .. -DSM=xx -DWITH_GPT=ON -DCMAKE_BUILD_TYPE=Release -DPADDLE_LIB=/path/to/paddle_inference_lib/ -DDEMO=./demo/gpt.cc -DWITH_STATIC_LIB=OFF -DON_INFER=ON -DWITH_MKL=ON +make -j +cd ../ +``` + +注意: +* `xx` 是指的所用 GPU 的 compute capability。举例来说,可以将之指定为 70(V100) 或是 75(T4)。 +* `-DPADDLE_LIB` 需要指明使用的 PaddlePaddle 预测库的路径 `/path/to/paddle_inference_install_dir/`,并且在该路径下,预测库的组织结构满足: + ```text + . + ├── CMakeCache.txt + ├── paddle/ + ├── include/ + └── lib/ + ├── third_party/ + ├── cudaerror/ + ├── install/ + └── threadpool/ + └── version.txt + ``` +* `-DDEMO` 说明预测库使用 demo 的位置。比如指定 -DDEMO=./demo/gpt.cc。最好使用绝对路径,若使用相对路径,需要是相对于 `PaddleNLP/paddlenlp/ops/faster_transformer/src/` 的相对路径。 +* `-DWITH_GPT`,如果是编译 GPT 的预测库可执行文件,需要加上 `-DWITH_GPT=ON`。 +* **当使用预测库的自定义 op 的时候,请务必开启 `-DON_INFER=ON` 选项,否则,不会得到预测库的可执行文件。** + +#### 执行 GPT decoding on PaddlePaddle + +如果需要使用 Paddle Inference 预测库针对 GPT 进行预测,首先,需要导出预测模型,可以通过 `./export_model.py` 脚本获取预测库用模型,执行方式如下所示: + +``` sh +python ./export_model.py --model_name_or_path gpt2-medium-en --decoding_lib ./build/lib/libdecoding_op.so --topk 4 --topp 0.0 --max_out_len 32 --start_token "<|endoftext|>" --end_token "<|endoftext|>" --temperature 1.0 --inference_model_dir ./infer_model/ +``` + +各个选项的意义与上文的 `infer.py` 的选项相同。额外新增一个 `--inference_model_dir` 选项用于指定保存的模型文件、词表等文件。若是使用的模型是 gpt2-medium-en,保存之后,`./infer_model/` 目录下组织的结构如下: + +``` text +. +├── gpt.pdiparams # 保存的参数文件 +├── gpt.pdiparams.info # 保存的一些变量描述信息,预测不会用到 +├── gpt.pdmodel # 保存的模型文件 +├── merges.txt # bpe +└── vocab.txt # 词表 +``` + +同理,完成编译后,可以在 `PaddleNLP/paddlenlp/ops/build/bin/` 路径下将会看到 `gpt` 的一个可执行文件。通过设置对应的设置参数完成执行的过程。 + +``` sh +cd bin/ +./gpt -batch_size 1 -gpu_id 0 -model_dir path/to/model -vocab_dir path/to/vocab -start_token "<|endoftext|>" -end_token "<|endoftext|>" +``` diff --git a/examples/language_model/gpt/faster_gpt/export_model.py b/examples/language_model/gpt/faster_gpt/export_model.py new file mode 100644 index 0000000000000..0a680e953353b --- /dev/null +++ b/examples/language_model/gpt/faster_gpt/export_model.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import numpy as np +from attrdict import AttrDict +import argparse +import time + +import paddle + +import yaml +from pprint import pprint + +from paddlenlp.ops import FasterGPT +from paddlenlp.transformers import GPTModel, GPTLMHeadModel +from paddlenlp.transformers import GPTChineseTokenizer, GPTTokenizer + +from paddlenlp.utils.log import logger + +MODEL_CLASSES = { + "gpt-cpm-large-cn": (GPTLMHeadModel, GPTChineseTokenizer), + "gpt2-medium-en": (GPTLMHeadModel, GPTTokenizer), +} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + default="gpt2-medium-en", + type=str, + help="The model name to specify the gpt to use. Can be one of ['gpt2-en', 'gpt2-medium-en', 'gpt-cpm-large-cn']. " + ) + parser.add_argument( + "--decoding_lib", + default="../../build/lib/libdecoding_op.so", + type=str, + help="Path of libdecoding_op.so. ") + parser.add_argument( + "--inference_model_dir", + default="./infer_model/", + type=str, + help="Path to save inference model of gpt. ") + parser.add_argument( + "--topk", + default=4, + type=int, + help="The number of candidate to procedure beam search. ") + parser.add_argument( + "--topp", + default=0.0, + type=float, + help="The probability threshold to procedure topp sampling. ") + parser.add_argument( + "--max_out_len", default=32, type=int, help="Maximum output length. ") + parser.add_argument( + "--start_token", + default="<|endoftext|>", + type=str, + help="The start token. Defaults to <|endoftext|>. ") + parser.add_argument( + "--end_token", + default="<|endoftext|>", + type=str, + help="The end token. Defaults to <|endoftext|>. ") + parser.add_argument( + "--temperature", + default=1.0, + type=float, + help="The temperature to set. ") + parser.add_argument( + "--use_fp16_decoding", + action="store_true", + help="Whether to use fp16 decoding to predict. ") + args = parser.parse_args() + return args + + +def do_predict(args): + place = "gpu" + place = paddle.set_device(place) + + model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path] + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + logger.info('Loading the model parameters, please wait...') + model = model_class.from_pretrained( + args.model_name_or_path, max_predict_len=args.max_out_len) + + bos_id = tokenizer.convert_tokens_to_ids(args.start_token) + eos_id = tokenizer.convert_tokens_to_ids(args.end_token) + + gpt = FasterGPT( + model=model, + topk=args.topk, + topp=args.topp, + max_out_len=args.max_out_len, + bos_id=bos_id, + eos_id=eos_id, + temperature=args.temperature, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) + + # Set evaluate mode + gpt.eval() + + # Convert dygraph model to static graph model + gpt = paddle.jit.to_static( + gpt, + input_spec=[ + # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int32") + ]) + + # Save converted static graph model + paddle.jit.save(gpt, os.path.join(args.inference_model_dir, "gpt")) + logger.info("GPT has been saved to {}".format(args.inference_model_dir)) + + gpt.save_resources(tokenizer, args.inference_model_dir) + + +if __name__ == "__main__": + args = parse_args() + pprint(args) + do_predict(args) diff --git a/examples/language_model/gpt/faster_gpt/infer.py b/examples/language_model/gpt/faster_gpt/infer.py new file mode 100644 index 0000000000000..2f04cb2fab446 --- /dev/null +++ b/examples/language_model/gpt/faster_gpt/infer.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import numpy as np +from attrdict import AttrDict +import argparse +import time + +import paddle + +import yaml +from pprint import pprint + +from paddlenlp.ops import FasterGPT +from paddlenlp.transformers import GPTModel, GPTLMHeadModel +from paddlenlp.transformers import GPTChineseTokenizer, GPTTokenizer + +from paddlenlp.utils.log import logger + +MODEL_CLASSES = { + "gpt-cpm-large-cn": (GPTLMHeadModel, GPTChineseTokenizer), + "gpt2-medium-en": (GPTLMHeadModel, GPTTokenizer), +} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + default="gpt2-medium-en", + type=str, + help="The model name to specify the gpt to use. Can be one of ['gpt2-en', 'gpt2-medium-en', 'gpt-cpm-large-cn']. " + ) + parser.add_argument( + "--decoding_lib", + default="../../../../paddlenlp/ops/build/lib/libdecoding_op.so", + type=str, + help="Path of libdecoding_op.so. ") + parser.add_argument( + "--batch_size", default=1, type=int, help="Batch size. ") + parser.add_argument( + "--topk", + default=4, + type=int, + help="The number of candidate to procedure beam search. ") + parser.add_argument( + "--topp", + default=0.0, + type=float, + help="The probability threshold to procedure topp sampling. ") + parser.add_argument( + "--max_out_len", default=32, type=int, help="Maximum output length. ") + parser.add_argument( + "--start_token", + default="<|endoftext|>", + type=str, + help="The start token. Defaults to <|endoftext|>. ") + parser.add_argument( + "--end_token", + default="<|endoftext|>", + type=str, + help="The end token. Defaults to <|endoftext|>. ") + parser.add_argument( + "--temperature", + default=1.0, + type=float, + help="The temperature to set. ") + parser.add_argument( + "--use_fp16_decoding", + action="store_true", + help="Whether to use fp16 decoding to predict. ") + args = parser.parse_args() + return args + + +def do_predict(args): + place = "gpu" + place = paddle.set_device(place) + + model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path] + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + logger.info('Loading the model parameters, please wait...') + model = model_class.from_pretrained( + args.model_name_or_path, max_predict_len=args.max_out_len) + model.eval() + + bos_id = tokenizer.convert_tokens_to_ids(args.start_token) + eos_id = tokenizer.convert_tokens_to_ids(args.end_token) + + # Define model + gpt = FasterGPT( + model=model, + topk=args.topk, + topp=args.topp, + max_out_len=args.max_out_len, + bos_id=bos_id, + eos_id=eos_id, + temperature=args.temperature, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) + + # Set evaluate mode + gpt.eval() + input_ids = np.array( + [[bos_id] for i in range(args.batch_size * 1)]).astype("int32").reshape( + [args.batch_size, 1]) + input_ids = paddle.to_tensor(input_ids) + + with paddle.no_grad(): + for i in range(100): + # For warmup. + if 50 == i: + paddle.fluid.core._cuda_synchronize(place) + start = time.time() + out_seq = gpt(input_ids) + paddle.fluid.core._cuda_synchronize(place) + logger.info("Average test time for decoding is %f ms" % ( + (time.time() - start) / 50 * 1000)) + output_sequence = out_seq.numpy().transpose() + for i in range(args.batch_size): + print("========== Sample-%d ==========" % i) + print(tokenizer.convert_ids_to_string(output_sequence[i][1:])) + + +if __name__ == "__main__": + args = parse_args() + pprint(args) + do_predict(args) diff --git a/examples/machine_translation/transformer/export_model.py b/examples/machine_translation/transformer/export_model.py index 9bf1564102bbd..5d37d39980406 100644 --- a/examples/machine_translation/transformer/export_model.py +++ b/examples/machine_translation/transformer/export_model.py @@ -68,7 +68,9 @@ def do_export(args): bos_id=args.bos_idx, eos_id=args.eos_idx, beam_size=args.beam_size, - max_out_len=args.max_out_len) + max_out_len=args.max_out_len, + rel_len=args.use_rel_len, + alpha=args.alpha) # Load the trained model assert args.init_from_params, ( diff --git a/examples/machine_translation/transformer/faster_transformer/README.md b/examples/machine_translation/transformer/faster_transformer/README.md index e766f0e485aec..999f4745fdc66 100644 --- a/examples/machine_translation/transformer/faster_transformer/README.md +++ b/examples/machine_translation/transformer/faster_transformer/README.md @@ -256,16 +256,18 @@ python export_model.py --config ../configs/transformer.base.yaml --decoding_lib ```text └── infer_model/ ├── transformer.pdiparams + ├── transformer.pdiparams.info └── transformer.pdmodel ``` + ### 使用 PaddlePaddle 预测库预测 自定义 op 编译完成后,在 `paddlenlp/ops/build/bin/` 路径下将会看到 `transformer_e2e` 的一个可执行文件。通过设置对应的设置参数完成执行的过程。 ``` sh cd bin/ -./transformer_e2e -batch_size -beam_size -gpu_id -model_dir -vocab_dir -data_dir +./transformer_e2e -batch_size -gpu_id -model_dir -vocab_dir -data_dir ``` 这里的 `` 即是上文说到导出的 paddle inference 模型。 @@ -275,7 +277,7 @@ cd bin/ ``` sh cd bin/ ../third-party/build/bin/decoding_gemm 8 5 8 64 38512 256 512 0 -./transformer_e2e -batch_size 8 -beam_size 5 -gpu_id 0 -model_dir ./infer_model/ -vocab_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 -data_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en +./transformer_e2e -batch_size 8 -gpu_id 0 -model_dir ./infer_model/ -vocab_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 -data_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en ``` 其中: diff --git a/examples/machine_translation/transformer/faster_transformer/export_model.py b/examples/machine_translation/transformer/faster_transformer/export_model.py index c2ef1258711e4..f7dbb59daf393 100644 --- a/examples/machine_translation/transformer/faster_transformer/export_model.py +++ b/examples/machine_translation/transformer/faster_transformer/export_model.py @@ -94,53 +94,49 @@ def do_predict(args): place = paddle.set_device(place) reader.adapt_vocab_size(args) - test_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(test_program, startup_program): - src_word = paddle.static.data( - name="src_word", shape=[None, None], dtype="int64") - - # Define model - transformer = FasterTransformer( - src_vocab_size=args.src_vocab_size, - trg_vocab_size=args.trg_vocab_size, - max_length=args.max_length + 1, - num_encoder_layers=args.n_layer, - num_decoder_layers=args.n_layer, - n_head=args.n_head, - d_model=args.d_model, - d_inner_hid=args.d_inner_hid, - dropout=args.dropout, - weight_sharing=args.weight_sharing, - bos_id=args.bos_idx, - eos_id=args.eos_idx, - decoding_strategy=args.decoding_strategy, - beam_size=args.beam_size, - max_out_len=args.max_out_len, - decoding_lib=args.decoding_lib, - use_fp16_decoding=args.use_fp16_decoding, - rel_len=args.use_rel_len, - alpha=args.alpha) - - finished_seq = transformer(src_word=src_word) - - test_program = test_program.clone(for_test=True) - - exe = paddle.static.Executor(place) - exe.run(startup_program) + # Define model + transformer = FasterTransformer( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + num_encoder_layers=args.n_layer, + num_decoder_layers=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + decoding_strategy=args.decoding_strategy, + beam_size=args.beam_size, + max_out_len=args.max_out_len, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding, + rel_len=args.use_rel_len, + alpha=args.alpha) + + # Set evaluate mode + transformer.eval() # Load checkpoint. - transformer.export_params( - init_from_params=os.path.join(args.init_from_params, - "transformer.pdparams"), - place=place) - - paddle.static.save_inference_model( - os.path.join(args.inference_model_dir, "transformer"), - feed_vars=src_word, - fetch_vars=finished_seq, - executor=exe, - program=test_program) + transformer.load(init_from_params=os.path.join(args.init_from_params, + "transformer.pdparams")) + + # Convert dygraph model to static graph model + transformer = paddle.jit.to_static( + transformer, + input_spec=[ + # src_word + paddle.static.InputSpec( + shape=[None, None], dtype="int64") + ]) + + # Save converted static graph model + paddle.jit.save(transformer, + os.path.join(args.inference_model_dir, "transformer")) + logger.info("Transformer has been saved to {}".format( + args.inference_model_dir)) if __name__ == "__main__": diff --git a/paddlenlp/ops/README.md b/paddlenlp/ops/README.md index d53d30d147a2c..83e0c7021bc6f 100644 --- a/paddlenlp/ops/README.md +++ b/paddlenlp/ops/README.md @@ -13,7 +13,7 @@ ## 使用环境说明 -* 本项目依赖于 PaddlePaddle 2.0.1 及以上版本或适当的 develop 版本 +* 本项目依赖于 PaddlePaddle 2.1.0 及以上版本或适当的 develop 版本 * CMake >= 3.10 * CUDA 10.1(需要 PaddlePaddle 框架一致) * gcc 版本需要与编译 PaddlePaddle 版本一致,比如使用 gcc8.2 @@ -236,7 +236,7 @@ cd ../ ``` sh cd bin/ -./transformer_e2e -batch_size -beam_size -gpu_id -model_dir -vocab_dir -data_dir +./transformer_e2e -batch_size -gpu_id -model_dir -vocab_dir -data_dir ``` 举例说明: @@ -244,7 +244,7 @@ cd bin/ ``` sh cd bin/ ../third-party/build/bin/decoding_gemm 8 5 8 64 38512 256 512 0 -./transformer_e2e -batch_size 8 -beam_size 5 -gpu_id 0 -model_dir ./infer_model/ -vocab_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 -data_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en +./transformer_e2e -batch_size 8 -gpu_id 0 -model_dir ./infer_model/ -vocab_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/vocab_all.bpe.33708 -data_dir DATA_HOME/WMT14ende/WMT14.en-de/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en ``` 其中: @@ -265,10 +265,11 @@ python ./faster_transformer/sample/gpt_export_model_sample.py --model_name_or_pa ``` text . -├── gpt.pdiparams # 保存的参数文件 -├── gpt.pdmodel # 保存的模型文件 -├── merges.txt # bpe -└── vocab.txt # 词表 +├── gpt.pdiparams # 保存的参数文件 +├── gpt.pdiparams.info # 保存的一些变量描述信息,预测不会用到 +├── gpt.pdmodel # 保存的模型文件 +├── merges.txt # bpe +└── vocab.txt # 词表 ``` 同理,完成编译后,可以在 `build/bin/` 路径下将会看到 `gpt` 的一个可执行文件。通过设置对应的设置参数完成执行的过程。 diff --git a/paddlenlp/ops/faster_transformer/sample/gpt_export_model_sample.py b/paddlenlp/ops/faster_transformer/sample/gpt_export_model_sample.py index d1dac2ba47d6b..0a680e953353b 100644 --- a/paddlenlp/ops/faster_transformer/sample/gpt_export_model_sample.py +++ b/paddlenlp/ops/faster_transformer/sample/gpt_export_model_sample.py @@ -24,14 +24,14 @@ from pprint import pprint from paddlenlp.ops import FasterGPT -from paddlenlp.transformers import GPTModel, GPTForGreedyGeneration +from paddlenlp.transformers import GPTModel, GPTLMHeadModel from paddlenlp.transformers import GPTChineseTokenizer, GPTTokenizer from paddlenlp.utils.log import logger MODEL_CLASSES = { - "gpt-cpm-large-cn": (GPTForGreedyGeneration, GPTChineseTokenizer), - "gpt2-medium-en": (GPTForGreedyGeneration, GPTTokenizer), + "gpt-cpm-large-cn": (GPTLMHeadModel, GPTChineseTokenizer), + "gpt2-medium-en": (GPTLMHeadModel, GPTTokenizer), } @@ -45,7 +45,7 @@ def parse_args(): ) parser.add_argument( "--decoding_lib", - default="../build/lib/libdecoding_op.so", + default="../../build/lib/libdecoding_op.so", type=str, help="Path of libdecoding_op.so. ") parser.add_argument( @@ -89,51 +89,44 @@ def parse_args(): def do_predict(args): - paddle.enable_static() place = "gpu" place = paddle.set_device(place) - test_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(test_program, startup_program): - model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path] - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - logger.info('Loading the model parameters, please wait...') - model, state_to_load = model_class.from_pretrained( - args.model_name_or_path, max_predict_len=args.max_out_len) - - bos_id = tokenizer.convert_tokens_to_ids(args.start_token) - eos_id = tokenizer.convert_tokens_to_ids(args.end_token) - - input_ids = paddle.static.data( - name="ids", shape=[None, None], dtype="int32") - # Define model - gpt = FasterGPT( - model=model, - topk=args.topk, - topp=args.topp, - max_out_len=args.max_out_len, - bos_id=bos_id, - eos_id=eos_id, - temperature=args.temperature, - decoding_lib=args.decoding_lib, - use_fp16_decoding=args.use_fp16_decoding) - - finished_seq = gpt(input_ids) - - test_program = test_program.clone(for_test=True) - - exe = paddle.static.Executor(place) - exe.run(startup_program) - - gpt.export_params(state_to_load=state_to_load, place=place) - - paddle.static.save_inference_model( - os.path.join(args.inference_model_dir, "gpt"), - feed_vars=input_ids, - fetch_vars=finished_seq, - executor=exe, - program=test_program) + model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path] + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + logger.info('Loading the model parameters, please wait...') + model = model_class.from_pretrained( + args.model_name_or_path, max_predict_len=args.max_out_len) + + bos_id = tokenizer.convert_tokens_to_ids(args.start_token) + eos_id = tokenizer.convert_tokens_to_ids(args.end_token) + + gpt = FasterGPT( + model=model, + topk=args.topk, + topp=args.topp, + max_out_len=args.max_out_len, + bos_id=bos_id, + eos_id=eos_id, + temperature=args.temperature, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) + + # Set evaluate mode + gpt.eval() + + # Convert dygraph model to static graph model + gpt = paddle.jit.to_static( + gpt, + input_spec=[ + # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int32") + ]) + + # Save converted static graph model + paddle.jit.save(gpt, os.path.join(args.inference_model_dir, "gpt")) + logger.info("GPT has been saved to {}".format(args.inference_model_dir)) gpt.save_resources(tokenizer, args.inference_model_dir) diff --git a/paddlenlp/ops/faster_transformer/sample/gpt_sample.py b/paddlenlp/ops/faster_transformer/sample/gpt_sample.py index 9cb2a18b288b1..f194eb97ba403 100644 --- a/paddlenlp/ops/faster_transformer/sample/gpt_sample.py +++ b/paddlenlp/ops/faster_transformer/sample/gpt_sample.py @@ -24,14 +24,14 @@ from pprint import pprint from paddlenlp.ops import FasterGPT -from paddlenlp.transformers import GPTModel, GPTForGreedyGeneration +from paddlenlp.transformers import GPTModel, GPTLMHeadModel from paddlenlp.transformers import GPTChineseTokenizer, GPTTokenizer from paddlenlp.utils.log import logger MODEL_CLASSES = { - "gpt-cpm-large-cn": (GPTForGreedyGeneration, GPTChineseTokenizer), - "gpt2-medium-en": (GPTForGreedyGeneration, GPTTokenizer), + "gpt-cpm-large-cn": (GPTLMHeadModel, GPTChineseTokenizer), + "gpt2-medium-en": (GPTLMHeadModel, GPTTokenizer), } diff --git a/paddlenlp/ops/faster_transformer/src/demo/gpt.cc b/paddlenlp/ops/faster_transformer/src/demo/gpt.cc index 853097a0d19bb..d79072d478958 100644 --- a/paddlenlp/ops/faster_transformer/src/demo/gpt.cc +++ b/paddlenlp/ops/faster_transformer/src/demo/gpt.cc @@ -191,7 +191,8 @@ class DataReader { std::vector& data_input_vec, int max_len, int batch_size) { - auto ids_t = predictor->GetInputHandle("ids"); + auto ids_name = predictor->GetInputNames(); + auto ids_t = predictor->GetInputHandle(ids_name[0]); std::vector ids_vec; ids_vec.resize(max_len * batch_size); for (int i = 0; i < batch_size; ++i) { diff --git a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py index 50e0239bec480..eeb23e01faaa7 100644 --- a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py +++ b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py @@ -260,6 +260,71 @@ def load(self, init_from_params): self.load_dict(model_dict) def export_params(self, init_from_params, place): + ''' + This method is used for load static graph from dygraph checkpoint + or export inference model using static graph. + + Args: + init_from_params (string): + The path to dygraph checkpoint. + place (paddle.Place): + The place to execute static graph. + + Example: + .. code-block:: + paddle.enable_static() + place = "gpu" + place = paddle.set_device(place) + reader.adapt_vocab_size(args) + + test_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(test_program, startup_program): + src_word = paddle.static.data( + name="src_word", shape=[None, None], dtype="int64") + + # Define model + transformer = FasterTransformer( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + num_encoder_layers=args.n_layer, + num_decoder_layers=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + decoding_strategy=args.decoding_strategy, + beam_size=args.beam_size, + max_out_len=args.max_out_len, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding, + rel_len=args.use_rel_len, + alpha=args.alpha) + + finished_seq = transformer(src_word=src_word) + + test_program = test_program.clone(for_test=True) + + exe = paddle.static.Executor(place) + exe.run(startup_program) + + # Load checkpoint. + transformer.export_params( + init_from_params=os.path.join(args.init_from_params, + "transformer.pdparams"), + place=place) + + paddle.static.save_inference_model( + os.path.join(args.inference_model_dir, "transformer"), + feed_vars=src_word, + fetch_vars=finished_seq, + executor=exe, + program=test_program) + ''' # Load the trained model assert init_from_params, ( "Please set init_from_params to load the infer model.")