-
Notifications
You must be signed in to change notification settings - Fork 5.6k
超大规模稀疏参数(DistributedLookupTable)的本地预测 增量
涉及到超大规模稀疏参数使用的模型,在保存参数的时候(save_inference_model/save_persistables)会和一般模型不太一样, 其中超大规模稀疏参数会被直接保存在PSERVER端(每个PSERVER端均有一份且只包含当前PSERVER上的参数),其他参数还是正常保存于TRAINER端。
涉及到的相关代码位于:
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/utils/lookup_table_utils.py
- 本地预测 本地预测说明: 本地预测是指,将带有超大规模稀疏参数的多机训练的参数在本地做合并,用一个节点来做预测的方式。如果拥有大内存(能够存放得下超大规模稀疏参数)的服务器,用本地预测的方式能够快速的对训练效果进行验证。
本地预测的一般步骤:
-
将训练过程中得到的模型参数(包括TRAINER端的参数和所有PSERVER端的参数)做合并。
-
组建inference的网络,需要包含加入Optimizer+minimize
-
加入Distributed Transpiler的信息,构建多机的网络
-
t.get_trainer_program(wait_port=False)获得TRAINER端的program
-
通过lookup_table_utils.convert_dist_to_sparse_program(main_program)将超大规模稀疏参数转换为本地的稀疏参数
-
加入输入数据的FEED和FETCH(参考save_inference_model),调用lookup_table_utils.get_inference_model(inference_program, feeds, fetch_targets)保存一个用于本地预测的网络
-
调用lookup_table_utils.load_persistables_for_inference(executor=exe, dirname=model_dir, program=inference_program, lookup_table_var_name='dis_emb'),进行本地的模型和参数加载,需要显示的指定超大规模稀疏参数的参数名
-
多机预测 略
-
多机增量训练 略
DEMO:
def dis_infer(use_cuda, model_dir, batch_size, is_local, use_parallel_executor, is_distribute):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
data_list, predict, auc_var, cur_auc, auc_var_cos, cur_auc_cos, avg_cost, label, py_reader, user_emb_cross, ad_emb_cross = net.model()
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
optimizer = fluid.optimizer.SGD(0.001)
optimizer.minimize(avg_cost)
test_reader = paddle.batch(batch_size=batch_size)
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS", "12.0,12.1,12.2,12.3,12.4,12.5,12.6,12.7,12.8,12.9,12.10,12.11,12.12,12.13,12.14,12.15,12.16,12.17,12.18,12.19,12.20,12.21,12.22,12.23,12.24,12.25,12.26,12.27,12.28,12.29,12.30,12.31,12.32,12.33,12.34,12.35,12.36,12.37,12.38,12.39,12.40,12.41,12.42,12.43,12.44,12.45,12.46,12.47,12.48,12.49,12.50,12.51,12.52,12.53,12.54,12.55,12.56,12.57,12.58,12.59,12.60,12.61,12.62,12.63,12.64,12.65,12.66,12.67,12.68,12.69,12.70,12.71,12.72,12.73,12.74,12.75,12.76,12.77,12.78,12.79,12.80,12.81,12.82,12.83,12.84,12.85,12.86,12.87,12.88,12.89,12.90,12.91,12.92,12.93,12.94,12.95,12.96,12.97,12.98,12.99") # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainer_num = int(os.getenv("PADDLE_TRAINERS_NUM", "100"))
current_endpoint = os.getenv("POD_IP", "12.0") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = True
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainer_num)
main_program = t.get_trainer_program(wait_port=False)
inference_program = lookup_table_utils.convert_dist_to_sparse_program(main_program)
feeds = [ x.name for x in data_list ]
fetch_targets = [predict, auc_var, auc_var_cos, user_emb_cross, ad_emb_cross]
inference_program = lookup_table_utils.get_inference_model(inference_program, feeds, fetch_targets)
lookup_table_utils.load_persistables_for_inference(executor=exe, dirname=model_dir, program=inference_program, lookup_table_var_name='dis_emb')