Skip to content

Commit

Permalink
add hist model
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Feb 7, 2024
1 parent 9726a8d commit a84963c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions qlib/contrib/model/pytorch_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def fit(

if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=torch.device('cpu')))

model_dict = self.HIST_model.state_dict()
pretrained_dict = {
Expand Down Expand Up @@ -429,7 +429,8 @@ def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same
return cos_similarity

def forward(self, x, concept_matrix):
device = torch.device(torch.get_device(x))
# device = torch.device(torch.get_device(x))
device = torch.device("cpu")

x_hidden = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
x_hidden = x_hidden.permute(0, 2, 1) # [N, T, F]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_version(rel_path: str) -> str:
"lightgbm>=3.3.0",
"tornado",
"joblib>=0.17.0",
"ruamel.yaml>=0.16.12",
"ruamel.yaml==0.17.21",
"pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
"dill",
Expand Down

0 comments on commit a84963c

Please sign in to comment.