BartForSequenceClassification 源码阅读
源码 阅读
2023-09-14 09:13:14 时间
class BartForSequenceClassification(BartPretrainedModel):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
self.classification_head = BartClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if input_ids is None and inputs_embeds is not None:
raise NotImplementedError(
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :
]
logits = self.classification_head(sentence_representation)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
相关文章
- Node学习笔记 - Koa源码阅读
- Vue2.6源码(2):$mount方法干了啥
- Mybatis源码:@MapperScan解析过程
- Go-Excelize API源码阅读(九)——SetSheetBackground(sheet, picture string)
- Go-Excelize API源码阅读(十七)——GetPageLayout、SetPageMargins
- Kubernetes 学习(九)Kubernetes 源码阅读之正式篇------核心组件之 Scheduler
- 客服客户聊天系统源码分享[通俗易懂]
- Vue3移动端组件库Varlet源码主题阅读之一:本地启动服务时都做了什么
- vue源码分析-组件
- java多线程与高并发:LockSupport、淘宝面试题与源码阅读方法论
- 深入AQS源码阅读与强软弱虚4种引用以及ThreadLocal原理与源码
- Android ContentProvider_2 源码解析
- react源码分析事件系统
- 揭秘 OpenTelemetry-Collector 源码内幕
- 阅读源码入门实践系列之 element ui(1)
- react源码之Fiber架构
- 带你实现react源码的核心功能
- React源码分析5-commit6
- sqlmap 源码分析(二)初始化
- 22.2k stars的GitHub辅助阅读源码神器
- 国产开源仿钉钉流程设计器源码,前端基于wflow工程创建,100%开源
- 【Android 系统开发】使用 Source InSight 阅读 Android 源码
- mold源码阅读 其一 读取输入文件
- Golang流媒体实战之五:lal推流服务源码阅读
- BitXHub 跨链插件(Fabric)源码解读
- Golang流媒体实战之七:hls拉流服务源码阅读
- JDK 源码阅读 : DirectByteBuffer详解编程语言
- 学习Linux源码:提升编程技能的无穷乐趣(阅读linux源码)
- 深度分析TP6中Redis源码解析(tp redis 源码)
- CI框架源码阅读,系统常量文件constants.php的配置
- mysqld_safe启动脚本源码阅读、分析