bert4torch.losses module¶
- class bert4torch.losses.ContrastiveLoss(*args: Any, **kwargs: Any)[source]¶
对比损失:减小正例之间的距离,增大正例和反例之间的距离 公式:labels * distance_matrix.pow(2) + (1-labels)*F.relu(margin-distance_matrix).pow(2) https://www.sbert.net/docs/package_reference/losses.html
- Parameters
margin – float, 距离参数,distance>margin时候不参加梯度回传,默认为0.5
size_average – bool, 是否对loss在样本维度上求均值,默认为True
online – bool, 是否使用OnlineContrastiveLoss, 即仅计算困难样本的loss, 默认为False
- class bert4torch.losses.FocalLoss(*args: Any, **kwargs: Any)[source]¶
Multi-class Focal loss implementation
- class bert4torch.losses.MultilabelCategoricalCrossentropy(*args: Any, **kwargs: Any)[source]¶
多标签分类的交叉熵; 说明:y_true和y_pred的shape一致,y_true的元素非0即1, 1表示对应的类为目标类,0表示对应的类为非目标类。 警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax!预测阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解本文。 参考:https://kexue.fm/archives/7359
- class bert4torch.losses.RDropLoss(*args: Any, **kwargs: Any)[source]¶
R-Drop的Loss实现,官方项目:https://github.com/dropreg/R-Drop
- Parameters
alpha – float, 控制rdrop的loss的比例
rank – str, 指示y_pred的排列方式, 支持[‘adjacent’, ‘updown’]
- forward(*args)[source]¶
支持两种方式: 一种是y_pred, y_true, 另一种是y_pred1, y_pred2, y_true
- Parameters
y_pred – torch.Tensor, 第一种方式的样本预测值, shape=[btz*2, num_labels]
y_true – torch.Tensor, 样本真实值, 第一种方式shape=[btz*2,], 第二种方式shape=[btz,]
y_pred1 – torch.Tensor, 第二种方式的样本预测值, shape=[btz, num_labels]
y_pred2 – torch.Tensor, 第二种方式的样本预测值, shape=[btz, num_labels]
- class bert4torch.losses.SparseMultilabelCategoricalCrossentropy(*args: Any, **kwargs: Any)[source]¶
稀疏版多标签分类的交叉熵; 请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax,预测阶段则输出y_pred大于0的类; 详情请看:https://kexue.fm/archives/7359 。
- class bert4torch.losses.TemporalEnsemblingLoss(*args: Any, **kwargs: Any)[source]¶
TemporalEnsembling的实现,思路是在监督loss的基础上,增加一个mse的一致性损失loss
pytorch第三方实现:https://github.com/ferretj/temporal-ensembling
使用的时候,train_dataloader的shffle必须未False
- class bert4torch.losses.UDALoss(*args: Any, **kwargs: Any)[source]¶
UDALoss,使用时候需要继承一下,因为forward需要使用到global_step和total_steps https://arxiv.org/abs/1904.12848
- Parameters
tsa_schedule – str, tsa策略,可选[‘linear_schedule’, ‘exp_schedule’, ‘log_schedule’]
start_p – float, tsa生效概率下限, 默认为0
end_p – float, tsa生效概率上限, 默认为1
return_all_loss – bool, 是否返回所有的loss,默认为True
- Returns
loss