We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
尝试在星辰开源代码库中的modeling_telechat添加TelechatForSequenceClassification方法类(分别参照qwen和星辰自己代码),会分别出现无法加载模型的错误和训练损失不下降的情况。需要AI公司帮忙一起看看怎么支持AutoModelForSequenceClassification任务。
class TelechatForSequenceClassification_tele(TelechatPreTrainedModel): # _tied_weights_keys = ["lm_head.weight"] #_keys_to_ignore_on_load_missing = [ r"lm_head.weight"] def init(self, config: TelechatConfig): super().init(config) self.num_labels = config.num_labels self.transformer = TelechatModel(config) self.lm_head = nn.Linear(config.hidden_size, self.num_labels, bias=False) self.post_init()
def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings: torch.Tensor): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> dict: if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 pooled_logits = lm_logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: labels = labels.to(lm_logits.device) shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() #batch_size, seq_length, num_labels = shift_logits.shape loss_fct = CrossEntropyLoss() #loss = loss_fct(shift_logits.view(batch_size * seq_length, num_labels), shift_labels.view(batch_size * seq_length)) loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits,#lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )
class TelechatForSequenceClassification_qwen(TelechatPreTrainedModel): def init(self, config): super().init(config) self.num_labels = config.num_labels self.model = TelechatModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) #print('************************', return_dict, loss, self.config) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output #print(f'**********pooled_logits={pooled_logits}***transformer_outputs={transformer_outputs}***********', pooled_logits, transformer_outputs) return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )
加载模型的代码如下:
The text was updated successfully, but these errors were encountered:
好的,我们这边看一下
Sorry, something went wrong.
我先用了qwen2模型做分类,数据集是:https://huggingface.co/datasets/knowledgator/Scientific-text-classification 使用deepspeed进行训练,
loss是正常的:
后续我再去试试telechat-7b,到时候给你反馈。
No branches or pull requests
尝试在星辰开源代码库中的modeling_telechat添加TelechatForSequenceClassification方法类(分别参照qwen和星辰自己代码),会分别出现无法加载模型的错误和训练损失不下降的情况。需要AI公司帮忙一起看看怎么支持AutoModelForSequenceClassification任务。
class TelechatForSequenceClassification_tele(TelechatPreTrainedModel):
# _tied_weights_keys = ["lm_head.weight"]
#_keys_to_ignore_on_load_missing = [ r"lm_head.weight"]
def init(self, config: TelechatConfig):
super().init(config)
self.num_labels = config.num_labels
self.transformer = TelechatModel(config)
self.lm_head = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.post_init()
class TelechatForSequenceClassification_qwen(TelechatPreTrainedModel):
def init(self, config):
super().init(config)
self.num_labels = config.num_labels
self.model = TelechatModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
加载模型的代码如下:
The text was updated successfully, but these errors were encountered: