From a122043a5c78a08b7f710143db567d7ab961c25a Mon Sep 17 00:00:00 2001 From: Shintaku Date: Thu, 8 Jun 2023 18:03:56 +0800 Subject: [PATCH 1/3] Update net.py fix a network bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit i是第几个task,需要乘上每个task的expert数量而不是task数量 --- models/multitask/ple/net.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/multitask/ple/net.py b/models/multitask/ple/net.py index e36d624a9..1fd17bff3 100644 --- a/models/multitask/ple/net.py +++ b/models/multitask/ple/net.py @@ -179,8 +179,7 @@ def forward(self, input_data): # task-specific expert part for i in range(0, self.task_num): for j in range(0, self.exp_per_task): - linear_out = self._param_expert[i * self.task_num + j]( - input_data[i]) + linear_out = self._param_expert[i * self.exp_per_task + j](input_data[i]) expert_output = F.relu(linear_out) expert_outputs.append(expert_output) # shared expert part From aa3ba7d0c7889fc2e191a84098a3064f4a79d1e1 Mon Sep 17 00:00:00 2001 From: Shintaku Date: Mon, 12 Jun 2023 14:27:28 +0800 Subject: [PATCH 2/3] Update net.py fixed code style --- models/multitask/ple/net.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/multitask/ple/net.py b/models/multitask/ple/net.py index 1fd17bff3..2490b7695 100644 --- a/models/multitask/ple/net.py +++ b/models/multitask/ple/net.py @@ -179,7 +179,8 @@ def forward(self, input_data): # task-specific expert part for i in range(0, self.task_num): for j in range(0, self.exp_per_task): - linear_out = self._param_expert[i * self.exp_per_task + j](input_data[i]) + linear_out = self._param_expert[i * self.exp_per_task + j]( + input_data[i]) expert_output = F.relu(linear_out) expert_outputs.append(expert_output) # shared expert part From 4b2f898f1da692b9515661d4a13f6ccbd8f94a3f Mon Sep 17 00:00:00 2001 From: Shintaku Date: Sun, 14 Apr 2024 14:57:52 +0800 Subject: [PATCH 3/3] Update reader_helper.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在不同节点上的文件顺序可能不一致,split_file_list可能读到相同的文件。sort以后保证每个节点的文件列表顺序一致,拆分读取个节点不会读到重复文件。 --- tools/utils/static_ps/reader_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/utils/static_ps/reader_helper.py b/tools/utils/static_ps/reader_helper.py index 73b0c0540..310b08463 100755 --- a/tools/utils/static_ps/reader_helper.py +++ b/tools/utils/static_ps/reader_helper.py @@ -73,7 +73,7 @@ def get_infer_reader(input_var, config): def get_file_list(data_path, config): assert os.path.exists(data_path) - file_list = [data_path + "/%s" % x for x in os.listdir(data_path)] + file_list = [data_path + "/%s" % x for x in sorted(os.listdir(data_path))] if config.get("runner.split_file_list"): logger.info("Split file list for worker {}".format(fleet.worker_index( )))