zl程序教程

您现在的位置是:首页 >  其它

当前栏目

YOLOv3中Loss

Loss yolov3
2023-09-14 09:15:53 时间

yolov3 的 loss ,今天终于看完了yolov3-tf2 的源码

YOLOv3中Loss部分计算

代码

        
        #终点的loss= obj_mask * box_loss_scale*sum((px1-tx1)**2+(px2-tx2)**2)
        #因为这里有obj_mask的存在,只考虑有目标点的loss 
        xy_loss = obj_mask * box_loss_scale * \
            tf.reduce_sum(tf.square(true_xy - pred_xy), axis=-1)
        
        wh_loss = obj_mask * box_loss_scale * \
            tf.reduce_sum(tf.square(true_wh - pred_wh), axis=-1)
        obj_loss = binary_crossentropy(true_obj, pred_obj)
        obj_loss = obj_mask * obj_loss + \
            (1 - obj_mask) * ignore_mask * obj_loss
        # TODO: use binary_crossentropy instead
        class_loss = obj_mask * sparse_categorical_crossentropy(
            true_class_idx, pred_class)


        # 6. sum over (batch, gridx, gridy, anchors) => (batch, 1)
        xy_loss = tf.reduce_sum(xy_loss, axis=(1, 2, 3))
        wh_loss = tf.reduce_sum(wh_loss, axis=(1, 2, 3))
        obj_loss = tf.reduce_sum(obj_loss, axis=(1, 2, 3))
        class_loss = tf.reduce_sum(class_loss, axis=(1, 2, 3))

        return xy_loss + wh_loss + obj_loss + class_loss