对QKV的理解,先废一下话...

计算流程参考:https://zhuanlan.zhihu.com/p/82312421

给定一组query,和当前输入样本input(每个样本都有各自的key),经过空间变化后input→query。

计算query和key之间的相关性r,r中的每个元素就可以看做是input中每个样本和其他样本的关系向量,然后根据query和key的相关性得到value的加权和(即,注意力)。

本质上,QKV的计算就是矩阵之间的计算。通过样本输入向量input和训练过程中训练得到的三个矩阵WQ、WK、WV分别进行矩阵计算,得到QKV三向量。

虽然QKV三者的属性不在同一空间,其实是有一定潜在联系的,通过某种变换(个人认为就是矩阵之间的线性相乘相加计算)将三者属性映射到一个相似的空间。注意力机制不是为了query去找value,而是根据当前query获取value的加权和来获取对输入样本的注意力。

训练参数说明

args.lr_backbone默认为1e-5,则train_backbone默认为true,通过设置backbone的lr来设置是否训练网络时接收backbone的梯度从而让backbone也训练。

--lr:网络学习率。
--lr_backbone:主干网络的学习率。
--batch_size:如果设置太高了显卡带不动,会显示爆显存的错,就需要降低batch_size。
--weight_decay:权重衰减系数。
--epochs: 训练的轮次,所有图片训练一次是一轮。
 --lr_drop:学习率衰减epoch(每训练epoch后,学习率线性递减变化)。
--clip_max_norm:梯度衰减最大范数。
--frozen_weights:是否固定住参数的权重,类似于迁移学习的微调。
--backbone:backbone的网络结构选择,默认'resnet50'。
--dilation:是否进行空洞卷积,区分是否DC5模块。
--position_embedding:用于图像特征的位置编码方法类型,可选['sine', 'learned'],默认是'sine'。
--enc_layers:Encoder中的layer数量。
--dec_layers:Decoder中的layer数量。
--dim_feedforward:FFN(前馈神经网络)的通道数。
--hidden_dim:中间维度。默认256。
--dropout:神经元以dropout设置的概率随机失活,一般用在全连接层,为了防止或减轻过拟合。
--nheads:头个数。
--num_queries:查询位置个数,即多少个目标框。
--pre_norm:在网络中,是否提前进行normalization。
--mask:
--no_aux_loss:  是否在解码器使用辅助损耗。
--set_cost_class:类别损失权重。
--set_cost_bbox:L1定位框损失函数权重。
--set_cost_giou:giou定位框损失函数权重。
--mask_loss_coef:mask损失置信度。
--dice_loss_coef:轮廓区域损失,分割使用
--bbox_loss_coef:预测框损失。
--giou_loss_coef:giou损失。主要有类别、预测框和giou损失函数。L1 Loss整体不如giou。
--eos_coef:针对背景分类的权重,即为无目标样本设置的权重,默认0.1。
--dataset_file:存储数据的格式,默认为‘coco’。
--coco_path:是存放自己数据集train、test、val文件夹的文件夹路径
--remove_difficult:是否移除difficult。
--output_dir:训练以及验证模型保存的文件夹路径。
--device:‘cpu’ or ‘cuda’,default='cuda'
--seed:随机种子seed。
--resume:预处理模型以及断点训练的load模型的地址。
--start_epoch:指定开始训练的第几个epoch。如果设置resume了,会根据resume指定模型中的epoch接着训练。
--eval:测试模型。
--num_workers:加载数据(batch)的线程数目,参考:num_workers的设置
--world_size:一般别超过1,是分布式训练需要用到的设置。
--dist_url:设置分布式训练。

错误及解决办法

在测试图片的时候报错:output with shape [1, 640, 640] doesn't match the broadcast shape [3, 640, 640]。(我单独输出图像的shape,明明是三通道的,不知道为什么在推断模型的时候给判断成单通的了...)

解决:待测试图像是灰度图,需要将灰度图转换成RGB三通道的彩色图,在inference_img.py文件中添加:

im = Image.open(img_path).convert('RGB')

DETR训练自己数据集心得-LMLPHP

11-22 11:51