Pytorch学习笔记9-损失函数
由于时效问题,该文某些代码、技术可能已经过期,请注意!!!本文最后更新于:3 年前
各种损失函数
一般来说,监督学习的目标函数由损失函数和正则化项组成。(Objective = Loss + Regularization)
Pytorch中的损失函数一般在训练模型时候指定。
注意Pytorch中内置的损失函数的参数和tensorflow不同,是y_pred在前,y_true在后,而Tensorflow是y_true在前,y_pred在后。
对于回归模型,通常使用的内置损失函数是均方损失函数nn.MSELoss 。
对于二分类模型,通常使用的是二元交叉熵损失函数nn.BCELoss (输入已经是sigmoid激活函数之后的结果)或者 nn.BCEWithLogitsLoss (输入尚未经过nn.Sigmoid激活函数) 。
对于多分类模型,一般推荐使用交叉熵损失函数 nn.CrossEntropyLoss。(y_true需要是一维的,是类别编码。y_pred未经过nn.Softmax激活。)
此外,如果多分类的y_pred经过了nn.LogSoftmax激活,可以使用nn.NLLLoss损失函数(The negative log likelihood loss)。
这种方法和直接使用nn.CrossEntropyLoss等价。
如果有需要,也可以自定义损失函数,自定义损失函数需要接收两个张量y_pred,y_true作为输入参数,并输出一个标量作为损失函数值。
Pytorch中的正则化项一般通过自定义的方式和损失函数一起添加作为目标函数。
如果仅仅使用L2正则化,也可以利用优化器的weight_decay参数来实现相同的效果。
内置损失函数
1 |
|
内置的损失函数一般有类的实现和函数的实现两种形式。
如:nn.BCE 和 F.binary_cross_entropy 都是二元交叉熵损失函数,前者是类的实现形式,后者是函数的实现形式。
实际上类的实现形式通常是调用函数的实现形式并用nn.Module封装后得到的。
一般我们常用的是类的实现形式。它们封装在torch.nn模块下,并且类名以Loss结尾。
常用的一些内置损失函数说明如下。
nn.MSELoss(均方误差损失,也叫做L2损失,用于回归)
nn.L1Loss (L1损失,也叫做绝对值误差损失,用于回归)
nn.SmoothL1Loss (平滑L1损失,当输入在-1到1之间时,平滑为L2损失,用于回归)
nn.BCELoss (二元交叉熵,用于二分类,输入已经过nn.Sigmoid激活,对不平衡数据集可以用weigths参数调整类别权重)
nn.BCEWithLogitsLoss (二元交叉熵,用于二分类,输入未经过nn.Sigmoid激活)
nn.CrossEntropyLoss (交叉熵,用于多分类,要求label为稀疏编码,输入未经过nn.Softmax激活,对不平衡数据集可以用weigths参数调整类别权重)
nn.NLLLoss (负对数似然损失,用于多分类,要求label为稀疏编码,输入经过nn.LogSoftmax激活)
nn.CosineSimilarity(余弦相似度,可用于多分类)
nn.AdaptiveLogSoftmaxWithLoss (一种适合非常多类别且类别分布很不均衡的损失函数,会自适应地将多个小类别合成一个cluster)
更多损失函数的介绍参考如下知乎文章:
《PyTorch的十八个损失函数》
自定义损失函数
自定义损失函数接收两个张量y_pred,y_true作为输入参数,并输出一个标量作为损失函数值。
也可以对nn.Module进行子类化,重写forward方法实现损失的计算逻辑,从而得到损失函数的类的实现。
下面是一个Focal Loss的自定义实现示范。Focal Loss是一种对binary_crossentropy的改进损失函数形式。
它在样本不均衡和存在较多易分类的样本时相比binary_crossentropy具有明显的优势。
它有两个可调参数,alpha参数和gamma参数。其中alpha参数主要用于衰减负样本的权重,gamma参数主要用于衰减容易训练样本的权重。
从而让模型更加聚焦在正样本和困难样本上。这就是为什么这个损失函数叫做Focal Loss。
详见《5分钟理解Focal Loss与GHM——解决样本不平衡利器》
自定义L1和L2正则化项
通常认为L1 正则化可以产生稀疏权值矩阵,即产生一个稀疏模型,可以用于特征选择。
而L2 正则化可以防止模型过拟合(overfitting)。一定程度上,L1也可以防止过拟合。
下面以一个二分类问题为例,演示给模型的目标函数添加自定义L1和L2正则化项的方法。
这个范例同时演示了上一个部分的FocalLoss的使用
示例参考 https://www.heywhale.com/mw/project/5f33d61caf3980002cb83d18
通过优化器实现L2正则化
如果仅仅需要使用L2正则化,那么也可以利用优化器的weight_decay参数来实现。
weight_decay参数可以设置参数在训练过程中的衰减,这和L2正则化的作用效果等价。
Pytorch的优化器支持一种称之为Per-parameter options的操作,就是对每一个参数进行特定的学习率,权重衰减率指定,以满足更为细致的要求。
搬运自:
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!