commit 9fea321af39d59b012fe5e740da73c89a10bdf84 Author: tanzk Date: Tue Jun 18 10:59:34 2024 +0800 init diff --git a/README.md b/README.md new file mode 100644 index 0000000..751e57c --- /dev/null +++ b/README.md @@ -0,0 +1,126 @@ +## DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在Pytorch当中的实现 +--- + +### 目录 +1. [仓库更新 Top News](#仓库更新) +2. [相关仓库 Related code](#相关仓库) +3. [性能情况 Performance](#性能情况) +4. [所需环境 Environment](#所需环境) +5. [文件下载 Download](#文件下载) +6. [训练步骤 How2train](#训练步骤) +7. [预测步骤 How2predict](#预测步骤) +8. [评估步骤 miou](#评估步骤) +9. [参考资料 Reference](#Reference) + +## Top News +**`2022-04`**:**支持多GPU训练。** + +**`2022-03`**:**进行大幅度更新、支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整。** +BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/deeplabv3-plus-pytorch/tree/bilibili + +**`2020-08`**:**创建仓库、支持多backbone、支持数据miou评估、标注数据处理、大量注释等。** + +## 相关仓库 +| 模型 | 路径 | +| :----- | :----- | +Unet | https://github.com/bubbliiiing/unet-pytorch +PSPnet | https://github.com/bubbliiiing/pspnet-pytorch +deeplabv3+ | https://github.com/bubbliiiing/deeplabv3-plus-pytorch +hrnet | https://github.com/bubbliiiing/hrnet-pytorch + +### 性能情况 +| 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | mIOU | +| :-----: | :-----: | :------: | :------: | :------: | +| VOC12+SBD | [deeplab_mobilenetv2.pth](https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/deeplab_mobilenetv2.pth) | VOC-Val12 | 512x512| 72.59 | +| VOC12+SBD | [deeplab_xception.pth](https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/deeplab_xception.pth) | VOC-Val12 | 512x512| 76.95 | + +### 所需环境 +torch==1.2.0 + +### 注意事项 +代码中的deeplab_mobilenetv2.pth和deeplab_xception.pth是基于VOC拓展数据集训练的。训练和预测时注意修改backbone。 + +### 文件下载 +训练所需的deeplab_mobilenetv2.pth和deeplab_xception.pth可在百度网盘中下载。 +链接: https://pan.baidu.com/s/1IQ3XYW-yRWQAy7jxCUHq8Q 提取码: qqq4 + +VOC拓展数据集的百度网盘如下: +链接: https://pan.baidu.com/s/1vkk3lMheUm6IjTXznlg7Ng 提取码: 44mk + +### 训练步骤 +#### a、训练voc数据集 +1、将我提供的voc数据集放入VOCdevkit中(无需运行voc_annotation.py)。 +2、在train.py中设置对应参数,默认参数已经对应voc数据集所需要的参数了,所以只要修改backbone和model_path即可。 +3、运行train.py进行训练。 + +#### b、训练自己的数据集 +1、本文使用VOC格式进行训练。 +2、训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的SegmentationClass中。 +3、训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。 +4、在训练前利用voc_annotation.py文件生成对应的txt。 +5、在train.py文件夹下面,选择自己要使用的主干模型和下采样因子。本文提供的主干模型有mobilenet和xception。下采样因子可以在8和16中选择。需要注意的是,预训练模型需要和主干模型相对应。 +6、注意修改train.py的num_classes为分类个数+1。 +7、运行train.py即可开始训练。 + +### 预测步骤 +#### a、使用预训练权重 +1、下载完库后解压,如果想用backbone为mobilenet的进行预测,直接运行predict.py就可以了;如果想要利用backbone为xception的进行预测,在百度网盘下载deeplab_xception.pth,放入model_data,修改deeplab.py的backbone和model_path之后再运行predict.py,输入。 +```python +img/street.jpg +``` +可完成预测。 +2、在predict.py里面进行设置可以进行fps测试、整个文件夹的测试和video视频检测。 + +#### b、使用自己训练的权重 +1、按照训练步骤训练。 +2、在deeplab.py文件里面,在如下部分修改model_path、num_classes、backbone使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,num_classes代表要预测的类的数量加1,backbone是所使用的主干特征提取网络**。 +```python +_defaults = { + #----------------------------------------# + # model_path指向logs文件夹下的权值文件 + #----------------------------------------# + "model_path" : 'model_data/deeplab_mobilenetv2.pth', + #----------------------------------------# + # 所需要区分的类的个数+1 + #----------------------------------------# + "num_classes" : 21, + #----------------------------------------# + # 所使用的的主干网络 + #----------------------------------------# + "backbone" : "mobilenet", + #----------------------------------------# + # 输入图片的大小 + #----------------------------------------# + "input_shape" : [512, 512], + #----------------------------------------# + # 下采样的倍数,一般可选的为8和16 + # 与训练时设置的一样即可 + #----------------------------------------# + "downsample_factor" : 16, + #--------------------------------# + # blend参数用于控制是否 + # 让识别结果和原图混合 + #--------------------------------# + "blend" : True, + #-------------------------------# + # 是否使用Cuda + # 没有GPU可以设置成False + #-------------------------------# + "cuda" : True, +} +``` +3、运行predict.py,输入 +```python +img/street.jpg +``` +可完成预测。 +4、在predict.py里面进行设置可以进行fps测试、整个文件夹的测试和video视频检测。 + +### 评估步骤 +1、设置get_miou.py里面的num_classes为预测的类的数量加1。 +2、设置get_miou.py里面的name_classes为需要去区分的类别。 +3、运行get_miou.py即可获得miou大小。 + +### Reference +https://github.com/ggyyzm/pytorch_segmentation +https://github.com/bonlime/keras-deeplab-v3-plus diff --git a/VOCdevkit/VOC2007/ImageSets/Segmentation/README.md b/VOCdevkit/VOC2007/ImageSets/Segmentation/README.md new file mode 100644 index 0000000..9042c5f --- /dev/null +++ b/VOCdevkit/VOC2007/ImageSets/Segmentation/README.md @@ -0,0 +1,2 @@ +存放的是指向文件名称的txt + diff --git a/VOCdevkit/VOC2007/ImageSets/Segmentation/test.txt b/VOCdevkit/VOC2007/ImageSets/Segmentation/test.txt new file mode 100644 index 0000000..e69de29 diff --git a/VOCdevkit/VOC2007/ImageSets/Segmentation/train.txt b/VOCdevkit/VOC2007/ImageSets/Segmentation/train.txt new file mode 100644 index 0000000..9b2dc1d --- /dev/null +++ b/VOCdevkit/VOC2007/ImageSets/Segmentation/train.txt @@ -0,0 +1,295 @@ +01-1 +01-2 +01-3 +01-4 +02-1 +02-2 +03-1 +03-2 +03-3 +04-1 +04-2 +04-3 +05-1 +05-2 +05-3 +06-1 +06-2 +06-3 +07-1 +07-2 +08-1 +08-2 +09-1 +10-1 +100-1 +11-1 +116_01_01 +117_01_01 +118_01_01 +119_01_01 +11_01_01 +12-1 +120_01_01 +121_01_01 +122_01_01 +123_01_01 +124_01_01 +125_01_01 +126_01_01 +127_01_01 +128_01_01 +129_01_01 +12_01_01 +13-1 +13-3 +130_01_01 +132_01_01 +134_01_01 +138_01_01 +139_01_01 +14-1 +140_01_01 +141_01_01 +142_01_01 +143_01_01 +144_01_01 +145_01_01 +146_01_01 +147_01_01 +148_01_01 +149_01_01 +14_01_01 +15-1 +150_01_01 +151_01_01 +152_01_01 +153_01_01 +154_01_01 +155_01_01 +156_01_01 +157_01_01 +158_01_01 +159_01_01 +15_01_01 +16-1 +16-3 +160_01_01 +161_01_01 +162_01_01 +163_01_01 +164_01_01 +165_01_01 +166_01_01 +168_01_01 +16_01_01 +17-1 +171_01_01 +172_01_01 +173_01_01 +174_01_01 +175_01_01 +176_01_01 +177_01_01 +178_01_01 +179_01_01 +18-1 +180_01_01 +181_01_01 +182_01_01 +183_01_01 +184_01_01 +185_01_01 +187_01_01 +188_01_01 +189_01_01 +19-1 +190_01_01 +191_01_01 +192_01_01 +194_01_01 +195_01_01 +196_01_01 +198_01_01 +199_01_01 +19_01_01 +19_01_02 +1_01_02 +1_01_03 +1_01_04 +1_02_01 +1_02_02 +1_02_03 +1_02_04 +1_03_01 +1_03_02 +1_03_03 +1_03_04 +20-1 +200_01_01 +201_01_01 +202_01_01 +203_01_01 +204_01_01 +206_01_01 +207_01_01 +208_01_01 +21-1 +21_01_01 +22-1 +22-3 +22_02_01 +23-1 +239_01_01 +23_01_01 +24-1 +240_01_01 +241_01_01 +242_01_01 +243_01_01 +244_01_01 +245_01_01 +246_01_01 +247_01_01 +248_01_01 +249_01_01 +24_02_01 +25-1 +25-2 +250_01_01 +251_01_01 +25_01_01 +27-3 +27_01_01 +28-1 +28_01_01 +29-1 +2_01_01 +2_01_02 +2_01_03 +2_01_04 +2_02_01 +2_02_02 +2_02_03 +2_02_04 +2_03_01 +2_03_02 +2_03_03 +2_03_04 +30-1 +31-1 +32-1 +34-1 +35-1 +35-3 +36-1 +36-3 +37-1 +37-2 +38-1 +39-1 +39-2 +3_01_01 +3_01_03 +3_01_04 +3_02_01 +3_02_02 +3_02_03 +3_02_04 +3_03_02 +3_03_03 +3_03_04 +40-1 +40-3 +41-1 +41-2 +41_02_01 +42_02_01 +43-1 +44-1 +45-1 +47-1 +49-1 +4_01_01 +4_01_02 +4_01_03 +4_01_04 +4_02_01 +4_02_02 +4_02_03 +4_02_04 +4_03_02 +4_03_03 +50-1 +51-1 +52-1 +53-1 +54-1 +56-1 +57-1 +58-1 +59-1 +5_01_01 +5_01_02 +5_01_03 +5_01_04 +5_02_01 +5_02_02 +5_02_03 +5_02_04 +5_03_03 +5_03_04 +60-1 +61-1 +62-1 +64-1 +65-1 +66-1 +67-1 +68-1 +69-1 +6_01_01 +6_01_02 +6_01_03 +6_01_04 +6_02_02 +6_02_03 +6_02_04 +6_03_01 +6_03_02 +6_03_03 +6_03_04 +70-1 +71-1 +72-1 +73-1 +74-1 +75-1 +75-2 +76-1 +77-1 +78-1 +79-1 +7_01_01 +80-1 +81-1 +82-1 +83-1 +84-1 +85-1 +86-1 +87-1 +88-1 +89-1 +8_01_01 +91-1 +92-1 +93-1 +94-1 +95-1 +96-1 +97-1 +97-4 +98-1 +98-2 +99-1 diff --git a/VOCdevkit/VOC2007/ImageSets/Segmentation/trainval.txt b/VOCdevkit/VOC2007/ImageSets/Segmentation/trainval.txt new file mode 100644 index 0000000..ff632fb --- /dev/null +++ b/VOCdevkit/VOC2007/ImageSets/Segmentation/trainval.txt @@ -0,0 +1,328 @@ +01-1 +01-2 +01-3 +01-4 +02-1 +02-2 +03-1 +03-2 +03-3 +04-1 +04-2 +04-3 +05-1 +05-2 +05-3 +06-1 +06-2 +06-3 +07-1 +07-2 +08-1 +08-2 +09-1 +10-1 +100-1 +10_01_01 +11-1 +116_01_01 +117_01_01 +118_01_01 +119_01_01 +11_01_01 +12-1 +120_01_01 +121_01_01 +122_01_01 +123_01_01 +124_01_01 +125_01_01 +126_01_01 +127_01_01 +128_01_01 +129_01_01 +12_01_01 +13-1 +13-3 +130_01_01 +131_01_01 +132_01_01 +133_01_01 +134_01_01 +135_01_01 +136_01_01 +137_01_01 +138_01_01 +139_01_01 +14-1 +140_01_01 +141_01_01 +142_01_01 +143_01_01 +144_01_01 +145_01_01 +146_01_01 +147_01_01 +148_01_01 +149_01_01 +14_01_01 +15-1 +150_01_01 +151_01_01 +152_01_01 +153_01_01 +154_01_01 +155_01_01 +156_01_01 +157_01_01 +158_01_01 +159_01_01 +15_01_01 +16-1 +16-3 +160_01_01 +161_01_01 +162_01_01 +163_01_01 +164_01_01 +165_01_01 +166_01_01 +167_01_01 +168_01_01 +169_01_01 +16_01_01 +17-1 +170_01_01 +171_01_01 +172_01_01 +173_01_01 +174_01_01 +175_01_01 +176_01_01 +177_01_01 +178_01_01 +179_01_01 +18-1 +180_01_01 +181_01_01 +182_01_01 +183_01_01 +184_01_01 +185_01_01 +186_01_01 +187_01_01 +188_01_01 +189_01_01 +19-1 +190_01_01 +191_01_01 +192_01_01 +193_01_01 +194_01_01 +195_01_01 +196_01_01 +197_01_01 +198_01_01 +199_01_01 +19_01_01 +19_01_02 +1_01_01 +1_01_02 +1_01_03 +1_01_04 +1_02_01 +1_02_02 +1_02_03 +1_02_04 +1_03_01 +1_03_02 +1_03_03 +1_03_04 +20-1 +200_01_01 +201_01_01 +202_01_01 +203_01_01 +204_01_01 +205_01_01 +206_01_01 +207_01_01 +208_01_01 +20_01_02 +21-1 +21_01_01 +22-1 +22-3 +22_02_01 +23-1 +239_01_01 +23_01_01 +24-1 +240_01_01 +241_01_01 +242_01_01 +243_01_01 +244_01_01 +245_01_01 +246_01_01 +247_01_01 +248_01_01 +249_01_01 +24_02_01 +25-1 +25-2 +250_01_01 +251_01_01 +252_01_01 +25_01_01 +26-1 +27-3 +27_01_01 +28-1 +28_01_01 +29-1 +2_01_01 +2_01_02 +2_01_03 +2_01_04 +2_02_01 +2_02_02 +2_02_03 +2_02_04 +2_03_01 +2_03_02 +2_03_03 +2_03_04 +30-1 +31-1 +32-1 +33-1 +34-1 +35-1 +35-3 +36-1 +36-3 +37-1 +37-2 +38-1 +39-1 +39-2 +3_01_01 +3_01_02 +3_01_03 +3_01_04 +3_02_01 +3_02_02 +3_02_03 +3_02_04 +3_03_01 +3_03_02 +3_03_03 +3_03_04 +40-1 +40-3 +41-1 +41-2 +41_02_01 +42-1 +42_02_01 +43-1 +44-1 +45-1 +46-1 +47-1 +48-1 +49-1 +4_01_01 +4_01_02 +4_01_03 +4_01_04 +4_02_01 +4_02_02 +4_02_03 +4_02_04 +4_03_01 +4_03_02 +4_03_03 +4_03_04 +50-1 +51-1 +52-1 +53-1 +53-2 +54-1 +55-1 +56-1 +57-1 +58-1 +59-1 +5_01_01 +5_01_02 +5_01_03 +5_01_04 +5_02_01 +5_02_02 +5_02_03 +5_02_04 +5_03_01 +5_03_02 +5_03_03 +5_03_04 +60-1 +61-1 +62-1 +63-1 +64-1 +65-1 +66-1 +67-1 +68-1 +69-1 +6_01_01 +6_01_02 +6_01_03 +6_01_04 +6_02_01 +6_02_02 +6_02_03 +6_02_04 +6_03_01 +6_03_02 +6_03_03 +6_03_04 +70-1 +71-1 +72-1 +73-1 +74-1 +75-1 +75-2 +76-1 +77-1 +78-1 +79-1 +7_01_01 +80-1 +81-1 +82-1 +83-1 +84-1 +85-1 +86-1 +87-1 +88-1 +89-1 +8_01_01 +90-1 +91-1 +92-1 +93-1 +94-1 +95-1 +96-1 +97-1 +97-4 +98-1 +98-2 +99-1 +9_01_01 diff --git a/VOCdevkit/VOC2007/ImageSets/Segmentation/val.txt b/VOCdevkit/VOC2007/ImageSets/Segmentation/val.txt new file mode 100644 index 0000000..ebf9121 --- /dev/null +++ b/VOCdevkit/VOC2007/ImageSets/Segmentation/val.txt @@ -0,0 +1,33 @@ +10_01_01 +131_01_01 +133_01_01 +135_01_01 +136_01_01 +137_01_01 +167_01_01 +169_01_01 +170_01_01 +186_01_01 +193_01_01 +197_01_01 +1_01_01 +205_01_01 +20_01_02 +252_01_01 +26-1 +33-1 +3_01_02 +3_03_01 +42-1 +46-1 +48-1 +4_03_01 +4_03_04 +53-2 +55-1 +5_03_01 +5_03_02 +63-1 +6_02_01 +90-1 +9_01_01 diff --git a/VOCdevkit1/VOC2007/ImageSets/Segmentation/README.md b/VOCdevkit1/VOC2007/ImageSets/Segmentation/README.md new file mode 100644 index 0000000..9042c5f --- /dev/null +++ b/VOCdevkit1/VOC2007/ImageSets/Segmentation/README.md @@ -0,0 +1,2 @@ +存放的是指向文件名称的txt + diff --git a/VOCdevkit1/VOC2007/ImageSets/Segmentation/test.txt b/VOCdevkit1/VOC2007/ImageSets/Segmentation/test.txt new file mode 100644 index 0000000..e69de29 diff --git a/VOCdevkit1/VOC2007/ImageSets/Segmentation/train.txt b/VOCdevkit1/VOC2007/ImageSets/Segmentation/train.txt new file mode 100644 index 0000000..6c0e27c --- /dev/null +++ b/VOCdevkit1/VOC2007/ImageSets/Segmentation/train.txt @@ -0,0 +1,2105 @@ +00001 +00002 +00003 +00004 +00005 +00006 +00007 +00008 +00010 +00011 +00012 +00013 +00014 +00015 +00016 +00018 +00019 +00020 +00021 +00022 +00023 +00024 +00025 +00026 +00028 +00029 +00030 +00031 +00032 +00034 +00035 +00036 +00037 +00038 +00039 +00040 +00041 +00043 +00044 +00045 +00046 +00047 +00048 +00049 +00050 +00052 +00053 +00054 +00055 +00056 +00057 +00058 +00059 +00060 +00061 +00062 +00063 +00064 +00065 +00066 +00067 +00068 +00069 +00071 +00072 +00073 +00074 +00075 +00076 +00077 +00078 +00079 +00080 +00081 +00083 +00084 +00085 +00086 +00087 +00088 +00090 +00091 +00093 +00094 +00095 +00096 +00097 +00098 +00099 +00100 +00101 +00103 +00104 +00105 +00106 +00107 +00108 +00109 +00110 +00112 +00113 +00114 +00115 +00116 +00118 +00119 +00120 +00121 +00122 +00123 +00124 +00125 +00126 +00127 +00128 +00129 +00132 +00134 +00135 +00136 +00139 +00140 +00141 +00142 +00143 +00144 +00146 +00147 +00148 +00149 +00150 +00151 +00152 +00153 +00154 +00155 +00156 +00157 +00158 +00159 +00160 +00161 +00163 +00164 +00165 +00166 +00167 +00168 +00169 +00170 +00172 +00173 +00174 +00175 +00176 +00177 +00178 +00179 +00180 +00181 +00182 +00183 +00184 +00186 +00187 +00188 +00190 +00192 +00193 +00194 +00195 +00196 +00197 +00198 +00199 +00200 +00201 +00202 +00204 +00205 +00206 +00207 +00209 +00210 +00212 +00213 +00214 +00215 +00216 +00217 +00219 +00220 +00221 +00222 +00224 +00225 +00226 +00227 +00228 +00229 +00230 +00231 +00232 +00235 +00236 +00237 +00238 +00239 +00240 +00241 +00242 +00243 +00244 +00245 +00246 +00247 +00248 +00250 +00251 +00252 +00253 +00254 +00256 +00257 +00258 +00259 +00260 +00261 +00262 +00264 +00265 +00266 +00267 +00268 +00269 +00270 +00271 +00272 +00273 +00274 +00275 +00276 +00277 +00279 +00280 +00281 +00282 +00283 +00284 +00285 +00286 +00287 +00288 +00289 +00290 +00291 +00292 +00293 +00294 +00295 +00296 +00298 +00299 +00300 +00301 +00302 +00303 +00304 +00305 +00306 +00307 +00308 +00309 +00310 +00311 +00312 +00313 +00314 +00315 +00316 +00317 +00318 +00321 +00323 +00324 +00325 +00326 +00328 +00330 +00331 +00332 +00333 +00334 +00335 +00336 +00337 +00338 +00339 +00340 +00341 +00342 +00343 +00344 +00345 +00346 +00347 +00348 +00349 +00350 +00351 +00352 +00353 +00354 +00355 +00356 +00357 +00358 +00359 +00360 +00361 +00362 +00363 +00364 +00365 +00366 +00367 +00368 +00369 +00370 +00371 +00372 +00374 +00375 +00376 +00377 +00378 +00379 +00380 +00381 +00382 +00383 +00384 +00386 +00387 +00388 +00389 +00390 +00391 +00392 +00393 +00394 +00396 +00397 +00398 +00399 +00400 +00401 +00402 +00403 +00404 +00405 +00406 +00407 +00408 +00409 +00410 +00412 +00413 +00414 +00415 +00416 +00418 +00419 +00420 +00421 +00423 +00424 +00425 +00426 +00427 +00428 +00430 +00431 +00432 +00433 +00434 +00435 +00436 +00437 +00439 +00440 +00441 +00442 +00443 +00444 +00445 +00446 +00447 +00448 +00449 +00450 +00451 +00452 +00454 +00455 +00456 +00458 +00459 +00460 +00461 +00462 +00463 +00464 +00465 +00466 +00467 +00468 +00469 +00471 +00472 +00473 +00474 +00475 +00476 +00477 +00478 +00479 +00480 +00481 +00482 +00483 +00484 +00485 +00486 +00487 +00488 +00489 +00490 +00491 +00492 +00494 +00495 +00496 +00499 +00500 +00501 +00502 +00503 +00504 +00505 +00506 +00507 +00508 +00509 +00510 +00511 +00512 +00513 +00514 +00515 +00517 +00518 +00519 +00520 +00521 +00522 +00523 +00524 +00525 +00526 +00527 +00528 +00529 +00531 +00532 +00533 +00534 +00535 +00536 +00537 +00538 +00539 +00540 +00541 +00542 +00543 +00544 +00545 +00546 +00547 +00548 +00549 +00550 +00551 +00552 +00553 +00554 +00555 +00556 +00557 +00558 +00559 +00560 +00561 +00562 +00563 +00564 +00565 +00566 +00567 +00568 +00569 +00570 +00571 +00572 +00573 +00574 +00575 +00576 +00577 +00578 +00579 +00580 +00581 +00582 +00583 +00584 +00585 +00586 +00587 +00589 +00590 +00591 +00592 +00593 +00594 +00596 +00597 +00598 +00599 +00600 +00602 +00603 +00605 +00606 +00607 +00608 +00609 +00610 +00611 +00612 +00613 +00614 +00615 +00617 +00618 +00619 +00620 +00621 +00622 +00623 +00624 +00625 +00626 +00627 +00628 +00629 +00630 +00631 +00632 +00633 +00634 +00637 +00638 +00639 +00640 +00641 +00642 +00643 +00644 +00645 +00646 +00647 +00648 +00649 +00650 +00651 +00652 +00653 +00654 +00655 +00656 +00657 +00658 +00659 +00660 +00661 +00662 +00663 +00664 +00665 +00666 +00667 +00668 +00670 +00671 +00672 +00673 +00674 +00675 +00676 +00677 +00678 +00679 +00680 +00681 +00682 +00683 +00684 +00685 +00686 +00687 +00688 +00689 +00690 +00691 +00692 +00693 +00694 +00695 +00697 +00698 +00699 +00700 +00702 +00703 +00704 +00705 +00706 +00707 +00708 +00710 +00711 +00712 +00713 +00714 +00715 +00717 +00718 +00719 +00720 +00721 +00722 +00723 +00725 +00726 +00727 +00728 +00729 +00730 +00731 +00733 +00734 +00735 +00736 +00737 +00738 +00739 +00740 +00742 +00743 +00744 +00745 +00746 +00747 +00749 +00750 +00751 +00752 +00753 +00754 +00755 +00758 +00759 +00760 +00761 +00762 +00763 +00764 +00765 +00766 +00767 +00768 +00769 +00770 +00771 +00772 +00773 +00774 +00775 +00776 +00777 +00778 +00779 +00781 +00782 +00783 +00784 +00785 +00786 +00787 +00788 +00790 +00791 +00792 +00793 +00794 +00795 +00796 +00797 +00798 +00799 +00800 +00801 +00802 +00803 +00804 +00805 +00806 +00807 +00808 +00809 +00810 +00811 +00813 +00814 +00815 +00816 +00817 +00818 +00819 +00820 +00821 +00822 +00823 +00824 +00825 +00826 +00827 +00828 +00829 +00830 +00831 +00832 +00833 +00834 +00835 +00836 +00838 +00839 +00840 +00841 +00842 +00843 +00844 +00845 +00846 +00847 +00848 +00849 +00851 +00852 +00853 +00854 +00855 +00856 +00857 +00859 +00860 +00861 +00862 +00863 +00864 +00865 +00866 +00868 +00870 +00871 +00872 +00873 +00874 +00875 +00877 +00878 +00880 +00881 +00882 +00883 +00885 +00886 +00888 +00889 +00890 +00891 +00892 +00893 +00894 +00895 +00896 +00897 +00898 +00899 +00900 +00902 +00903 +00904 +00905 +00907 +00908 +00909 +00912 +00913 +00914 +00915 +00916 +00917 +00918 +00919 +00920 +00921 +00922 +00923 +00924 +00925 +00927 +00928 +00929 +00930 +00931 +00932 +00933 +00934 +00935 +00936 +00937 +00938 +00939 +00940 +00941 +00942 +00943 +00944 +00945 +00946 +00947 +00948 +00949 +00950 +00951 +00952 +00953 +00954 +00955 +00956 +00957 +00958 +00959 +00960 +00961 +00962 +00963 +00964 +00966 +00967 +00968 +00969 +00970 +00971 +00972 +00973 +00974 +00975 +00976 +00977 +00978 +00980 +00981 +00982 +00983 +00984 +00985 +00988 +00989 +00990 +00992 +00993 +00994 +00995 +00996 +00997 +00998 +00999 +01000 +01002 +01003 +01004 +01005 +01006 +01007 +01008 +01009 +01010 +01011 +01012 +01013 +01015 +01016 +01017 +01018 +01019 +01020 +01021 +01022 +01023 +01024 +01025 +01027 +01028 +01029 +01030 +01031 +01032 +01033 +01034 +01036 +01037 +01038 +01039 +01040 +01041 +01042 +01043 +01044 +01045 +01046 +01047 +01049 +01051 +01052 +01053 +01054 +01055 +01056 +01057 +01058 +01059 +01060 +01061 +01062 +01063 +01064 +01065 +01066 +01067 +01068 +01070 +01071 +01072 +01073 +01074 +01075 +01077 +01078 +01079 +01080 +01081 +01082 +01084 +01085 +01086 +01087 +01088 +01089 +01090 +01091 +01094 +01095 +01096 +01097 +01098 +01099 +01100 +01101 +01102 +01103 +01104 +01105 +01106 +01107 +01108 +01109 +01110 +01111 +01112 +01114 +01115 +01116 +01117 +01118 +01119 +01120 +01121 +01122 +01123 +01125 +01127 +01128 +01129 +01130 +01132 +01133 +01134 +01135 +01136 +01137 +01138 +01139 +01141 +01143 +01144 +01145 +01146 +01147 +01148 +01149 +01150 +01151 +01152 +01153 +01154 +01155 +01156 +01157 +01158 +01160 +01162 +01163 +01164 +01165 +01166 +01167 +01169 +01170 +01171 +01172 +01173 +01174 +01175 +01176 +01177 +01179 +01180 +01181 +01182 +01183 +01184 +01186 +01187 +01188 +01189 +01191 +01192 +01194 +01195 +01196 +01197 +01198 +01199 +01200 +01201 +01202 +01204 +01205 +01206 +01207 +01208 +01209 +01210 +01211 +01212 +01213 +01214 +01215 +01216 +01217 +01218 +01220 +01221 +01222 +01223 +01225 +01226 +01227 +01228 +01229 +01230 +01231 +01232 +01233 +01234 +01235 +01236 +01237 +01238 +01239 +01240 +01241 +01242 +01243 +01244 +01245 +01246 +01247 +01248 +01249 +01250 +01251 +01252 +01253 +01255 +01256 +01257 +01260 +01262 +01263 +01264 +01265 +01266 +01268 +01269 +01270 +01272 +01273 +01274 +01275 +01276 +01277 +01278 +01280 +01281 +01282 +01283 +01285 +01286 +01287 +01288 +01289 +01290 +01292 +01293 +01294 +01295 +01296 +01297 +01298 +01299 +01300 +01301 +01302 +01303 +01304 +01305 +01306 +01307 +01308 +01309 +01310 +01311 +01312 +01315 +01316 +01317 +01318 +01319 +01321 +01322 +01323 +01326 +01327 +01328 +01329 +01330 +01331 +01332 +01333 +01334 +01335 +01336 +01337 +01338 +01339 +01340 +01341 +01342 +01343 +01344 +01345 +01346 +01347 +01348 +01349 +01350 +01351 +01353 +01354 +01355 +01356 +01357 +01358 +01359 +01360 +01361 +01362 +01363 +01365 +01366 +01367 +01368 +01369 +01370 +01371 +01372 +01373 +01374 +01375 +01376 +01377 +01378 +01379 +01380 +01381 +01382 +01383 +01384 +01385 +01386 +01387 +01388 +01389 +01390 +01391 +01392 +01393 +01394 +01395 +01397 +01398 +01399 +01400 +01401 +01402 +01403 +01404 +01405 +01406 +01407 +01408 +01409 +01410 +01411 +01412 +01414 +01415 +01416 +01417 +01418 +01419 +01420 +01422 +01424 +01425 +01426 +01427 +01428 +01429 +01430 +01431 +01432 +01433 +01435 +01437 +01438 +01439 +01440 +01441 +01442 +01443 +01444 +01445 +01446 +01448 +01449 +01450 +01451 +01452 +01453 +01455 +01456 +01457 +01458 +01459 +01460 +01461 +01462 +01463 +01464 +01466 +01467 +01468 +01470 +01471 +01473 +01474 +01475 +01476 +01477 +01478 +01479 +01481 +01482 +01483 +01484 +01485 +01486 +01487 +01488 +01490 +01491 +01492 +01493 +01494 +01496 +01497 +01498 +01499 +01500 +01501 +01502 +01503 +01504 +01505 +01506 +01507 +01508 +01509 +01510 +01512 +01513 +01514 +01515 +01516 +01518 +01519 +01520 +01521 +01522 +01523 +01524 +01525 +01526 +01527 +01528 +01530 +01531 +01532 +01533 +01535 +01536 +01537 +01538 +01539 +01540 +01542 +01543 +01544 +01546 +01547 +01548 +01549 +01550 +01551 +01552 +01555 +01556 +01557 +01558 +01559 +01560 +01561 +01562 +01563 +01564 +01565 +01566 +01567 +01568 +01569 +01570 +01571 +01572 +01573 +01574 +01575 +01576 +01577 +01578 +01579 +01580 +01581 +01582 +01583 +01584 +01585 +01586 +01587 +01588 +01589 +01590 +01591 +01592 +01593 +01594 +01596 +01598 +01599 +01600 +01601 +01602 +01603 +01604 +01605 +01606 +01607 +01608 +01609 +01610 +01611 +01612 +01613 +01614 +01615 +01616 +01617 +01619 +01620 +01621 +01622 +01623 +01624 +01626 +01627 +01628 +01632 +01633 +01634 +01635 +01636 +01637 +01638 +01639 +01642 +01643 +01644 +01645 +01646 +01647 +01648 +01649 +01650 +01651 +01652 +01653 +01654 +01655 +01656 +01657 +01658 +01659 +01660 +01661 +01662 +01663 +01664 +01665 +01666 +01667 +01668 +01669 +01671 +01672 +01673 +01674 +01675 +01676 +01677 +01678 +01679 +01680 +01681 +01682 +01683 +01685 +01686 +01687 +01688 +01689 +01690 +01692 +01693 +01694 +01695 +01696 +01697 +01698 +01699 +01700 +01701 +01702 +01703 +01704 +01705 +01706 +01707 +01708 +01709 +01710 +01711 +01712 +01713 +01714 +01715 +01716 +01717 +01718 +01719 +01720 +01721 +01722 +01723 +01724 +01725 +01726 +01727 +01728 +01729 +01730 +01731 +01732 +01733 +01734 +01735 +01736 +01737 +01738 +01739 +01740 +01741 +01742 +01743 +01744 +01745 +01746 +01747 +01748 +01749 +01751 +01752 +01753 +01754 +01755 +01756 +01758 +01759 +01760 +01761 +01762 +01763 +01764 +01765 +01766 +01767 +01768 +01769 +01770 +01771 +01772 +01774 +01775 +01776 +01778 +01780 +01781 +01783 +01784 +01785 +01786 +01789 +01790 +01791 +01792 +01794 +01796 +01797 +01799 +01800 +01801 +01802 +01803 +01804 +01805 +01806 +01807 +01808 +01809 +01810 +01811 +01812 +01813 +01814 +01815 +01817 +01818 +01820 +01822 +01823 +01824 +01825 +01826 +01827 +01828 +01829 +01830 +01831 +01832 +01834 +01835 +01836 +01837 +01838 +01839 +01840 +01841 +01842 +01843 +01844 +01845 +01846 +01847 +01848 +01849 +01850 +01851 +01852 +01853 +01854 +01855 +01856 +01857 +01858 +01859 +01860 +01861 +01863 +01864 +01865 +01866 +01867 +01868 +01869 +01870 +01871 +01872 +01873 +01874 +01875 +01876 +01877 +01878 +01879 +01880 +01882 +01883 +01884 +01885 +01886 +01887 +01888 +01889 +01890 +01891 +01893 +01894 +01895 +01896 +01897 +01898 +01899 +01900 +01901 +01903 +01904 +01905 +01906 +01907 +01908 +01909 +01910 +01911 +01912 +01913 +01914 +01915 +01916 +01917 +01918 +01919 +01920 +01921 +01922 +01923 +01924 +01925 +01926 +01927 +01928 +01929 +01930 +01931 +01932 +01933 +01934 +01935 +01936 +01937 +01941 +01942 +01943 +01944 +01945 +01946 +01947 +01949 +01951 +01952 +01953 +01954 +01956 +01957 +01958 +01959 +01960 +01961 +01962 +01963 +01965 +01966 +01967 +01968 +01969 +01970 +01971 +01972 +01973 +01974 +01975 +01976 +01977 +01978 +01979 +01980 +01981 +01982 +01983 +01984 +01985 +01986 +01987 +01988 +01989 +01991 +01992 +01993 +01994 +01995 +01996 +01997 +01998 +01999 +02000 +02001 +02002 +02003 +02004 +02005 +02006 +02007 +02008 +02009 +02010 +02011 +02012 +02013 +02015 +02016 +02017 +02018 +02020 +02021 +02022 +02023 +02024 +02025 +02026 +02027 +02028 +02029 +02030 +02031 +02032 +02033 +02034 +02035 +02036 +02037 +02038 +02039 +02040 +02041 +02042 +02043 +02044 +02045 +02046 +02047 +02048 +02050 +02051 +02052 +02053 +02054 +02055 +02056 +02057 +02058 +02059 +02060 +02062 +02063 +02064 +02065 +02067 +02069 +02070 +02071 +02072 +02073 +02074 +02075 +02076 +02077 +02078 +02079 +02080 +02081 +02083 +02085 +02086 +02087 +02088 +02089 +02092 +02094 +02096 +02097 +02098 +02099 +02100 +02101 +02102 +02103 +02104 +02105 +02106 +02107 +02108 +02109 +02110 +02111 +02112 +02113 +02114 +02115 +02116 +02117 +02119 +02120 +02122 +02123 +02124 +02125 +02126 +02127 +02128 +02129 +02130 +02131 +02132 +02134 +02135 +02136 +02137 +02138 +02139 +02140 +02141 +02142 +02143 +02144 +02145 +02146 +02147 +02148 +02149 +02150 +02151 +02152 +02153 +02154 +02155 +02156 +02158 +02159 +02160 +02162 +02163 +02164 +02165 +02166 +02167 +02168 +02169 +02170 +02171 +02172 +02174 +02175 +02176 +02177 +02178 +02179 +02180 +02181 +02182 +02184 +02185 +02186 +02187 +02188 +02189 +02190 +02191 +02192 +02193 +02194 +02195 +02196 +02197 +02198 +02199 +02201 +02202 +02205 +02206 +02207 +02208 +02209 +02210 +02211 +02213 +02214 +02215 +02216 +02217 +02218 +02219 +02220 +02223 +02224 +02225 +02226 +02227 +02228 +02231 +02232 +02233 +02234 +02235 +02236 +02237 +02238 +02239 +02240 +02242 +02243 +02244 +02245 +02246 +02247 +02248 +02249 +02250 +02251 +02252 +02254 +02255 +02256 +02257 +02258 +02259 +02260 +02261 +02262 +02264 +02265 +02266 +02267 +02268 +02269 +02270 +02271 +02272 +02273 +02274 +02275 +02276 +02277 +02278 +02280 +02281 +02282 +02283 +02284 +02285 +02286 +02287 +02288 +02289 +02290 +02291 +02292 +02293 +02294 +02295 +02296 +02297 +02298 +02299 +02300 +02301 +02302 +02303 +02304 +02305 +02306 +02307 +02308 +02309 +02310 +02311 +02312 +02313 +02314 +02315 +02316 +02317 +02318 +02320 +02321 +02322 +02323 +02324 +02325 +02326 +02327 +02328 +02329 +02330 +02331 +02333 +02334 +02335 +02336 +02337 +02338 +02339 diff --git a/VOCdevkit1/VOC2007/ImageSets/Segmentation/trainval.txt b/VOCdevkit1/VOC2007/ImageSets/Segmentation/trainval.txt new file mode 100644 index 0000000..ec1d75b --- /dev/null +++ b/VOCdevkit1/VOC2007/ImageSets/Segmentation/trainval.txt @@ -0,0 +1,2339 @@ +00001 +00002 +00003 +00004 +00005 +00006 +00007 +00008 +00009 +00010 +00011 +00012 +00013 +00014 +00015 +00016 +00017 +00018 +00019 +00020 +00021 +00022 +00023 +00024 +00025 +00026 +00027 +00028 +00029 +00030 +00031 +00032 +00033 +00034 +00035 +00036 +00037 +00038 +00039 +00040 +00041 +00042 +00043 +00044 +00045 +00046 +00047 +00048 +00049 +00050 +00051 +00052 +00053 +00054 +00055 +00056 +00057 +00058 +00059 +00060 +00061 +00062 +00063 +00064 +00065 +00066 +00067 +00068 +00069 +00070 +00071 +00072 +00073 +00074 +00075 +00076 +00077 +00078 +00079 +00080 +00081 +00082 +00083 +00084 +00085 +00086 +00087 +00088 +00089 +00090 +00091 +00092 +00093 +00094 +00095 +00096 +00097 +00098 +00099 +00100 +00101 +00102 +00103 +00104 +00105 +00106 +00107 +00108 +00109 +00110 +00111 +00112 +00113 +00114 +00115 +00116 +00117 +00118 +00119 +00120 +00121 +00122 +00123 +00124 +00125 +00126 +00127 +00128 +00129 +00130 +00131 +00132 +00133 +00134 +00135 +00136 +00137 +00138 +00139 +00140 +00141 +00142 +00143 +00144 +00145 +00146 +00147 +00148 +00149 +00150 +00151 +00152 +00153 +00154 +00155 +00156 +00157 +00158 +00159 +00160 +00161 +00162 +00163 +00164 +00165 +00166 +00167 +00168 +00169 +00170 +00171 +00172 +00173 +00174 +00175 +00176 +00177 +00178 +00179 +00180 +00181 +00182 +00183 +00184 +00185 +00186 +00187 +00188 +00189 +00190 +00191 +00192 +00193 +00194 +00195 +00196 +00197 +00198 +00199 +00200 +00201 +00202 +00203 +00204 +00205 +00206 +00207 +00208 +00209 +00210 +00211 +00212 +00213 +00214 +00215 +00216 +00217 +00218 +00219 +00220 +00221 +00222 +00223 +00224 +00225 +00226 +00227 +00228 +00229 +00230 +00231 +00232 +00233 +00234 +00235 +00236 +00237 +00238 +00239 +00240 +00241 +00242 +00243 +00244 +00245 +00246 +00247 +00248 +00249 +00250 +00251 +00252 +00253 +00254 +00255 +00256 +00257 +00258 +00259 +00260 +00261 +00262 +00263 +00264 +00265 +00266 +00267 +00268 +00269 +00270 +00271 +00272 +00273 +00274 +00275 +00276 +00277 +00278 +00279 +00280 +00281 +00282 +00283 +00284 +00285 +00286 +00287 +00288 +00289 +00290 +00291 +00292 +00293 +00294 +00295 +00296 +00297 +00298 +00299 +00300 +00301 +00302 +00303 +00304 +00305 +00306 +00307 +00308 +00309 +00310 +00311 +00312 +00313 +00314 +00315 +00316 +00317 +00318 +00319 +00320 +00321 +00322 +00323 +00324 +00325 +00326 +00327 +00328 +00329 +00330 +00331 +00332 +00333 +00334 +00335 +00336 +00337 +00338 +00339 +00340 +00341 +00342 +00343 +00344 +00345 +00346 +00347 +00348 +00349 +00350 +00351 +00352 +00353 +00354 +00355 +00356 +00357 +00358 +00359 +00360 +00361 +00362 +00363 +00364 +00365 +00366 +00367 +00368 +00369 +00370 +00371 +00372 +00373 +00374 +00375 +00376 +00377 +00378 +00379 +00380 +00381 +00382 +00383 +00384 +00385 +00386 +00387 +00388 +00389 +00390 +00391 +00392 +00393 +00394 +00395 +00396 +00397 +00398 +00399 +00400 +00401 +00402 +00403 +00404 +00405 +00406 +00407 +00408 +00409 +00410 +00411 +00412 +00413 +00414 +00415 +00416 +00417 +00418 +00419 +00420 +00421 +00422 +00423 +00424 +00425 +00426 +00427 +00428 +00429 +00430 +00431 +00432 +00433 +00434 +00435 +00436 +00437 +00438 +00439 +00440 +00441 +00442 +00443 +00444 +00445 +00446 +00447 +00448 +00449 +00450 +00451 +00452 +00453 +00454 +00455 +00456 +00457 +00458 +00459 +00460 +00461 +00462 +00463 +00464 +00465 +00466 +00467 +00468 +00469 +00470 +00471 +00472 +00473 +00474 +00475 +00476 +00477 +00478 +00479 +00480 +00481 +00482 +00483 +00484 +00485 +00486 +00487 +00488 +00489 +00490 +00491 +00492 +00493 +00494 +00495 +00496 +00497 +00498 +00499 +00500 +00501 +00502 +00503 +00504 +00505 +00506 +00507 +00508 +00509 +00510 +00511 +00512 +00513 +00514 +00515 +00516 +00517 +00518 +00519 +00520 +00521 +00522 +00523 +00524 +00525 +00526 +00527 +00528 +00529 +00530 +00531 +00532 +00533 +00534 +00535 +00536 +00537 +00538 +00539 +00540 +00541 +00542 +00543 +00544 +00545 +00546 +00547 +00548 +00549 +00550 +00551 +00552 +00553 +00554 +00555 +00556 +00557 +00558 +00559 +00560 +00561 +00562 +00563 +00564 +00565 +00566 +00567 +00568 +00569 +00570 +00571 +00572 +00573 +00574 +00575 +00576 +00577 +00578 +00579 +00580 +00581 +00582 +00583 +00584 +00585 +00586 +00587 +00588 +00589 +00590 +00591 +00592 +00593 +00594 +00595 +00596 +00597 +00598 +00599 +00600 +00601 +00602 +00603 +00604 +00605 +00606 +00607 +00608 +00609 +00610 +00611 +00612 +00613 +00614 +00615 +00616 +00617 +00618 +00619 +00620 +00621 +00622 +00623 +00624 +00625 +00626 +00627 +00628 +00629 +00630 +00631 +00632 +00633 +00634 +00635 +00636 +00637 +00638 +00639 +00640 +00641 +00642 +00643 +00644 +00645 +00646 +00647 +00648 +00649 +00650 +00651 +00652 +00653 +00654 +00655 +00656 +00657 +00658 +00659 +00660 +00661 +00662 +00663 +00664 +00665 +00666 +00667 +00668 +00669 +00670 +00671 +00672 +00673 +00674 +00675 +00676 +00677 +00678 +00679 +00680 +00681 +00682 +00683 +00684 +00685 +00686 +00687 +00688 +00689 +00690 +00691 +00692 +00693 +00694 +00695 +00696 +00697 +00698 +00699 +00700 +00701 +00702 +00703 +00704 +00705 +00706 +00707 +00708 +00709 +00710 +00711 +00712 +00713 +00714 +00715 +00716 +00717 +00718 +00719 +00720 +00721 +00722 +00723 +00724 +00725 +00726 +00727 +00728 +00729 +00730 +00731 +00732 +00733 +00734 +00735 +00736 +00737 +00738 +00739 +00740 +00741 +00742 +00743 +00744 +00745 +00746 +00747 +00748 +00749 +00750 +00751 +00752 +00753 +00754 +00755 +00756 +00757 +00758 +00759 +00760 +00761 +00762 +00763 +00764 +00765 +00766 +00767 +00768 +00769 +00770 +00771 +00772 +00773 +00774 +00775 +00776 +00777 +00778 +00779 +00780 +00781 +00782 +00783 +00784 +00785 +00786 +00787 +00788 +00789 +00790 +00791 +00792 +00793 +00794 +00795 +00796 +00797 +00798 +00799 +00800 +00801 +00802 +00803 +00804 +00805 +00806 +00807 +00808 +00809 +00810 +00811 +00812 +00813 +00814 +00815 +00816 +00817 +00818 +00819 +00820 +00821 +00822 +00823 +00824 +00825 +00826 +00827 +00828 +00829 +00830 +00831 +00832 +00833 +00834 +00835 +00836 +00837 +00838 +00839 +00840 +00841 +00842 +00843 +00844 +00845 +00846 +00847 +00848 +00849 +00850 +00851 +00852 +00853 +00854 +00855 +00856 +00857 +00858 +00859 +00860 +00861 +00862 +00863 +00864 +00865 +00866 +00867 +00868 +00869 +00870 +00871 +00872 +00873 +00874 +00875 +00876 +00877 +00878 +00879 +00880 +00881 +00882 +00883 +00884 +00885 +00886 +00887 +00888 +00889 +00890 +00891 +00892 +00893 +00894 +00895 +00896 +00897 +00898 +00899 +00900 +00901 +00902 +00903 +00904 +00905 +00906 +00907 +00908 +00909 +00910 +00911 +00912 +00913 +00914 +00915 +00916 +00917 +00918 +00919 +00920 +00921 +00922 +00923 +00924 +00925 +00926 +00927 +00928 +00929 +00930 +00931 +00932 +00933 +00934 +00935 +00936 +00937 +00938 +00939 +00940 +00941 +00942 +00943 +00944 +00945 +00946 +00947 +00948 +00949 +00950 +00951 +00952 +00953 +00954 +00955 +00956 +00957 +00958 +00959 +00960 +00961 +00962 +00963 +00964 +00965 +00966 +00967 +00968 +00969 +00970 +00971 +00972 +00973 +00974 +00975 +00976 +00977 +00978 +00979 +00980 +00981 +00982 +00983 +00984 +00985 +00986 +00987 +00988 +00989 +00990 +00991 +00992 +00993 +00994 +00995 +00996 +00997 +00998 +00999 +01000 +01001 +01002 +01003 +01004 +01005 +01006 +01007 +01008 +01009 +01010 +01011 +01012 +01013 +01014 +01015 +01016 +01017 +01018 +01019 +01020 +01021 +01022 +01023 +01024 +01025 +01026 +01027 +01028 +01029 +01030 +01031 +01032 +01033 +01034 +01035 +01036 +01037 +01038 +01039 +01040 +01041 +01042 +01043 +01044 +01045 +01046 +01047 +01048 +01049 +01050 +01051 +01052 +01053 +01054 +01055 +01056 +01057 +01058 +01059 +01060 +01061 +01062 +01063 +01064 +01065 +01066 +01067 +01068 +01069 +01070 +01071 +01072 +01073 +01074 +01075 +01076 +01077 +01078 +01079 +01080 +01081 +01082 +01083 +01084 +01085 +01086 +01087 +01088 +01089 +01090 +01091 +01092 +01093 +01094 +01095 +01096 +01097 +01098 +01099 +01100 +01101 +01102 +01103 +01104 +01105 +01106 +01107 +01108 +01109 +01110 +01111 +01112 +01113 +01114 +01115 +01116 +01117 +01118 +01119 +01120 +01121 +01122 +01123 +01124 +01125 +01126 +01127 +01128 +01129 +01130 +01131 +01132 +01133 +01134 +01135 +01136 +01137 +01138 +01139 +01140 +01141 +01142 +01143 +01144 +01145 +01146 +01147 +01148 +01149 +01150 +01151 +01152 +01153 +01154 +01155 +01156 +01157 +01158 +01159 +01160 +01161 +01162 +01163 +01164 +01165 +01166 +01167 +01168 +01169 +01170 +01171 +01172 +01173 +01174 +01175 +01176 +01177 +01178 +01179 +01180 +01181 +01182 +01183 +01184 +01185 +01186 +01187 +01188 +01189 +01190 +01191 +01192 +01193 +01194 +01195 +01196 +01197 +01198 +01199 +01200 +01201 +01202 +01203 +01204 +01205 +01206 +01207 +01208 +01209 +01210 +01211 +01212 +01213 +01214 +01215 +01216 +01217 +01218 +01219 +01220 +01221 +01222 +01223 +01224 +01225 +01226 +01227 +01228 +01229 +01230 +01231 +01232 +01233 +01234 +01235 +01236 +01237 +01238 +01239 +01240 +01241 +01242 +01243 +01244 +01245 +01246 +01247 +01248 +01249 +01250 +01251 +01252 +01253 +01254 +01255 +01256 +01257 +01258 +01259 +01260 +01261 +01262 +01263 +01264 +01265 +01266 +01267 +01268 +01269 +01270 +01271 +01272 +01273 +01274 +01275 +01276 +01277 +01278 +01279 +01280 +01281 +01282 +01283 +01284 +01285 +01286 +01287 +01288 +01289 +01290 +01291 +01292 +01293 +01294 +01295 +01296 +01297 +01298 +01299 +01300 +01301 +01302 +01303 +01304 +01305 +01306 +01307 +01308 +01309 +01310 +01311 +01312 +01313 +01314 +01315 +01316 +01317 +01318 +01319 +01320 +01321 +01322 +01323 +01324 +01325 +01326 +01327 +01328 +01329 +01330 +01331 +01332 +01333 +01334 +01335 +01336 +01337 +01338 +01339 +01340 +01341 +01342 +01343 +01344 +01345 +01346 +01347 +01348 +01349 +01350 +01351 +01352 +01353 +01354 +01355 +01356 +01357 +01358 +01359 +01360 +01361 +01362 +01363 +01364 +01365 +01366 +01367 +01368 +01369 +01370 +01371 +01372 +01373 +01374 +01375 +01376 +01377 +01378 +01379 +01380 +01381 +01382 +01383 +01384 +01385 +01386 +01387 +01388 +01389 +01390 +01391 +01392 +01393 +01394 +01395 +01396 +01397 +01398 +01399 +01400 +01401 +01402 +01403 +01404 +01405 +01406 +01407 +01408 +01409 +01410 +01411 +01412 +01413 +01414 +01415 +01416 +01417 +01418 +01419 +01420 +01421 +01422 +01423 +01424 +01425 +01426 +01427 +01428 +01429 +01430 +01431 +01432 +01433 +01434 +01435 +01436 +01437 +01438 +01439 +01440 +01441 +01442 +01443 +01444 +01445 +01446 +01447 +01448 +01449 +01450 +01451 +01452 +01453 +01454 +01455 +01456 +01457 +01458 +01459 +01460 +01461 +01462 +01463 +01464 +01465 +01466 +01467 +01468 +01469 +01470 +01471 +01472 +01473 +01474 +01475 +01476 +01477 +01478 +01479 +01480 +01481 +01482 +01483 +01484 +01485 +01486 +01487 +01488 +01489 +01490 +01491 +01492 +01493 +01494 +01495 +01496 +01497 +01498 +01499 +01500 +01501 +01502 +01503 +01504 +01505 +01506 +01507 +01508 +01509 +01510 +01511 +01512 +01513 +01514 +01515 +01516 +01517 +01518 +01519 +01520 +01521 +01522 +01523 +01524 +01525 +01526 +01527 +01528 +01529 +01530 +01531 +01532 +01533 +01534 +01535 +01536 +01537 +01538 +01539 +01540 +01541 +01542 +01543 +01544 +01545 +01546 +01547 +01548 +01549 +01550 +01551 +01552 +01553 +01554 +01555 +01556 +01557 +01558 +01559 +01560 +01561 +01562 +01563 +01564 +01565 +01566 +01567 +01568 +01569 +01570 +01571 +01572 +01573 +01574 +01575 +01576 +01577 +01578 +01579 +01580 +01581 +01582 +01583 +01584 +01585 +01586 +01587 +01588 +01589 +01590 +01591 +01592 +01593 +01594 +01595 +01596 +01597 +01598 +01599 +01600 +01601 +01602 +01603 +01604 +01605 +01606 +01607 +01608 +01609 +01610 +01611 +01612 +01613 +01614 +01615 +01616 +01617 +01618 +01619 +01620 +01621 +01622 +01623 +01624 +01625 +01626 +01627 +01628 +01629 +01630 +01631 +01632 +01633 +01634 +01635 +01636 +01637 +01638 +01639 +01640 +01641 +01642 +01643 +01644 +01645 +01646 +01647 +01648 +01649 +01650 +01651 +01652 +01653 +01654 +01655 +01656 +01657 +01658 +01659 +01660 +01661 +01662 +01663 +01664 +01665 +01666 +01667 +01668 +01669 +01670 +01671 +01672 +01673 +01674 +01675 +01676 +01677 +01678 +01679 +01680 +01681 +01682 +01683 +01684 +01685 +01686 +01687 +01688 +01689 +01690 +01691 +01692 +01693 +01694 +01695 +01696 +01697 +01698 +01699 +01700 +01701 +01702 +01703 +01704 +01705 +01706 +01707 +01708 +01709 +01710 +01711 +01712 +01713 +01714 +01715 +01716 +01717 +01718 +01719 +01720 +01721 +01722 +01723 +01724 +01725 +01726 +01727 +01728 +01729 +01730 +01731 +01732 +01733 +01734 +01735 +01736 +01737 +01738 +01739 +01740 +01741 +01742 +01743 +01744 +01745 +01746 +01747 +01748 +01749 +01750 +01751 +01752 +01753 +01754 +01755 +01756 +01757 +01758 +01759 +01760 +01761 +01762 +01763 +01764 +01765 +01766 +01767 +01768 +01769 +01770 +01771 +01772 +01773 +01774 +01775 +01776 +01777 +01778 +01779 +01780 +01781 +01782 +01783 +01784 +01785 +01786 +01787 +01788 +01789 +01790 +01791 +01792 +01793 +01794 +01795 +01796 +01797 +01798 +01799 +01800 +01801 +01802 +01803 +01804 +01805 +01806 +01807 +01808 +01809 +01810 +01811 +01812 +01813 +01814 +01815 +01816 +01817 +01818 +01819 +01820 +01821 +01822 +01823 +01824 +01825 +01826 +01827 +01828 +01829 +01830 +01831 +01832 +01833 +01834 +01835 +01836 +01837 +01838 +01839 +01840 +01841 +01842 +01843 +01844 +01845 +01846 +01847 +01848 +01849 +01850 +01851 +01852 +01853 +01854 +01855 +01856 +01857 +01858 +01859 +01860 +01861 +01862 +01863 +01864 +01865 +01866 +01867 +01868 +01869 +01870 +01871 +01872 +01873 +01874 +01875 +01876 +01877 +01878 +01879 +01880 +01881 +01882 +01883 +01884 +01885 +01886 +01887 +01888 +01889 +01890 +01891 +01892 +01893 +01894 +01895 +01896 +01897 +01898 +01899 +01900 +01901 +01902 +01903 +01904 +01905 +01906 +01907 +01908 +01909 +01910 +01911 +01912 +01913 +01914 +01915 +01916 +01917 +01918 +01919 +01920 +01921 +01922 +01923 +01924 +01925 +01926 +01927 +01928 +01929 +01930 +01931 +01932 +01933 +01934 +01935 +01936 +01937 +01938 +01939 +01940 +01941 +01942 +01943 +01944 +01945 +01946 +01947 +01948 +01949 +01950 +01951 +01952 +01953 +01954 +01955 +01956 +01957 +01958 +01959 +01960 +01961 +01962 +01963 +01964 +01965 +01966 +01967 +01968 +01969 +01970 +01971 +01972 +01973 +01974 +01975 +01976 +01977 +01978 +01979 +01980 +01981 +01982 +01983 +01984 +01985 +01986 +01987 +01988 +01989 +01990 +01991 +01992 +01993 +01994 +01995 +01996 +01997 +01998 +01999 +02000 +02001 +02002 +02003 +02004 +02005 +02006 +02007 +02008 +02009 +02010 +02011 +02012 +02013 +02014 +02015 +02016 +02017 +02018 +02019 +02020 +02021 +02022 +02023 +02024 +02025 +02026 +02027 +02028 +02029 +02030 +02031 +02032 +02033 +02034 +02035 +02036 +02037 +02038 +02039 +02040 +02041 +02042 +02043 +02044 +02045 +02046 +02047 +02048 +02049 +02050 +02051 +02052 +02053 +02054 +02055 +02056 +02057 +02058 +02059 +02060 +02061 +02062 +02063 +02064 +02065 +02066 +02067 +02068 +02069 +02070 +02071 +02072 +02073 +02074 +02075 +02076 +02077 +02078 +02079 +02080 +02081 +02082 +02083 +02084 +02085 +02086 +02087 +02088 +02089 +02090 +02091 +02092 +02093 +02094 +02095 +02096 +02097 +02098 +02099 +02100 +02101 +02102 +02103 +02104 +02105 +02106 +02107 +02108 +02109 +02110 +02111 +02112 +02113 +02114 +02115 +02116 +02117 +02118 +02119 +02120 +02121 +02122 +02123 +02124 +02125 +02126 +02127 +02128 +02129 +02130 +02131 +02132 +02133 +02134 +02135 +02136 +02137 +02138 +02139 +02140 +02141 +02142 +02143 +02144 +02145 +02146 +02147 +02148 +02149 +02150 +02151 +02152 +02153 +02154 +02155 +02156 +02157 +02158 +02159 +02160 +02161 +02162 +02163 +02164 +02165 +02166 +02167 +02168 +02169 +02170 +02171 +02172 +02173 +02174 +02175 +02176 +02177 +02178 +02179 +02180 +02181 +02182 +02183 +02184 +02185 +02186 +02187 +02188 +02189 +02190 +02191 +02192 +02193 +02194 +02195 +02196 +02197 +02198 +02199 +02200 +02201 +02202 +02203 +02204 +02205 +02206 +02207 +02208 +02209 +02210 +02211 +02212 +02213 +02214 +02215 +02216 +02217 +02218 +02219 +02220 +02221 +02222 +02223 +02224 +02225 +02226 +02227 +02228 +02229 +02230 +02231 +02232 +02233 +02234 +02235 +02236 +02237 +02238 +02239 +02240 +02241 +02242 +02243 +02244 +02245 +02246 +02247 +02248 +02249 +02250 +02251 +02252 +02253 +02254 +02255 +02256 +02257 +02258 +02259 +02260 +02261 +02262 +02263 +02264 +02265 +02266 +02267 +02268 +02269 +02270 +02271 +02272 +02273 +02274 +02275 +02276 +02277 +02278 +02279 +02280 +02281 +02282 +02283 +02284 +02285 +02286 +02287 +02288 +02289 +02290 +02291 +02292 +02293 +02294 +02295 +02296 +02297 +02298 +02299 +02300 +02301 +02302 +02303 +02304 +02305 +02306 +02307 +02308 +02309 +02310 +02311 +02312 +02313 +02314 +02315 +02316 +02317 +02318 +02319 +02320 +02321 +02322 +02323 +02324 +02325 +02326 +02327 +02328 +02329 +02330 +02331 +02332 +02333 +02334 +02335 +02336 +02337 +02338 +02339 diff --git a/VOCdevkit1/VOC2007/ImageSets/Segmentation/val.txt b/VOCdevkit1/VOC2007/ImageSets/Segmentation/val.txt new file mode 100644 index 0000000..d206f87 --- /dev/null +++ b/VOCdevkit1/VOC2007/ImageSets/Segmentation/val.txt @@ -0,0 +1,234 @@ +00009 +00017 +00027 +00033 +00042 +00051 +00070 +00082 +00089 +00092 +00102 +00111 +00117 +00130 +00131 +00133 +00137 +00138 +00145 +00162 +00171 +00185 +00189 +00191 +00203 +00208 +00211 +00218 +00223 +00233 +00234 +00249 +00255 +00263 +00278 +00297 +00319 +00320 +00322 +00327 +00329 +00373 +00385 +00395 +00411 +00417 +00422 +00429 +00438 +00453 +00457 +00470 +00493 +00497 +00498 +00516 +00530 +00588 +00595 +00601 +00604 +00616 +00635 +00636 +00669 +00696 +00701 +00709 +00716 +00724 +00732 +00741 +00748 +00756 +00757 +00780 +00789 +00812 +00837 +00850 +00858 +00867 +00869 +00876 +00879 +00884 +00887 +00901 +00906 +00910 +00911 +00926 +00965 +00979 +00986 +00987 +00991 +01001 +01014 +01026 +01035 +01048 +01050 +01069 +01076 +01083 +01092 +01093 +01113 +01124 +01126 +01131 +01140 +01142 +01159 +01161 +01168 +01178 +01185 +01190 +01193 +01203 +01219 +01224 +01254 +01258 +01259 +01261 +01267 +01271 +01279 +01284 +01291 +01313 +01314 +01320 +01324 +01325 +01352 +01364 +01396 +01413 +01421 +01423 +01434 +01436 +01447 +01454 +01465 +01469 +01472 +01480 +01489 +01495 +01511 +01517 +01529 +01534 +01541 +01545 +01553 +01554 +01595 +01597 +01618 +01625 +01629 +01630 +01631 +01640 +01641 +01670 +01684 +01691 +01750 +01757 +01773 +01777 +01779 +01782 +01787 +01788 +01793 +01795 +01798 +01816 +01819 +01821 +01833 +01862 +01881 +01892 +01902 +01938 +01939 +01940 +01948 +01950 +01955 +01964 +01990 +02014 +02019 +02049 +02061 +02066 +02068 +02082 +02084 +02090 +02091 +02093 +02095 +02118 +02121 +02133 +02157 +02161 +02173 +02183 +02200 +02203 +02204 +02212 +02221 +02222 +02229 +02230 +02241 +02253 +02263 +02279 +02319 +02332 diff --git a/deeplab.py b/deeplab.py new file mode 100644 index 0000000..d0cf30d --- /dev/null +++ b/deeplab.py @@ -0,0 +1,397 @@ +import colorsys +import copy +import time + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch import nn + +from nets.deeplabv3_plus import DeepLab +from utils.utils import cvtColor, preprocess_input, resize_image, show_config + + +#-----------------------------------------------------------------------------------# +# 使用自己训练好的模型预测需要修改3个参数 +# model_path、backbone和num_classes都需要修改! +# 如果出现shape不匹配,一定要注意训练时的model_path、backbone和num_classes的修改 +#-----------------------------------------------------------------------------------# +class DeeplabV3(object): + _defaults = { + #-------------------------------------------------------------------# + # model_path指向logs文件夹下的权值文件 + # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 + # 验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。 + #-------------------------------------------------------------------# + "model_path" : 'model_data/last_epoch_weights1.pth', + #----------------------------------------# + # 所需要区分的类的个数+1 + #----------------------------------------# + "num_classes" : 46, + #----------------------------------------# + # 所使用的的主干网络: + # mobilenet + # xception + #----------------------------------------# + "backbone" : "mobilenet", + #----------------------------------------# + # 输入图片的大小 + #----------------------------------------# + "input_shape" : [1024, 1042], + #----------------------------------------# + # 下采样的倍数,一般可选的为8和16 + # 与训练时设置的一样即可 + #----------------------------------------# + "downsample_factor" : 16, + #-------------------------------------------------# + # mix_type参数用于控制检测结果的可视化方式 + # + # mix_type = 0的时候代表原图与生成的图进行混合 + # mix_type = 1的时候代表仅保留生成的图 + # mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标 + #-------------------------------------------------# + "mix_type" : 0, + #-------------------------------# + # 是否使用Cuda + # 没有GPU可以设置成False + #-------------------------------# + "cuda" : True, + } + + #---------------------------------------------------# + # 初始化Deeplab + #---------------------------------------------------# + def __init__(self, **kwargs): + self.__dict__.update(self._defaults) + for name, value in kwargs.items(): + setattr(self, name, value) + #---------------------------------------------------# + # 画框设置不同的颜色 + #---------------------------------------------------# + if self.num_classes <= 46: + self.colors = [ (0, 0, 0), + (128, 0, 0), + (0, 128, 0), + (128, 128, 0), + (0, 0, 128), + (128, 0, 128), + (0, 128, 128), + (128, 128, 128), + (64, 0, 0), + (192, 0, 0), + (64, 128, 0), + (192, 128, 0), + (64, 0, 128), + (192, 0, 128), + (64, 128, 128), + (192, 128, 128), + (0, 64, 0), + (128, 64, 0), + (0, 192, 0), + (128, 192, 0), + (0, 64, 128), + (128, 64, 12), + (0, 0, 142), + (119, 11, 32), + (244,164,140), + (188,143,143), + (64,224,205), + (127,255,0), + (199,97,20), + (189,252,201), + (0,255,127), + (160,32,240), + (138,42,226), + (255,97,0), + (255,215,0), + (255,128,0), + (189,252,201), + (240,255,240), + (0, 130, 180), + (152, 251, 152), + (107, 142, 35), + (153, 153, 153), + (190, 153, 153), + (250, 170, 30), + (220, 220, 0), + (107, 142, 35), + ] + + else: + hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)] + self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) + self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) + #---------------------------------------------------# + # 获得模型 + #---------------------------------------------------# + self.generate() + + show_config(**self._defaults) + + #---------------------------------------------------# + # 获得所有的分类 + #---------------------------------------------------# + def generate(self, onnx=False): + #-------------------------------# + # 载入模型与权值 + #-------------------------------# + self.net = DeepLab(num_classes=self.num_classes, backbone=self.backbone, downsample_factor=self.downsample_factor, pretrained=False) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.net.load_state_dict(torch.load(self.model_path, map_location=device)) + self.net = self.net.eval() + print('{} model, and classes loaded.'.format(self.model_path)) + if not onnx: + if self.cuda: + self.net = nn.DataParallel(self.net) + self.net = self.net.cuda() + + #---------------------------------------------------# + # 检测图片 + #---------------------------------------------------# + def detect_image(self, image, count=False, name_classes=None): + #---------------------------------------------------------# + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 + # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB + #---------------------------------------------------------# + image = cvtColor(image) + #---------------------------------------------------# + # 对输入图像进行一个备份,后面用于绘图 + #---------------------------------------------------# + old_img = copy.deepcopy(image) + orininal_h = np.array(image).shape[0] + orininal_w = np.array(image).shape[1] + #---------------------------------------------------------# + # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 + #---------------------------------------------------------# + image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) + #---------------------------------------------------------# + # 添加上batch_size维度 + #---------------------------------------------------------# + image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) + + with torch.no_grad(): + images = torch.from_numpy(image_data) + if self.cuda: + images = images.cuda() + + #---------------------------------------------------# + # 图片传入网络进行预测 + #---------------------------------------------------# + pr = self.net(images)[0] + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() + #--------------------------------------# + # 将灰条部分截取掉 + #--------------------------------------# + pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), + int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] + #---------------------------------------------------# + # 进行图片的resize + #---------------------------------------------------# + pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = pr.argmax(axis=-1) + + #---------------------------------------------------------# + # 计数 + #---------------------------------------------------------# + if count: + classes_nums = np.zeros([self.num_classes]) + total_points_num = orininal_h * orininal_w + print('-' * 63) + print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio")) + print('-' * 63) + for i in range(self.num_classes): + num = np.sum(pr == i) + ratio = num / total_points_num * 100 + if num > 0: + print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio)) + print('-' * 63) + classes_nums[i] = num + print("classes_nums:", classes_nums) + + if self.mix_type == 0: + # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) + # for c in range(self.num_classes): + # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') + # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') + # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') + seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) + #------------------------------------------------# + # 将新图片转换成Image的形式 + #------------------------------------------------# + image = Image.fromarray(np.uint8(seg_img)) + #------------------------------------------------# + # 将新图与原图及进行混合 + #------------------------------------------------# + image = Image.blend(old_img, image, 0.7) + + elif self.mix_type == 1: + # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) + # for c in range(self.num_classes): + # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') + # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') + # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') + seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) + #------------------------------------------------# + # 将新图片转换成Image的形式 + #------------------------------------------------# + image = Image.fromarray(np.uint8(seg_img)) + + elif self.mix_type == 2: + seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8') + #------------------------------------------------# + # 将新图片转换成Image的形式 + #------------------------------------------------# + image = Image.fromarray(np.uint8(seg_img)) + + return image + + def get_FPS(self, image, test_interval): + #---------------------------------------------------------# + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 + # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB + #---------------------------------------------------------# + image = cvtColor(image) + #---------------------------------------------------------# + # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 + #---------------------------------------------------------# + image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) + #---------------------------------------------------------# + # 添加上batch_size维度 + #---------------------------------------------------------# + image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) + + with torch.no_grad(): + images = torch.from_numpy(image_data) + if self.cuda: + images = images.cuda() + + #---------------------------------------------------# + # 图片传入网络进行预测 + #---------------------------------------------------# + pr = self.net(images)[0] + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1) + #--------------------------------------# + # 将灰条部分截取掉 + #--------------------------------------# + pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), + int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] + + t1 = time.time() + for _ in range(test_interval): + with torch.no_grad(): + #---------------------------------------------------# + # 图片传入网络进行预测 + #---------------------------------------------------# + pr = self.net(images)[0] + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1) + #--------------------------------------# + # 将灰条部分截取掉 + #--------------------------------------# + pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), + int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] + t2 = time.time() + tact_time = (t2 - t1) / test_interval + return tact_time + + def convert_to_onnx(self, simplify, model_path): + import onnx + self.generate(onnx=True) + + im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW + input_layer_names = ["images"] + output_layer_names = ["output"] + + # Export the model + print(f'Starting export with onnx {onnx.__version__}.') + torch.onnx.export(self.net, + im, + f = model_path, + verbose = False, + opset_version = 12, + training = torch.onnx.TrainingMode.EVAL, + do_constant_folding = True, + input_names = input_layer_names, + output_names = output_layer_names, + dynamic_axes = None) + + # Checks + model_onnx = onnx.load(model_path) # load onnx model + onnx.checker.check_model(model_onnx) # check onnx model + + # Simplify onnx + if simplify: + import onnxsim + print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.') + model_onnx, check = onnxsim.simplify( + model_onnx, + dynamic_input_shape=False, + input_shapes=None) + assert check, 'assert check failed' + onnx.save(model_onnx, model_path) + + print('Onnx model save as {}'.format(model_path)) + + def get_miou_png(self, image): + #---------------------------------------------------------# + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 + # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB + #---------------------------------------------------------# + image = cvtColor(image) + orininal_h = np.array(image).shape[0] + orininal_w = np.array(image).shape[1] + #---------------------------------------------------------# + # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 + #---------------------------------------------------------# + image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) + #---------------------------------------------------------# + # 添加上batch_size维度 + #---------------------------------------------------------# + image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) + + with torch.no_grad(): + images = torch.from_numpy(image_data) + if self.cuda: + images = images.cuda() + + #---------------------------------------------------# + # 图片传入网络进行预测 + #---------------------------------------------------# + pr = self.net(images)[0] + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() + #--------------------------------------# + # 将灰条部分截取掉 + #--------------------------------------# + pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), + int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] + #---------------------------------------------------# + # 进行图片的resize + #---------------------------------------------------# + pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = pr.argmax(axis=-1) + + image = Image.fromarray(np.uint8(pr)) + return image diff --git a/get_miou.py b/get_miou.py new file mode 100644 index 0000000..cf476ae --- /dev/null +++ b/get_miou.py @@ -0,0 +1,62 @@ +import os + +from PIL import Image +from tqdm import tqdm + +from deeplab import DeeplabV3 +from utils.utils_metrics import compute_mIoU, show_results + +''' +进行指标评估需要注意以下几点: +1、该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。 +2、该文件计算的是验证集的miou,当前该库将测试集当作验证集使用,不单独划分测试集 +''' +if __name__ == "__main__": + #---------------------------------------------------------------------------# + # miou_mode用于指定该文件运行时计算的内容 + # miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。 + # miou_mode为1代表仅仅获得预测结果。 + # miou_mode为2代表仅仅计算miou。 + #---------------------------------------------------------------------------# + miou_mode = 0 + #------------------------------# + # 分类个数+1、如2+1 + #------------------------------# + num_classes = 47 + #--------------------------------------------# + # 区分的种类,和json_to_dataset里面的一样 + #--------------------------------------------# + name_classes = ["_background_", "pl5", "pl20", "pl30", "pl40", "pl50", "pl60", "pl70", "pl80", "pl100", "pl120", "pm20", "pm55","pr40","p11", "pn", "pne", "p26", "i2", "i4", "i5", "ip", "il60", "il80", "il100", "p5", "p10", "p23", "p3", "pg", "p19", "p12", "p6", "p27", "ph4", "ph4.5", "ph5", "pm30", "w55", "w59", "w13", "w57", "w32", "wo", "io", "po", "indicative"] + # name_classes = ["_background_","cat","dog"] + #-------------------------------------------------------# + # 指向VOC数据集所在的文件夹 + # 默认指向根目录下的VOC数据集 + #-------------------------------------------------------# + VOCdevkit_path = 'VOCdevkit' + + image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines() + gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/") + miou_out_path = "miou_out" + pred_dir = os.path.join(miou_out_path, 'detection-results') + + if miou_mode == 0 or miou_mode == 1: + if not os.path.exists(pred_dir): + os.makedirs(pred_dir) + + print("Load model.") + deeplab = DeeplabV3() + print("Load model done.") + + print("Get predict result.") + for image_id in tqdm(image_ids): + image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") + image = Image.open(image_path) + image = deeplab.get_miou_png(image) + image.save(os.path.join(pred_dir, image_id + ".png")) + print("Get predict result done.") + + if miou_mode == 0 or miou_mode == 2: + print("Get miou.") + hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数 + print("Get miou done.") + show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes) \ No newline at end of file diff --git a/json_to_dataset.py b/json_to_dataset.py new file mode 100644 index 0000000..4b84bc0 --- /dev/null +++ b/json_to_dataset.py @@ -0,0 +1,73 @@ +import base64 +import json +import os +import os.path as osp + +import numpy as np +import PIL.Image +from labelme import utils + +''' +制作自己的语义分割数据集需要注意以下几点: +1、我使用的labelme版本是3.16.7,建议使用该版本的labelme,有些版本的labelme会发生错误, + 具体错误为:Too many dimensions: 3 > 2 + 安装方式为命令行pip install labelme==3.16.7 +2、此处生成的标签图是8位彩色图,与视频中看起来的数据集格式不太一样。 + 虽然看起来是彩图,但事实上只有8位,此时每个像素点的值就是这个像素点所属的种类。 + 所以其实和视频中VOC数据集的格式一样。因此这样制作出来的数据集是可以正常使用的。也是正常的。 +''' +if __name__ == '__main__': + jpgs_path = "datasets/JPEGImages" + pngs_path = "datasets/SegmentationClass" + # classes = ["_background_", "pl5", "pl20", "pl30", "pl40", "pl50", "pl60", "pl70", "pl80", "pl100", "pl120", "pm20", "pm55","pr40","p11", "pn", "pne", "p26", "i2", "i4", "i5", "ip", "il60", "il80", "il100", "p5", "p10", "p23", "p3", "pg", "p19", "p12", "p6", "p27", "ph4", "ph4.5", "ph5", "pm30", "w55", "w59", "w13", "w57", "w32", "wo", "io", "po", "indicative"] + # classes = ["_background_","cat","dog"] + classes = ["_background_", "Historical Village", "Hydraulic Facilities", "Historical Buildings", "Green", "Sky", + "Water", "Bare Land", + "Infrastructure", "Park Related", "Enclosure", "Garbage and Debris", "Electric Poles", + "Modern Architecture", "Hard Surface", + "Human Activities", "Identification", "Water Pollution"] + count = os.listdir("./datasets/before/") + for i in range(0, len(count)): + path = os.path.join("./datasets/before", count[i]) + + if os.path.isfile(path) and path.endswith('json'): + data = json.load(open(path)) + + if data['imageData']: + imageData = data['imageData'] + else: + imagePath = os.path.join(os.path.dirname(path), data['imagePath']) + with open(imagePath, 'rb') as f: + imageData = f.read() + imageData = base64.b64encode(imageData).decode('utf-8') + + img = utils.img_b64_to_arr(imageData) + label_name_to_value = {'_background_': 0} + for shape in data['shapes']: + label_name = shape['label'] + if label_name in label_name_to_value: + label_value = label_name_to_value[label_name] + else: + label_value = len(label_name_to_value) + label_name_to_value[label_name] = label_value + + # label_values must be dense + label_values, label_names = [], [] + for ln, lv in sorted(label_name_to_value.items(), key=lambda x: x[1]): + label_values.append(lv) + label_names.append(ln) + assert label_values == list(range(len(label_values))) + + lbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value) + + + PIL.Image.fromarray(img).save(osp.join(jpgs_path, count[i].split(".")[0]+'.jpg')) + + new = np.zeros([np.shape(img)[0],np.shape(img)[1]]) + for name in label_names: + index_json = label_names.index(name) + index_all = classes.index(name) + new = new + index_all*(np.array(lbl) == index_json) + + utils.lblsave(osp.join(pngs_path, count[i].split(".")[0]+'.png'), new) + print('Saved ' + count[i].split(".")[0] + '.jpg and ' + count[i].split(".")[0] + '.png') diff --git a/logs/loss_2023_04_16_14_40_55/epoch_miou.txt b/logs/loss_2023_04_16_14_40_55/epoch_miou.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/logs/loss_2023_04_16_14_40_55/epoch_miou.txt @@ -0,0 +1 @@ +0 diff --git a/logs/loss_2023_04_16_15_03_17/epoch_miou.txt b/logs/loss_2023_04_16_15_03_17/epoch_miou.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/logs/loss_2023_04_16_15_03_17/epoch_miou.txt @@ -0,0 +1 @@ +0 diff --git a/logs/loss_2023_05_05_11_03_08/epoch_miou.txt b/logs/loss_2023_05_05_11_03_08/epoch_miou.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/logs/loss_2023_05_05_11_03_08/epoch_miou.txt @@ -0,0 +1 @@ +0 diff --git a/logs/loss_2023_05_06_10_18_49/epoch_miou.txt b/logs/loss_2023_05_06_10_18_49/epoch_miou.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/logs/loss_2023_05_06_10_18_49/epoch_miou.txt @@ -0,0 +1 @@ +0 diff --git a/logs/loss_2023_05_06_10_21_47/epoch_miou.txt b/logs/loss_2023_05_06_10_21_47/epoch_miou.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/logs/loss_2023_05_06_10_21_47/epoch_miou.txt @@ -0,0 +1 @@ +0 diff --git a/logs/loss_2023_05_06_10_22_54/epoch_miou.txt b/logs/loss_2023_05_06_10_22_54/epoch_miou.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/logs/loss_2023_05_06_10_22_54/epoch_miou.txt @@ -0,0 +1 @@ +0 diff --git a/logs/loss_2023_05_06_10_23_46/epoch_loss.txt b/logs/loss_2023_05_06_10_23_46/epoch_loss.txt new file mode 100644 index 0000000..a18b29a --- /dev/null +++ b/logs/loss_2023_05_06_10_23_46/epoch_loss.txt @@ -0,0 +1,300 @@ +0.6197003625021688 +0.039477425878832094 +0.021815583949108316 +0.017008269736862 +0.014811055252271856 +0.013597223319878143 +0.012746048395383834 +0.012807543290346975 +0.012045896844938397 +0.010721664282603068 +0.010831192051265884 +0.011085463909984711 +0.009701294567747602 +0.010043841666726134 +0.00960726743146154 +0.009713078478967395 +0.009528612792350268 +0.009614691838530023 +0.008979973745330489 +0.009033944955301035 +0.008963118615372552 +0.008663091998451704 +0.008887246418425622 +0.008487651522090736 +0.008931122144361951 +0.007925359707966645 +0.008151701880360426 +0.008162347928748356 +0.00812305118924028 +0.008164252225352104 +0.008294204591699087 +0.007984961868182334 +0.00793478730888931 +0.008040947269492503 +0.00789135160055649 +0.007976733407487264 +0.00797758857762094 +0.008167726175413385 +0.007854951135265312 +0.008030839253769401 +0.007845229877357304 +0.008030030800263351 +0.007429942826407575 +0.007581196240230885 +0.0074647688961980915 +0.00780516351261548 +0.007857467600289634 +0.007772660604855152 +0.007275923568677483 +0.0074134994388519015 +0.007543342298661632 +0.00698753386273796 +0.0072625635161698435 +0.007725435333457481 +0.0073359590019329585 +0.007243735706008716 +0.0073306868472508375 +0.007583740644047157 +0.007169632133567429 +0.0070728671461690795 +0.007072581034932753 +0.007413626354570171 +0.007152383307795334 +0.007606737600274724 +0.007405726957496099 +0.006981491556075957 +0.006948879280768533 +0.006812293091328538 +0.0068052408720115235 +0.007279976445494552 +0.007174421689196883 +0.006878163205626148 +0.00700332477561308 +0.006634817236924664 +0.006955325338784109 +0.006905557131819682 +0.007238527825420914 +0.007335566426007636 +0.007129638260602243 +0.007070887029107249 +0.00687503706913498 +0.0070434755843996775 +0.006843726146265071 +0.007250396561801802 +0.00696170583549717 +0.006585396268093399 +0.0068075423090835545 +0.007294489346982862 +0.006918249383104873 +0.00639713606566984 +0.00699184124159694 +0.006749085086918375 +0.007148223049795262 +0.006741083779517462 +0.006767992923981705 +0.006374543556189271 +0.006559330907636127 +0.006634276142492607 +0.006561781389961026 +0.006353535225965479 +0.006518312053092731 +0.00663000394097979 +0.006518688552194902 +0.006821648500144425 +0.006526681042131136 +0.006539963696598157 +0.006351140145825031 +0.006738816389279306 +0.006629471270944711 +0.006451651816054498 +0.006166263602535902 +0.006461152977959922 +0.006396989321943458 +0.006274158426758766 +0.006374230302477206 +0.0063116708989503015 +0.006304902248987538 +0.006369911974823605 +0.006575406290203007 +0.006191673439382062 +0.006340822853863296 +0.006553388611031094 +0.006342110292677955 +0.006331088343642283 +0.006240377007866767 +0.006275904750426657 +0.006147952992230045 +0.006477536654857628 +0.006378386527770854 +0.006447824759067111 +0.006113713274323895 +0.006700844290584836 +0.006149753932610948 +0.006174205401345748 +0.0064192439733558255 +0.006562992551951836 +0.0065121324476166255 +0.006085073481536977 +0.0063053527178540255 +0.006172239070867999 +0.006482244845472645 +0.006249061088031689 +0.006072588164904668 +0.005963505221399773 +0.006541932343873603 +0.005978238326231789 +0.006068465054679676 +0.006090859371722594 +0.006554332840656543 +0.006240069931691704 +0.0061388542280734265 +0.006210481463364917 +0.006177722078928646 +0.006109445868856752 +0.006113338514369251 +0.00638589895411816 +0.006318557595819971 +0.006174334730800709 +0.00610174891138768 +0.0063077077161549045 +0.006507697034478471 +0.006079476892845458 +0.006291852342832423 +0.006102473914141664 +0.0063502817660810945 +0.00603547172391603 +0.006014683098840339 +0.006272724872993256 +0.006180851822081571 +0.00613947869889339 +0.006143935481104079 +0.006179913735909551 +0.005999207273550586 +0.0060072671179394075 +0.006093241513384499 +0.006308488499111153 +0.006319105360894234 +0.006015585906743408 +0.006270462099530029 +0.0059296336665062566 +0.006237312077948321 +0.0061654297226841626 +0.006195027984062585 +0.005816533096124039 +0.006039109105375301 +0.006196394616723797 +0.00588476053879954 +0.006272064002244928 +0.006258616158862673 +0.006027525825560546 +0.006361286404645318 +0.006194578015826246 +0.0062392738952297905 +0.00614226611548886 +0.0060234401426402 +0.006119346779509104 +0.005998238505080285 +0.006297375983381538 +0.006022227832873993 +0.006159367399226463 +0.005989471511975526 +0.006161844625205028 +0.006265990433421255 +0.006152094558267341 +0.0058916246588430424 +0.006321959443631168 +0.00588216587521327 +0.00605715124025948 +0.006178620813180738 +0.005865271184186207 +0.006187718112850314 +0.0062731528347499755 +0.006187547586978671 +0.00586931534844518 +0.005711228119744122 +0.005954130673993506 +0.005987088050349782 +0.0062855214961475125 +0.0058403364900577485 +0.005747506121428464 +0.005805430818081766 +0.005949270271894454 +0.0059699209237882 +0.006146733909349591 +0.005705802928047262 +0.006105377703122511 +0.006335743928696436 +0.0058695421086716335 +0.006284519915819565 +0.006223392199650548 +0.006336943062284373 +0.005775144907996455 +0.006208539449524829 +0.00656147331872268 +0.0061425472279486495 +0.005778139089269585 +0.005739478614825769 +0.006324212335513997 +0.006061715773703396 +0.006094571465481172 +0.005793110492724231 +0.0064889906950741155 +0.006320204451226916 +0.006033010395268552 +0.006268571268150959 +0.005793830825586028 +0.005928737658940029 +0.0059360746009197414 +0.005921123769778772 +0.00580074743920151 +0.00600041084331478 +0.006065549175325823 +0.006224441010958228 +0.0059621825715980206 +0.006038792652233784 +0.006213274846356291 +0.005790815653258177 +0.006297106610720857 +0.00596215214027878 +0.005580317692923229 +0.005961467344189919 +0.0058768165754614675 +0.005721380075832878 +0.006118272315664748 +0.005865733843905844 +0.006051895516736879 +0.005892398204573009 +0.006239523179165155 +0.0064091585884614366 +0.00580754133513261 +0.005925991370594094 +0.005758367213261218 +0.005780674416010583 +0.005978336317611667 +0.006500314771969707 +0.006062155688170896 +0.005737047866668495 +0.006051646931957411 +0.006021980424641769 +0.006063678167749748 +0.006224919216328445 +0.006066763688130269 +0.005888556390947631 +0.005921683000655625 +0.006041925482163312 +0.0060582415669087055 +0.005892879506679316 +0.006194286557452704 +0.00605125249692809 +0.006066604279700539 +0.006147222191506126 +0.00621423757893561 +0.006306637827919605 +0.005826810339444491 +0.005905158232150337 +0.005995549125678327 +0.006102899812071623 +0.006010078881714961 +0.006307000559028734 +0.005901990987328607 diff --git a/logs/loss_2023_05_06_10_23_46/epoch_miou.txt b/logs/loss_2023_05_06_10_23_46/epoch_miou.txt new file mode 100644 index 0000000..930fc30 --- /dev/null +++ b/logs/loss_2023_05_06_10_23_46/epoch_miou.txt @@ -0,0 +1,61 @@ +0 +2.167944609686367 +2.167944609686367 +2.167944609686367 +2.5163406635704697 +2.6689392280869644 +2.670944657433233 +2.703679533595729 +2.6839131514854766 +2.742344626635652 +2.702516516895955 +2.695090191212251 +2.735516962860369 +2.762820157003522 +2.803803187338103 +2.810150036825074 +2.785170587287734 +2.791059580646515 +2.7656206207612173 +2.750095742323767 +2.7199225496891963 +2.7440055010236017 +2.8082342409328964 +2.814418398352213 +2.8107268204411624 +2.7995198932005354 +2.8011092609779977 +2.813604539250568 +2.8117495424778216 +2.8551548916813605 +2.839891923746138 +2.830180601681894 +2.8565925775780086 +2.8434723293551856 +2.8561996859138605 +2.7949647514086173 +2.8332743978696695 +2.8262352776644186 +2.8354105541311005 +2.852989126642651 +2.8432223736426883 +2.8606846598024553 +2.861047605790151 +2.8600708076704904 +2.87727084635111 +2.851273289108007 +2.88282154652109 +2.8663933562737225 +2.8731297790316663 +2.87792437948789 +2.8820779746904246 +2.877737893574461 +2.901723775915173 +2.8798679126394533 +2.8967890362822537 +2.8869220510911777 +2.886316192013369 +2.8947920914409986 +2.8912124346752957 +2.8855903610238225 +2.8622060446293918 diff --git a/logs/loss_2023_05_06_10_23_46/epoch_val_loss.txt b/logs/loss_2023_05_06_10_23_46/epoch_val_loss.txt new file mode 100644 index 0000000..2d13500 --- /dev/null +++ b/logs/loss_2023_05_06_10_23_46/epoch_val_loss.txt @@ -0,0 +1,300 @@ +0.07189631975930312 +0.03571391471757971 +0.027422866623463302 +0.021274538132651098 +0.019434706291890348 +0.0181045015984825 +0.0172115751092547 +0.016695918306579877 +0.016110792691851485 +0.015716335768329686 +0.015145999877231902 +0.014976732689758828 +0.014389945386812604 +0.013972413999124848 +0.013699803758284142 +0.01339970230413922 +0.013178563901576502 +0.012990865278346786 +0.012821075763424923 +0.012758990862117759 +0.012705465907166744 +0.012287154946283534 +0.012462963366174492 +0.012506939679512689 +0.012573993242955927 +0.012206647625385687 +0.012096314859608638 +0.012331344660949605 +0.012048744427939427 +0.012072740058446753 +0.011987858901506868 +0.011968210740978348 +0.011763962387139427 +0.011847949594837325 +0.011792505705921814 +0.011772599175636625 +0.01168673218966558 +0.011574136976409575 +0.011669780023331786 +0.011759012158767417 +0.011602205529423624 +0.011422211994770271 +0.011385530939903753 +0.011649176711216569 +0.011735210436043041 +0.011452095663367674 +0.011474372008173117 +0.011400713970691994 +0.011594116976805803 +0.01162711290867421 +0.011419184762856057 +0.011850407154395663 +0.01142614646333045 +0.011498499998887038 +0.009839700976515124 +0.011730902996877658 +0.011111496353586173 +0.011294754430780122 +0.011366692901556862 +0.011263233063549831 +0.0110595160613543 +0.011138797135509807 +0.011222220363159632 +0.011109685597555905 +0.011056809348921323 +0.011218598698554882 +0.010987450376731055 +0.011019065277650952 +0.011006714190067402 +0.010988561950367072 +0.010986150756221393 +0.011355467606335878 +0.010989442997579944 +0.01098421217616776 +0.011198728754795316 +0.011363271527506155 +0.011044791474103414 +0.010770803253198492 +0.01112097201483517 +0.010638197153357083 +0.010831893212964823 +0.010716704378739512 +0.010368685591323623 +0.010761569197899822 +0.010918065688797626 +0.010843201339694446 +0.010986347955747926 +0.010991036056958396 +0.011016894348672238 +0.01096486386136505 +0.010981223537939889 +0.010795045017810732 +0.010998073893053264 +0.010681337612713206 +0.010768676008065713 +0.010744555644562533 +0.010923682551445633 +0.010729674450603539 +0.01098158168767033 +0.01112275263936869 +0.010836047963399825 +0.010836834079939229 +0.010870976510040205 +0.010789354853653189 +0.01086848299821903 +0.010880841994015822 +0.010519333617312127 +0.010682976371126956 +0.01058533501104805 +0.010533275203137049 +0.010898187933168534 +0.01089936629708471 +0.010773020057842649 +0.010753131800748664 +0.010720436316754284 +0.010669665924947837 +0.010629855184655252 +0.010751658111232621 +0.010831298250383857 +0.01070010883669401 +0.010617443055299849 +0.01048437919435573 +0.010346002114037502 +0.010567514346270212 +0.010638203857273891 +0.010729511059692194 +0.010716044665153685 +0.010648740768625304 +0.010806642850090203 +0.010610603590913373 +0.010871732390710506 +0.010454296712474576 +0.01122654142693199 +0.01061759469227801 +0.010552432757384819 +0.010584884724611866 +0.010228246000820193 +0.010536318002589818 +0.010565901255427763 +0.010508602525203907 +0.010447882806304199 +0.010485738555997097 +0.010434291293394977 +0.010488350811446535 +0.01026412391039575 +0.010033505045455592 +0.010524095084261277 +0.01064964158235696 +0.010424483214215985 +0.010454660225338462 +0.010546684586282435 +0.010508425666244122 +0.010540762514775169 +0.010453253229758862 +0.010467348146605594 +0.010784732603371656 +0.010536088070286245 +0.010499941867551413 +0.010430208495657506 +0.009066679659460125 +0.009254690982272913 +0.010539520547950062 +0.010518653013197512 +0.01060755919376067 +0.010437106007132036 +0.010493663484872929 +0.01069221261021649 +0.010056320529688021 +0.01049118855936003 +0.01037861027851187 +0.010572533131223815 +0.010480090421785054 +0.010378657173814958 +0.010510338382024703 +0.010683570731559703 +0.010662932864165512 +0.010498907786376518 +0.01064154385300032 +0.010491117979560432 +0.010530859400550353 +0.010502434761167086 +0.01056128019338538 +0.010444778033757004 +0.010589376124071664 +0.01070712298829237 +0.010507027690844804 +0.01046000702852576 +0.010480330727095234 +0.010393654197004849 +0.010532267263223385 +0.010621749025223583 +0.01059725113084604 +0.010466801741256797 +0.010441247365790707 +0.010576903739751413 +0.010367687928072852 +0.010391760312406153 +0.010373844066634774 +0.010441387449551759 +0.01051324969639295 +0.010589498922164584 +0.010469761396083853 +0.01004613331390609 +0.010493891723132852 +0.010337132904357437 +0.010412660716422674 +0.010555118097570437 +0.010401249598262125 +0.010469417176048818 +0.010312642388302705 +0.010469515470723653 +0.010319499600004277 +0.01052881558907443 +0.010495687826889855 +0.010476063182256344 +0.010430275382281378 +0.01042456973087171 +0.010463053858357257 +0.010540624470289412 +0.010455657615615377 +0.01044396341553536 +0.010395392269731081 +0.010523632361457265 +0.010493589950532749 +0.010407621699289 +0.010477926068264863 +0.010362225199310944 +0.010507392584635266 +0.010485237970113241 +0.010403798909151349 +0.010422620135520038 +0.01039113621388016 +0.010345617493871471 +0.010484577702551052 +0.010535608792420605 +0.009211804059430444 +0.010571935406789697 +0.010426660244963292 +0.010452914501315561 +0.010462347955988913 +0.010413646770255833 +0.010408702667738343 +0.010426012686742791 +0.010548672494317951 +0.010507516892113048 +0.010526582253454575 +0.01042245962689149 +0.01042446452353535 +0.010473434427945778 +0.010495701796728475 +0.010512662473423728 +0.010494389996767557 +0.010509672984159711 +0.010496956135692268 +0.010419984805750949 +0.008929582906822706 +0.010441352597212997 +0.010464571400320736 +0.010399415591282064 +0.010455412115773251 +0.01057258246723434 +0.0105080163539869 +0.010404359782114625 +0.010481919033517098 +0.01050947531094325 +0.010400464565589511 +0.010359236191765502 +0.010454744012255606 +0.010623272428096369 +0.010368483839556575 +0.010493359793427175 +0.009050809256411317 +0.01053778325400219 +0.010499262250959873 +0.01051076638094824 +0.010474678538418535 +0.010002020031920281 +0.010449292515565095 +0.01030946695566948 +0.010532466622454852 +0.010416264388838718 +0.010429385599905047 +0.01049036902760894 +0.010462773698863798 +0.01032512606089485 +0.010482839268535889 +0.010443417997709635 +0.010451143398752501 +0.010503844787559375 +0.00892035136448926 +0.010418647643307159 +0.010451512692239264 +0.010512368560865008 +0.01043968412479193 +0.010323734408051803 +0.010429111742896253 +0.010425899531049975 +0.010525008904394405 +0.010326696856846583 +0.010493756890104249 diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/nets/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/nets/__pycache__/__init__.cpython-38.pyc b/nets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..9e1e87e Binary files /dev/null and b/nets/__pycache__/__init__.cpython-38.pyc differ diff --git a/nets/__pycache__/__init__.cpython-39.pyc b/nets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..1dbd4e1 Binary files /dev/null and b/nets/__pycache__/__init__.cpython-39.pyc differ diff --git a/nets/__pycache__/deeplabv3_plus.cpython-38.pyc b/nets/__pycache__/deeplabv3_plus.cpython-38.pyc new file mode 100644 index 0000000..223ca40 Binary files /dev/null and b/nets/__pycache__/deeplabv3_plus.cpython-38.pyc differ diff --git a/nets/__pycache__/deeplabv3_plus.cpython-39.pyc b/nets/__pycache__/deeplabv3_plus.cpython-39.pyc new file mode 100644 index 0000000..d12b386 Binary files /dev/null and b/nets/__pycache__/deeplabv3_plus.cpython-39.pyc differ diff --git a/nets/__pycache__/deeplabv3_training.cpython-38.pyc b/nets/__pycache__/deeplabv3_training.cpython-38.pyc new file mode 100644 index 0000000..22c744b Binary files /dev/null and b/nets/__pycache__/deeplabv3_training.cpython-38.pyc differ diff --git a/nets/__pycache__/deeplabv3_training.cpython-39.pyc b/nets/__pycache__/deeplabv3_training.cpython-39.pyc new file mode 100644 index 0000000..7849397 Binary files /dev/null and b/nets/__pycache__/deeplabv3_training.cpython-39.pyc differ diff --git a/nets/__pycache__/mobilenetv2.cpython-38.pyc b/nets/__pycache__/mobilenetv2.cpython-38.pyc new file mode 100644 index 0000000..eb5b70b Binary files /dev/null and b/nets/__pycache__/mobilenetv2.cpython-38.pyc differ diff --git a/nets/__pycache__/mobilenetv2.cpython-39.pyc b/nets/__pycache__/mobilenetv2.cpython-39.pyc new file mode 100644 index 0000000..c4103ed Binary files /dev/null and b/nets/__pycache__/mobilenetv2.cpython-39.pyc differ diff --git a/nets/__pycache__/xception.cpython-38.pyc b/nets/__pycache__/xception.cpython-38.pyc new file mode 100644 index 0000000..5f853a7 Binary files /dev/null and b/nets/__pycache__/xception.cpython-38.pyc differ diff --git a/nets/__pycache__/xception.cpython-39.pyc b/nets/__pycache__/xception.cpython-39.pyc new file mode 100644 index 0000000..25a4544 Binary files /dev/null and b/nets/__pycache__/xception.cpython-39.pyc differ diff --git a/nets/deeplabv3_plus.py b/nets/deeplabv3_plus.py new file mode 100644 index 0000000..adb8608 --- /dev/null +++ b/nets/deeplabv3_plus.py @@ -0,0 +1,257 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from nets.xception import xception +from nets.mobilenetv2 import mobilenetv2 + + +class MobileNetV2(nn.Module): + def __init__(self, downsample_factor=8, pretrained=True): + super(MobileNetV2, self).__init__() + from functools import partial + + model = mobilenetv2(pretrained) + self.features = model.features[:-1] + + self.total_idx = len(self.features) + self.down_idx = [2, 4, 7, 14] + + if downsample_factor == 8: + for i in range(self.down_idx[-2], self.down_idx[-1]): + self.features[i].apply(partial(self._nostride_dilate, + dilate=2)) + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply(partial(self._nostride_dilate, + dilate=4)) + elif downsample_factor == 16: + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply(partial(self._nostride_dilate, + dilate=2)) + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate // 2, dilate // 2) + m.padding = (dilate // 2, dilate // 2) + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x): + low_level_features = self.features[:4](x) + x = self.features[4:](low_level_features) + return low_level_features, x + + +#-----------------------------------------# +# ASPP特征提取模块 +# 利用不同膨胀率的膨胀卷积进行特征提取 +#-----------------------------------------# +""" +卷积操作使用了膨胀率 dilation=6 * rate,这意味着卷积核内的采样点之间有 6 个像素的间隔。 +为了保持输出特征图的尺寸与输入数据相同,填充参数 padding 被设置为 padding=6 * rate。 +这样,卷积操作将在输入数据上以膨胀率为 6 的间隔进行卷积,并且在输出特征图的边缘周围填充 6 个像素的零,以确保输出特征图的尺寸不会缩小。 + + +bn_mom 在这段代码中是批量归一化层(Batch Normalization)的动量参数(momentum)。 +在 PyTorch 中的批量归一化层的实现中,动量参数控制了均值和方差的移动平均值的更新速度 + +dilation 是卷积操作的膨胀率参数 +""" +class ASPP(nn.Module): + def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): + super(ASPP, self).__init__() + self.branch1 = nn.Sequential( + nn.Conv2d(dim_in, + dim_out, + 1, + 1, + padding=0, + dilation=rate, + bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + self.branch2 = nn.Sequential( + nn.Conv2d(dim_in, + dim_out, + 3, + 1, + padding=6 * rate, + dilation=6 * rate, + bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + self.branch3 = nn.Sequential( + nn.Conv2d(dim_in, + dim_out, + 3, + 1, + padding=12 * rate, + dilation=12 * rate, + bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + self.branch4 = nn.Sequential( + nn.Conv2d(dim_in, + dim_out, + 3, + 1, + padding=18 * rate, + dilation=18 * rate, + bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + # branch5是一个池化过程 + self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) + self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) + self.branch5_relu = nn.ReLU(inplace=True) + + self.conv_cat = nn.Sequential( + nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + [b, c, row, col] = x.size() + #-----------------------------------------# + # 一共五个分支 + #-----------------------------------------# + conv1x1 = self.branch1(x) + conv3x3_1 = self.branch2(x) + conv3x3_2 = self.branch3(x) + conv3x3_3 = self.branch4(x) + #-----------------------------------------# + # 第五个分支,全局平均池化+卷积 + #-----------------------------------------# + global_feature = torch.mean(x, 2, True) + global_feature = torch.mean(global_feature, 3, True) #两次平均池化 + global_feature = self.branch5_conv(global_feature) + global_feature = self.branch5_bn(global_feature) + global_feature = self.branch5_relu(global_feature)# 通道的整合 + global_feature = F.interpolate(global_feature, (row, col), None, + 'bilinear', True) + """ + 这一行代码首先使用 torch.mean 函数计算了 x 沿第 2 维度(通常是垂直方向)的平均值。 + True 参数表示要保持维度的数量,即在结果中保留该维度。这将生成一个张量 + 其形状为 [batch_size, num_channels, 1, width], + 其中 batch_size 是批量大小,num_channels 是通道数,width 是特征图的宽度,高度为 1。 + global_feature = torch.mean(global_feature, 3, True): + 同样使用 torch.mean 函数,计算了上一步结果中的 global_feature 沿第 3 维度(通常是水平方向)的平均值, + 并再次保持维度的数量。这将生成一个形状为 [batch_size, num_channels, 1, 1] 的张量, + 其中每个通道的值都是整个特征图在该通道上的平均值。 + + 这个过程的结果是对整个特征图的每个通道执行全局平均池化, + 最终生成一个形状为 [batch_size, num_channels, 1, 1] 的全局平均特征向量。 + + 将全局平均特征 global_feature 调整为指定的目标尺寸 (row, col) + """ + #-----------------------------------------# + # 将五个分支的内容堆叠起来 + # 然后1x1卷积整合特征。 + #-----------------------------------------# + feature_cat = torch.cat( + [conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) + result = self.conv_cat(feature_cat) #1*1卷积 这个就是哪个绿色的 + return result +# ASPP就是使用不同膨胀率的膨胀卷积对特征进行提取 + +class DeepLab(nn.Module): + def __init__(self, + num_classes, + backbone="mobilenet", + pretrained=True, + downsample_factor=16): + super(DeepLab, self).__init__() + if backbone == "xception": + #----------------------------------# + # 获得两个特征层 + # 浅层特征 [128,128,256] + # 主干部分 [30,30,2048] + #----------------------------------# + self.backbone = xception(downsample_factor=downsample_factor, + pretrained=pretrained) + in_channels = 2048 + low_level_channels = 256 + elif backbone == "mobilenet": + #----------------------------------# + # 获得两个特征层 + # 浅层特征 [128,128,24] + # 主干部分 [30,30,320] + #----------------------------------# + self.backbone = MobileNetV2(downsample_factor=downsample_factor, + pretrained=pretrained) + in_channels = 320 + low_level_channels = 24 + else: + raise ValueError( + 'Unsupported backbone - `{}`, Use mobilenet, xception.'.format( + backbone)) + + #-----------------------------------------# + # ASPP特征提取模块 + # 利用不同膨胀率的膨胀卷积进行特征提取 + #-----------------------------------------# + self.aspp = ASPP(dim_in=in_channels, + dim_out=256, + rate=16 // downsample_factor) + + #----------------------------------# + # 浅层特征边 + #----------------------------------# + self.shortcut_conv = nn.Sequential( + nn.Conv2d(low_level_channels, 48, 1), + nn.BatchNorm2d(48), + nn.ReLU(inplace=True)) + + self.cat_conv = nn.Sequential( + nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + + nn.Conv2d(256, 256, 3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Dropout(0.1), + ) + self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1) + """ + 在每个训练批次中,以一定的概率(通常在 0.2 到 0.5 之间)随机选择一些神经元,并将它们的输出置零。 + 这表示在每次前向传播中,只有部分神经元的输出会被传递到下一层,而其他神经元的输出被设置为零。 + + 在每次反向传播中,只有那些没有被置零的神经元才会更新其权重。这意味着每个神经元都有机会被训练, + 而不是过度依赖于特定的神经元。 + """ + def forward(self, x): + H, W = x.size(2), x.size(3) + #-----------------------------------------# + # 获得两个特征层 + # low_level_features: 浅层特征-进行卷积处理 + # x : 主干部分-利用ASPP结构进行加强特征提取 + #-----------------------------------------# + low_level_features, x = self.backbone(x) + x = self.aspp(x) + low_level_features = self.shortcut_conv(low_level_features) + + #-----------------------------------------# + # 将加强特征边上采样 绿色的模块 + # 与浅层特征堆叠后利用卷积进行特征提取 + #-----------------------------------------# + x = F.interpolate(x, + size=(low_level_features.size(2), + low_level_features.size(3)), + mode='bilinear',#使用双线性插值进行的上采样操作 + align_corners=True) #上采样 + x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) + x = self.cls_conv(x) + x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + return x diff --git a/nets/deeplabv3_training.py b/nets/deeplabv3_training.py new file mode 100644 index 0000000..26d5cc1 --- /dev/null +++ b/nets/deeplabv3_training.py @@ -0,0 +1,113 @@ +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def CE_Loss(inputs, target, cls_weights, num_classes=21): + n, c, h, w = inputs.size() + nt, ht, wt = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) + temp_target = target.view(-1) + + CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target) + return CE_loss + +def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2): + n, c, h, w = inputs.size() + nt, ht, wt = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) + temp_target = target.view(-1) + + logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target) + pt = torch.exp(logpt) + if alpha is not None: + logpt *= alpha + loss = -((1 - pt) ** gamma) * logpt + loss = loss.mean() + return loss + +def Dice_loss(inputs, target, beta=1, smooth = 1e-5): + n, c, h, w = inputs.size() + nt, ht, wt, ct = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) + temp_target = target.view(n, -1, ct) + + #--------------------------------------------# + # 计算dice loss + #--------------------------------------------# + tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) + fp = torch.sum(temp_inputs , axis=[0,1]) - tp + fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp + + score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) + dice_loss = 1 - torch.mean(score) + return dice_loss + +def weights_init(net, init_type='normal', init_gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and classname.find('Conv') != -1: + if init_type == 'normal': + torch.nn.init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + print('initialize network with %s type' % init_type) + net.apply(init_func) + +def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.1, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.3, step_num = 10): + def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): + if iters <= warmup_total_iters: + # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start + lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start + elif iters >= total_iters - no_aug_iter: + lr = min_lr + else: + lr = min_lr + 0.5 * (lr - min_lr) * ( + 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) + ) + return lr + + def step_lr(lr, decay_rate, step_size, iters): + if step_size < 1: + raise ValueError("step_size must above 1.") + n = iters // step_size + out_lr = lr * decay_rate ** n + return out_lr + + if lr_decay_type == "cos": + warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) + warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) + no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) + func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) + else: + decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) + step_size = total_iters / step_num + func = partial(step_lr, lr, decay_rate, step_size) + + return func + +def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): + lr = lr_scheduler_func(epoch) + for param_group in optimizer.param_groups: + param_group['lr'] = lr diff --git a/nets/img.png b/nets/img.png new file mode 100644 index 0000000..3464b35 Binary files /dev/null and b/nets/img.png differ diff --git a/nets/img_1.png b/nets/img_1.png new file mode 100644 index 0000000..8daaef3 Binary files /dev/null and b/nets/img_1.png differ diff --git a/nets/mobilenetv2.py b/nets/mobilenetv2.py new file mode 100644 index 0000000..13ae643 --- /dev/null +++ b/nets/mobilenetv2.py @@ -0,0 +1,164 @@ +import math +import os + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +BatchNorm2d = nn.BatchNorm2d + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + #--------------------------------------------# + # 进行3x3的逐层卷积,进行跨特征点的特征提取 + #--------------------------------------------# + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + #-----------------------------------# + # 利用1x1卷积进行通道数的调整 + #-----------------------------------# + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + #-----------------------------------# + # 利用1x1卷积进行通道数的上升 + #-----------------------------------# + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + #--------------------------------------------# + # 进行3x3的逐层卷积,进行跨特征点的特征提取 + #--------------------------------------------# + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + #-----------------------------------# + # 利用1x1卷积进行通道数的下降 + #-----------------------------------# + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + +class MobileNetV2(nn.Module): + def __init__(self, n_class=1000, input_size=224, width_mult=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s t:expand_ratio 1*1卷积通道数上升,c:output_channel,n 是range(n)n次循环,s是步长 + [1, 16, 1, 1], + # 256, 256, 32 -> 256, 256, 16 + [6, 24, 2, 2], + # 256, 256, 16 -> 128, 128, 24 2 + [6, 32, 3, 2], + # 128, 128, 24 -> 64, 64, 32 4 + [6, 64, 4, 2], + # 64, 64, 32 -> 32, 32, 64 7 + [6, 96, 3, 1], + # 32, 32, 64 -> 32, 32, 96 + [6, 160, 3, 2], + # 32, 32, 96 -> 16, 16, 160 14 + [6, 320, 1, 1], + # 16, 16, 160 -> 16, 16, 320 + ] + + assert input_size % 32 == 0 + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel + # 512, 512, 3 -> 256, 256, 32 + self.features = [conv_bn(3, input_channel, 2)] + + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + if i == 0: + self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) + else: + self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + self.features = nn.Sequential(*self.features) + + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, n_class), + ) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def load_url(url, model_dir='./model_data', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if os.path.exists(cached_file): + return torch.load(cached_file, map_location=map_location) + else: + return model_zoo.load_url(url,model_dir=model_dir) + +def mobilenetv2(pretrained=False, **kwargs): + model = MobileNetV2(n_class=1000, **kwargs) + if pretrained: + model.load_state_dict(load_url('https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/mobilenet_v2.pth.tar'), strict=False) + return model + +if __name__ == "__main__": + model = mobilenetv2() + for i, layer in enumerate(model.features): + print(i, layer) diff --git a/nets/test_net.py b/nets/test_net.py new file mode 100644 index 0000000..f1c641b --- /dev/null +++ b/nets/test_net.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from nets.xception import xception +from nets.mobilenetv2 import mobilenetv2 + +class ASPP(nn.Module): + def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): + super(ASPP, self).__init__() + self.branch1 = nn.Sequential( + nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + self.branch2 = nn.Sequential( + nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + self.branch3 = nn.Sequential( + nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + self.branch4 = nn.Sequential( + nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), + nn.BatchNorm2d(dim_out, momentum=bn_mom), + nn.ReLU(inplace=True), + ) + self.branch5_conv = nn.Conv2d(dim_in, dim_out,1, 1, 0, bias=True), diff --git a/nets/xception.py b/nets/xception.py new file mode 100644 index 0000000..3f536c7 --- /dev/null +++ b/nets/xception.py @@ -0,0 +1,298 @@ +import math +import os +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +bn_mom = 0.0003 + + +class SeparableConv2d(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=False, + activate_first=True, + inplace=True): + super(SeparableConv2d, self).__init__() + self.relu0 = nn.ReLU(inplace=inplace) + self.depthwise = nn.Conv2d(in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + groups=in_channels, + bias=bias) + self.bn1 = nn.BatchNorm2d(in_channels, momentum=bn_mom) + self.relu1 = nn.ReLU(inplace=True) + self.pointwise = nn.Conv2d(in_channels, + out_channels, + 1, + 1, + 0, + 1, + 1, + bias=bias) + self.bn2 = nn.BatchNorm2d(out_channels, momentum=bn_mom) + self.relu2 = nn.ReLU(inplace=True) + self.activate_first = activate_first + + def forward(self, x): + if self.activate_first: + x = self.relu0(x) + x = self.depthwise(x) + x = self.bn1(x) + if not self.activate_first: + x = self.relu1(x) + x = self.pointwise(x) + x = self.bn2(x) + if not self.activate_first: + x = self.relu2(x) + return x + + +class Block(nn.Module): + def __init__(self, + in_filters, + out_filters, + strides=1, + atrous=None, + grow_first=True, + activate_first=True, + inplace=True): + super(Block, self).__init__() + if atrous is None: + atrous = [1] * 3 + elif isinstance(atrous, int): + atrous_list = [atrous] * 3 + atrous = atrous_list + idx = 0 + self.head_relu = True + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d(in_filters, + out_filters, + 1, + stride=strides, + bias=False) + self.skipbn = nn.BatchNorm2d(out_filters, momentum=bn_mom) + self.head_relu = False + else: + self.skip = None + + self.hook_layer = None + if grow_first: + filters = out_filters + else: + filters = in_filters + self.sepconv1 = SeparableConv2d(in_filters, + filters, + 3, + stride=1, + padding=1 * atrous[0], + dilation=atrous[0], + bias=False, + activate_first=activate_first, + inplace=self.head_relu) + self.sepconv2 = SeparableConv2d(filters, + out_filters, + 3, + stride=1, + padding=1 * atrous[1], + dilation=atrous[1], + bias=False, + activate_first=activate_first) + self.sepconv3 = SeparableConv2d(out_filters, + out_filters, + 3, + stride=strides, + padding=1 * atrous[2], + dilation=atrous[2], + bias=False, + activate_first=activate_first, + inplace=inplace) + + def forward(self, inp): + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x = self.sepconv1(inp) + x = self.sepconv2(x) + self.hook_layer = x + x = self.sepconv3(x) + + x += skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + def __init__(self, downsample_factor): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + + stride_list = None + if downsample_factor == 8: + stride_list = [2, 1, 1] + elif downsample_factor == 16: + stride_list = [2, 2, 1] + else: + raise ValueError( + 'xception.py: output stride=%d is not supported.' % os) + self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(32, momentum=bn_mom) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=bn_mom) + #do relu here + + self.block1 = Block(64, 128, 2) + self.block2 = Block(128, 256, stride_list[0], inplace=False) + self.block3 = Block(256, 728, stride_list[1]) + + rate = 16 // downsample_factor + self.block4 = Block(728, 728, 1, atrous=rate) + self.block5 = Block(728, 728, 1, atrous=rate) + self.block6 = Block(728, 728, 1, atrous=rate) + self.block7 = Block(728, 728, 1, atrous=rate) + + self.block8 = Block(728, 728, 1, atrous=rate) + self.block9 = Block(728, 728, 1, atrous=rate) + self.block10 = Block(728, 728, 1, atrous=rate) + self.block11 = Block(728, 728, 1, atrous=rate) + + self.block12 = Block(728, 728, 1, atrous=rate) + self.block13 = Block(728, 728, 1, atrous=rate) + self.block14 = Block(728, 728, 1, atrous=rate) + self.block15 = Block(728, 728, 1, atrous=rate) + + self.block16 = Block(728, + 728, + 1, + atrous=[1 * rate, 1 * rate, 1 * rate]) + self.block17 = Block(728, + 728, + 1, + atrous=[1 * rate, 1 * rate, 1 * rate]) + self.block18 = Block(728, + 728, + 1, + atrous=[1 * rate, 1 * rate, 1 * rate]) + self.block19 = Block(728, + 728, + 1, + atrous=[1 * rate, 1 * rate, 1 * rate]) + + self.block20 = Block(728, + 1024, + stride_list[2], + atrous=rate, + grow_first=False) + self.conv3 = SeparableConv2d(1024, + 1536, + 3, + 1, + 1 * rate, + dilation=rate, + activate_first=False) + + self.conv4 = SeparableConv2d(1536, + 1536, + 3, + 1, + 1 * rate, + dilation=rate, + activate_first=False) + + self.conv5 = SeparableConv2d(1536, + 2048, + 3, + 1, + 1 * rate, + dilation=rate, + activate_first=False) + self.layers = [] + + #------- init weights -------- + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + #----------------------------- + + def forward(self, input): + self.layers = [] + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + x = self.block2(x) + low_featrue_layer = self.block2.hook_layer + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + x = self.block13(x) + x = self.block14(x) + x = self.block15(x) + x = self.block16(x) + x = self.block17(x) + x = self.block18(x) + x = self.block19(x) + x = self.block20(x) + + x = self.conv3(x) + + x = self.conv4(x) + + x = self.conv5(x) + return low_featrue_layer, x + + +def load_url(url, model_dir='./model_data', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if os.path.exists(cached_file): + return torch.load(cached_file, map_location=map_location) + else: + return model_zoo.load_url(url, model_dir=model_dir) + + +def xception(pretrained=True, downsample_factor=16): + model = Xception(downsample_factor=downsample_factor) + if pretrained: + model.load_state_dict(load_url( + 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth' + ), + strict=False) + return model diff --git a/new_predict.py b/new_predict.py new file mode 100644 index 0000000..001311a --- /dev/null +++ b/new_predict.py @@ -0,0 +1,82 @@ +import os +import time +import json + +import torch +from torchvision import transforms +import numpy as np +from PIL import Image + +from src import deeplabv3_resnet50 + + +def time_synchronized(): + torch.cuda.synchronize() if torch.cuda.is_available() else None + return time.time() + + +def main(): + aux = False # inference time not need aux_classifier + classes = 20 + weights_path = "./save_weights/model_29.pth" + img_path = "./test.jpg" + palette_path = "./palette.json" + assert os.path.exists(weights_path), f"weights {weights_path} not found." + assert os.path.exists(img_path), f"image {img_path} not found." + assert os.path.exists(palette_path), f"palette {palette_path} not found." + with open(palette_path, "rb") as f: + pallette_dict = json.load(f) + pallette = [] + for v in pallette_dict.values(): + pallette += v + + # get devices + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print("using {} device.".format(device)) + + # create model + model = deeplabv3_resnet50(aux=aux, num_classes=classes+1) + + # delete weights about aux_classifier + weights_dict = torch.load(weights_path, map_location='cpu')['model'] + for k in list(weights_dict.keys()): + if "aux" in k: + del weights_dict[k] + + # load weights + model.load_state_dict(weights_dict) + model.to(device) + + # load image + original_img = Image.open(img_path) + + # from pil image to tensor and normalize + data_transform = transforms.Compose([transforms.Resize(520), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225))]) + img = data_transform(original_img) + # expand batch dimension + img = torch.unsqueeze(img, dim=0) + + model.eval() # 进入验证模式 + with torch.no_grad(): + # init model + img_height, img_width = img.shape[-2:] + init_img = torch.zeros((1, 3, img_height, img_width), device=device) + model(init_img) + + t_start = time_synchronized() + output = model(img.to(device)) + t_end = time_synchronized() + print("inference time: {}".format(t_end - t_start)) + + prediction = output['out'].argmax(1).squeeze(0) + prediction = prediction.to("cpu").numpy().astype(np.uint8) + mask = Image.fromarray(prediction) + mask.putpalette(pallette) + mask.save("test_result.png") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..70d92dc --- /dev/null +++ b/predict.py @@ -0,0 +1,165 @@ +#----------------------------------------------------# +# 将单张图片预测、摄像头检测和FPS测试功能 +# 整合到了一个py文件中,通过指定mode进行模式的修改。 +#----------------------------------------------------# +import time + +import cv2 +import numpy as np +from PIL import Image + +from deeplab import DeeplabV3 + +if __name__ == "__main__": + #-------------------------------------------------------------------------# + # 如果想要修改对应种类的颜色,到__init__函数里修改self.colors即可 + #-------------------------------------------------------------------------# + deeplab = DeeplabV3() + #----------------------------------------------------------------------------------------------------------# + # mode用于指定测试的模式: + # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 + # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 + # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 + # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 + # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 + #----------------------------------------------------------------------------------------------------------# + mode = "predict" + #-------------------------------------------------------------------------# + # count 指定了是否进行目标的像素点计数(即面积)与比例计算 + # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量 + # + # count、name_classes仅在mode='predict'时有效 + #-------------------------------------------------------------------------# + count = False + name_classes = ["background", "pl5", "pl20", "pl30", "pl40", "pl50", "pl60", "pl70", "pl80", "pl100", "pl120", "pm20", "pm55","pr40","p11", "pn", "pne", "p26", "i2", "i4", "i5", "ip", "il60", "il80", "il100", "p5", "p10", "p23", "p3", "pg", "p19", "p12", "p6", "p27", "ph4", "ph4.5", "ph5", "pm30", "w55", "w59", "w13", "w57", "w32", "wo", "io", "po", "indicative"] + # name_classes = ["background","cat","dog"] + #----------------------------------------------------------------------------------------------------------# + # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 + # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 + # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 + # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 + # video_fps 用于保存的视频的fps + # + # video_path、video_save_path和video_fps仅在mode='video'时有效 + # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 + #----------------------------------------------------------------------------------------------------------# + video_path = 0 + video_save_path = "" + video_fps = 25.0 + #----------------------------------------------------------------------------------------------------------# + # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 + # fps_image_path 用于指定测试的fps图片 + # + # test_interval和fps_image_path仅在mode='fps'有效 + #----------------------------------------------------------------------------------------------------------# + test_interval = 100 + fps_image_path = "img/73473.jpg" + #-------------------------------------------------------------------------# + # dir_origin_path 指定了用于检测的图片的文件夹路径 + # dir_save_path 指定了检测完图片的保存路径 + # + # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 + #-------------------------------------------------------------------------# + dir_origin_path = "imgs/" + dir_save_path = "img_out/" + #-------------------------------------------------------------------------# + # simplify 使用Simplify onnx + # onnx_save_path 指定了onnx的保存路径 + #-------------------------------------------------------------------------# + simplify = True + onnx_save_path = "model_data/models.onnx" + + if mode == "predict": + ''' + predict.py有几个注意点 + 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。 + 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。 + 2、如果想要保存,利用r_image.save("img.jpg")即可保存。 + 3、如果想要原图和分割图不混合,可以把blend参数设置成False。 + 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。 + seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3)) + for c in range(self.num_classes): + seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8') + seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8') + seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8') + ''' + while True: + img = input('Input image filename:') + try: + image = Image.open(img) + except: + print('Open Error! Try again!') + continue + else: + r_image = deeplab.detect_image(image, count=count, name_classes=name_classes) + r_image.show() + + elif mode == "video": + capture=cv2.VideoCapture(video_path) + if video_save_path!="": + fourcc = cv2.VideoWriter_fourcc(*'XVID') + size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) + out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) + + ref, frame = capture.read() + if not ref: + raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") + + fps = 0.0 + while(True): + t1 = time.time() + # 读取某一帧 + ref, frame = capture.read() + if not ref: + break + # 格式转变,BGRtoRGB + frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) + # 转变成Image + frame = Image.fromarray(np.uint8(frame)) + # 进行检测 + frame = np.array(deeplab.detect_image(frame)) + # RGBtoBGR满足opencv显示格式 + frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) + + fps = ( fps + (1./(time.time()-t1)) ) / 2 + print("fps= %.2f"%(fps)) + frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + + cv2.imshow("video",frame) + c= cv2.waitKey(1) & 0xff + if video_save_path!="": + out.write(frame) + + if c==27: + capture.release() + break + print("Video Detection Done!") + capture.release() + if video_save_path!="": + print("Save processed video to the path :" + video_save_path) + out.release() + cv2.destroyAllWindows() + + elif mode == "fps": + img = Image.open(fps_image_path) + tact_time = deeplab.get_FPS(img, test_interval) + print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') + + elif mode == "dir_predict": + import os + from tqdm import tqdm + + img_names = os.listdir(dir_origin_path) + for img_name in tqdm(img_names): + if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): + image_path = os.path.join(dir_origin_path, img_name) + image = Image.open(image_path) + r_image = deeplab.detect_image(image) + if not os.path.exists(dir_save_path): + os.makedirs(dir_save_path) + r_image.save(os.path.join(dir_save_path, img_name)) + elif mode == "export_onnx": + deeplab.convert_to_onnx(simplify, onnx_save_path) + + else: + raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a4e6b7d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +scipy==1.2.1 +numpy==1.17.0 +matplotlib==3.1.2 +opencv_python==4.1.2.30 +torch==1.2.0 +torchvision==0.4.0 +tqdm==4.60.0 +Pillow==8.2.0 +h5py==2.10.0 diff --git a/summary.py b/summary.py new file mode 100644 index 0000000..3c2cb7c --- /dev/null +++ b/summary.py @@ -0,0 +1,30 @@ +#--------------------------------------------# +# 该部分代码用于看网络结构 +#--------------------------------------------# +import torch +from thop import clever_format, profile +from torchsummary import summary + +from nets.deeplabv3_plus import DeepLab + +if __name__ == "__main__": + input_shape = [512, 512] + num_classes = 47 + backbone = 'mobilenet' + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=16, pretrained=False).to(device) + summary(model, (3, input_shape[0], input_shape[1])) + + dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) + flops, params = profile(model.to(device), (dummy_input, ), verbose=False) + #--------------------------------------------------------# + # flops * 2是因为profile没有将卷积作为两个operations + # 有些论文将卷积算乘法、加法两个operations。此时乘2 + # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 + # 本代码选择乘2,参考YOLOX。 + #--------------------------------------------------------# + flops = flops * 2 + flops, params = clever_format([flops, params], "%.3f") + print('Total GFLOPS: %s' % (flops)) + print('Total params: %s' % (params)) diff --git a/train.py b/train.py new file mode 100644 index 0000000..9df77b8 --- /dev/null +++ b/train.py @@ -0,0 +1,522 @@ +import os +import datetime + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim as optim +from torch.utils.data import DataLoader + +from nets.deeplabv3_plus import DeepLab +from nets.deeplabv3_training import (get_lr_scheduler, set_optimizer_lr, + weights_init) +from utils.callbacks import LossHistory, EvalCallback +from utils.dataloader import DeeplabDataset, deeplab_dataset_collate +from utils.utils import download_weights, show_config +from utils.utils_fit import fit_one_epoch + +''' +训练自己的语义分割模型一定需要注意以下几点: +1、训练前仔细检查自己的格式是否满足要求,该库要求数据集格式为VOC格式,需要准备好的内容有输入图片和标签 + 输入图片为.jpg图片,无需固定大小,传入训练前会自动进行resize。 + 灰度图会自动转成RGB图片进行训练,无需自己修改。 + 输入图片如果后缀非jpg,需要自己批量转成jpg后再开始训练。 + + 标签为png图片,无需固定大小,传入训练前会自动进行resize。 + 由于许多同学的数据集是网络上下载的,标签格式并不符合,需要再度处理。一定要注意!标签的每个像素点的值就是这个像素点所属的种类。 + 网上常见的数据集总共对输入图片分两类,背景的像素点值为0,目标的像素点值为255。这样的数据集可以正常运行但是预测是没有效果的! + 需要改成,背景的像素点值为0,目标的像素点值为1。 + 如果格式有误,参考:https://github.com/bubbliiiing/segmentation-format-fix + +2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。 + 损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。 + 训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中 + +3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。 + 如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。 +''' +if __name__ == "__main__": + #---------------------------------# + # Cuda 是否使用Cuda + # 没有GPU可以设置成False + #---------------------------------# + Cuda = True + #---------------------------------------------------------------------# + # distributed 用于指定是否使用单机多卡分布式运行 + # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 + # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。 + # DP模式: + # 设置 distributed = False + # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py + # DDP模式: + # 设置 distributed = True + # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py + #---------------------------------------------------------------------# + distributed = False + #---------------------------------------------------------------------# + # sync_bn 是否使用sync_bn,DDP模式多卡可用 + #---------------------------------------------------------------------# + sync_bn = False + #---------------------------------------------------------------------# + # fp16 是否使用混合精度训练 + # 可减少约一半的显存、需要pytorch1.7.1以上 + #---------------------------------------------------------------------# + fp16 = False + #-----------------------------------------------------# + # num_classes 训练自己的数据集必须要修改的 + # 自己需要的分类个数+1,如2+1 + #-----------------------------------------------------# + num_classes = 46 + #---------------------------------# + # 所使用的的主干网络: + # mobilenet + # xception + #---------------------------------# + backbone = "mobilenet" + #----------------------------------------------------------------------------------------------------------------------------# + # pretrained 是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。 + # 如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。 + # 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。 + # 如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 + #----------------------------------------------------------------------------------------------------------------------------# + pretrained = False + #----------------------------------------------------------------------------------------------------------------------------# + # 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。 + # 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。 + # 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好 + # 训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配 + # + # 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。 + # 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。 + # + # 当model_path = ''的时候不加载整个模型的权值。 + # + # 此处使用的是整个模型的权重,因此是在train.py进行加载的,pretrain不影响此处的权值加载。 + # 如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',pretrain = True,此时仅加载主干。 + # 如果想要让模型从0开始训练,则设置model_path = '',pretrain = Fasle,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 + # + # 一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练! + # 如果一定要从0开始,可以了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。 + #----------------------------------------------------------------------------------------------------------------------------# + model_path = "model_data/deeplab_mobilenetv2.pth" + #---------------------------------------------------------# + # downsample_factor 下采样的倍数8、16 + # 8下采样的倍数较小、理论上效果更好。 + # 但也要求更大的显存 + #---------------------------------------------------------# + downsample_factor = 16 + #------------------------------# + # 输入图片的大小 + #------------------------------# + input_shape = [1024, 1024] + + #----------------------------------------------------------------------------------------------------------------------------# + # 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。 + # 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。 + # + # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整: + # (一)从整个模型的预训练权重开始训练: + # Adam: + # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 5e-4,weight_decay = 0。(冻结) + # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 5e-4,weight_decay = 0。(不冻结) + # SGD: + # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 7e-3,weight_decay = 1e-4。(冻结) + # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 7e-3,weight_decay = 1e-4。(不冻结) + # 其中:UnFreeze_Epoch可以在100-300之间调整。 + # (二)从主干网络的预训练权重开始训练: + # Adam: + # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 5e-4,weight_decay = 0。(冻结) + # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 5e-4,weight_decay = 0。(不冻结) + # SGD: + # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 120,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 7e-3,weight_decay = 1e-4。(冻结) + # Init_Epoch = 0,UnFreeze_Epoch = 120,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 7e-3,weight_decay = 1e-4。(不冻结) + # 其中:由于从主干网络的预训练权重开始训练,主干的权值不一定适合语义分割,需要更多的训练跳出局部最优解。 + # UnFreeze_Epoch可以在120-300之间调整。 + # Adam相较于SGD收敛的快一些。因此UnFreeze_Epoch理论上可以小一点,但依然推荐更多的Epoch。 + # (三)batch_size的设置: + # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。 + # 受到BatchNorm层影响,batch_size最小为2,不能为1。 + # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。 + #----------------------------------------------------------------------------------------------------------------------------# + #------------------------------------------------------------------# + # 冻结阶段训练参数 + # 此时模型的主干被冻结了,特征提取网络不发生改变 + # 占用的显存较小,仅对网络进行微调 + # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置: + # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100 + # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。 + # (断点续练时使用) + # Freeze_Epoch 模型冻结训练的Freeze_Epoch + # (当Freeze_Train=False时失效) + # Freeze_batch_size 模型冻结训练的batch_size + # (当Freeze_Train=False时失效) + #------------------------------------------------------------------# + Init_Epoch = 0 + Freeze_Epoch = 400 + Freeze_batch_size = 8 + #------------------------------------------------------------------# + # 解冻阶段训练参数 + # 此时模型的主干不被冻结了,特征提取网络会发生改变 + # 占用的显存较大,网络所有的参数都会发生改变 + # UnFreeze_Epoch 模型总共训练的epoch + # Unfreeze_batch_size 模型在解冻后的batch_size + #------------------------------------------------------------------# + UnFreeze_Epoch = 200 + Unfreeze_batch_size = 4 + #------------------------------------------------------------------# + # Freeze_Train 是否进行冻结训练 + # 默认先冻结主干训练后解冻训练。 + #------------------------------------------------------------------# + Freeze_Train = True + + #------------------------------------------------------------------# + # 其它训练参数:学习率、优化器、学习率下降有关 + #------------------------------------------------------------------# + #------------------------------------------------------------------# + # Init_lr 模型的最大学习率 + # 当使用Adam优化器时建议设置 Init_lr=5e-4 + # 当使用SGD优化器时建议设置 Init_lr=7e-3 + # Min_lr 模型的最小学习率,默认为最大学习率的0.01 + #------------------------------------------------------------------# + Init_lr = 7e-3 + Min_lr = Init_lr * 0.01 + #------------------------------------------------------------------# + # optimizer_type 使用到的优化器种类,可选的有adam、sgd + # 当使用Adam优化器时建议设置 Init_lr=5e-4 + # 当使用SGD优化器时建议设置 Init_lr=7e-3 + # momentum 优化器内部使用到的momentum参数 + # weight_decay 权值衰减,可防止过拟合 + # adam会导致weight_decay错误,使用adam时建议设置为0。 + #------------------------------------------------------------------# + optimizer_type = "sgd" + momentum = 0.9 + weight_decay = 1e-4 + #------------------------------------------------------------------# + # lr_decay_type 使用到的学习率下降方式,可选的有'step'、'cos' + #------------------------------------------------------------------# + lr_decay_type = 'cos' + #------------------------------------------------------------------# + # save_period 多少个epoch保存一次权值 + #------------------------------------------------------------------# + save_period = 5 + #------------------------------------------------------------------# + # save_dir 权值与日志文件保存的文件夹 + #------------------------------------------------------------------# + save_dir = 'logs' + #------------------------------------------------------------------# + # eval_flag 是否在训练时进行评估,评估对象为验证集 + # eval_period 代表多少个epoch评估一次,不建议频繁的评估 + # 评估需要消耗较多的时间,频繁评估会导致训练非常慢 + # 此处获得的mAP会与get_map.py获得的会有所不同,原因有二: + # (一)此处获得的mAP为验证集的mAP。 + # (二)此处设置评估参数较为保守,目的是加快评估速度。 + #------------------------------------------------------------------# + eval_flag = True + eval_period = 5 + + #------------------------------------------------------------------# + # VOCdevkit_path 数据集路径 + #------------------------------------------------------------------# + VOCdevkit_path = 'VOCdevkit' + #------------------------------------------------------------------# + # 建议选项: + # 种类少(几类)时,设置为True + # 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True + # 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False + #------------------------------------------------------------------# + dice_loss = False + #------------------------------------------------------------------# + # 是否使用focal loss来防止正负样本不平衡 + #------------------------------------------------------------------# + focal_loss = False + #------------------------------------------------------------------# + # 是否给不同种类赋予不同的损失权值,默认是平衡的。 + # 设置的话,注意设置成numpy形式的,长度和num_classes一样。 + # 如: + # num_classes = 3 + # cls_weights = np.array([1, 2, 3], np.float32) + #------------------------------------------------------------------# + cls_weights = np.ones([num_classes], np.float32) + #------------------------------------------------------------------# + # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程 + # 开启后会加快数据读取速度,但是会占用更多内存 + # keras里开启多线程有些时候速度反而慢了许多 + # 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。 + #------------------------------------------------------------------# + num_workers = 4 + + #------------------------------------------------------# + # 设置用到的显卡 + #------------------------------------------------------# + ngpus_per_node = torch.cuda.device_count() + if distributed: + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + device = torch.device("cuda", local_rank) + if local_rank == 0: + print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") + print("Gpu Device Count : ", ngpus_per_node) + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + local_rank = 0 + + #----------------------------------------------------# + # 下载预训练权重 + #----------------------------------------------------# + if pretrained: + if distributed: + if local_rank == 0: + download_weights(backbone) + dist.barrier() + else: + download_weights(backbone) + + model = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained) + if not pretrained: + weights_init(model) + if model_path != '': + #------------------------------------------------------# + # 权值文件请看README,百度网盘下载 + #------------------------------------------------------# + if local_rank == 0: + print('Load weights {}.'.format(model_path)) + + #------------------------------------------------------# + # 根据预训练权重的Key和模型的Key进行加载 + #------------------------------------------------------# + model_dict = model.state_dict() + pretrained_dict = torch.load(model_path, map_location = device) + load_key, no_load_key, temp_dict = [], [], {} + for k, v in pretrained_dict.items(): + if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): + temp_dict[k] = v + load_key.append(k) + else: + no_load_key.append(k) + model_dict.update(temp_dict) + model.load_state_dict(model_dict) + #------------------------------------------------------# + # 显示没有匹配上的Key + #------------------------------------------------------# + if local_rank == 0: + print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) + print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) + print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m") + + #----------------------# + # 记录Loss + #----------------------# + if local_rank == 0: + time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') + log_dir = os.path.join(save_dir, "loss_" + str(time_str)) + loss_history = LossHistory(log_dir, model, input_shape=input_shape) + else: + loss_history = None + + #------------------------------------------------------------------# + # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16 + # 因此torch1.2这里显示"could not be resolve" + #------------------------------------------------------------------# + if fp16: + from torch.cuda.amp import GradScaler as GradScaler + scaler = GradScaler() + else: + scaler = None + + model_train = model.train() + #----------------------------# + # 多卡同步Bn + #----------------------------# + if sync_bn and ngpus_per_node > 1 and distributed: + model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) + elif sync_bn: + print("Sync_bn is not support in one gpu or not distributed.") + + if Cuda: + if distributed: + #----------------------------# + # 多卡平行运行 + #----------------------------# + model_train = model_train.cuda(local_rank) + model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True) + else: + model_train = torch.nn.DataParallel(model) + cudnn.benchmark = True + model_train = model_train.cuda() + + #---------------------------# + # 读取数据集对应的txt + #---------------------------# + with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/train.txt"),"r") as f: + train_lines = f.readlines() + with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),"r") as f: + val_lines = f.readlines() + num_train = len(train_lines) + num_val = len(val_lines) + + if local_rank == 0: + show_config( + num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, + Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, + Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, + save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val + ) + #---------------------------------------------------------# + # 总训练世代指的是遍历全部数据的总次数 + # 总训练步长指的是梯度下降的总次数 + # 每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。 + # 此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分 + #----------------------------------------------------------# + wanted_step = 1.5e4 if optimizer_type == "sgd" else 0.5e4 + total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch + if total_step <= wanted_step: + if num_train // Unfreeze_batch_size == 0: + raise ValueError('数据集过小,无法进行训练,请扩充数据集。') + wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1 + print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step)) + print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step)) + print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m"%(total_step, wanted_step, wanted_epoch)) + + #------------------------------------------------------# + # 主干特征提取网络特征通用,冻结训练可以加快训练速度 + # 也可以在训练初期防止权值被破坏。 + # Init_Epoch为起始世代 + # Interval_Epoch为冻结训练的世代 + # Epoch总训练世代 + # 提示OOM或者显存不足请调小Batch_size + #------------------------------------------------------# + if True: + UnFreeze_flag = False + #------------------------------------# + # 冻结一定部分训练 + #------------------------------------# + if Freeze_Train: + for param in model.backbone.parameters(): + param.requires_grad = False + + #-------------------------------------------------------------------# + # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size + #-------------------------------------------------------------------# + batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size + + #-------------------------------------------------------------------# + # 判断当前batch_size,自适应调整学习率 + #-------------------------------------------------------------------# + nbs = 16 + lr_limit_max = 5e-4 if optimizer_type == 'adam' else 1e-1 + lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4 + if backbone == "xception": + lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 + lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 + Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) + Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) + + #---------------------------------------# + # 根据optimizer_type选择优化器 + #---------------------------------------# + optimizer = { + 'adam' : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay), + 'sgd' : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay) + }[optimizer_type] + + #---------------------------------------# + # 获得学习率下降的公式 + #---------------------------------------# + lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) + + #---------------------------------------# + # 判断每一个世代的长度 + #---------------------------------------# + epoch_step = num_train // batch_size + epoch_step_val = num_val // batch_size + + if epoch_step == 0 or epoch_step_val == 0: + raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") + + train_dataset = DeeplabDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path) + val_dataset = DeeplabDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path) + + if distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,) + batch_size = batch_size // ngpus_per_node + shuffle = False + else: + train_sampler = None + val_sampler = None + shuffle = True + + gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, + drop_last = True, collate_fn = deeplab_dataset_collate, sampler=train_sampler) + gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, + drop_last = True, collate_fn = deeplab_dataset_collate, sampler=val_sampler) + + #----------------------# + # 记录eval的map曲线 + #----------------------# + if local_rank == 0: + eval_callback = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, eval_flag=eval_flag, period=eval_period) + else: + eval_callback = None + + #---------------------------------------# + # 开始模型训练 + #---------------------------------------# + for epoch in range(Init_Epoch, UnFreeze_Epoch): + #---------------------------------------# + # 如果模型有冻结学习部分 + # 则解冻,并设置参数 + #---------------------------------------# + if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: + batch_size = Unfreeze_batch_size + + #-------------------------------------------------------------------# + # 判断当前batch_size,自适应调整学习率 + #-------------------------------------------------------------------# + nbs = 16 + lr_limit_max = 5e-4 if optimizer_type == 'adam' else 1e-1 + lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4 + if backbone == "xception": + lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 + lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 + Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) + Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) + #---------------------------------------# + # 获得学习率下降的公式 + #---------------------------------------# + lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) + + for param in model.backbone.parameters(): + param.requires_grad = True + + epoch_step = num_train // batch_size + epoch_step_val = num_val // batch_size + + if epoch_step == 0 or epoch_step_val == 0: + raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") + + if distributed: + batch_size = batch_size // ngpus_per_node + + gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, + drop_last = True, collate_fn = deeplab_dataset_collate, sampler=train_sampler) + gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, + drop_last = True, collate_fn = deeplab_dataset_collate, sampler=val_sampler) + + UnFreeze_flag = True + + if distributed: + train_sampler.set_epoch(epoch) + + set_optimizer_lr(optimizer, lr_scheduler_func, epoch) + + fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, + epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank) + + if distributed: + dist.barrier() + + if local_rank == 0: + loss_history.writer.close() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..61e8e95 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..cbb20cc Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/__pycache__/callbacks.cpython-38.pyc b/utils/__pycache__/callbacks.cpython-38.pyc new file mode 100644 index 0000000..a56fde5 Binary files /dev/null and b/utils/__pycache__/callbacks.cpython-38.pyc differ diff --git a/utils/__pycache__/callbacks.cpython-39.pyc b/utils/__pycache__/callbacks.cpython-39.pyc new file mode 100644 index 0000000..96cc47c Binary files /dev/null and b/utils/__pycache__/callbacks.cpython-39.pyc differ diff --git a/utils/__pycache__/dataloader.cpython-38.pyc b/utils/__pycache__/dataloader.cpython-38.pyc new file mode 100644 index 0000000..37f2fab Binary files /dev/null and b/utils/__pycache__/dataloader.cpython-38.pyc differ diff --git a/utils/__pycache__/dataloader.cpython-39.pyc b/utils/__pycache__/dataloader.cpython-39.pyc new file mode 100644 index 0000000..a605f51 Binary files /dev/null and b/utils/__pycache__/dataloader.cpython-39.pyc differ diff --git a/utils/__pycache__/utils.cpython-38.pyc b/utils/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..ca964bd Binary files /dev/null and b/utils/__pycache__/utils.cpython-38.pyc differ diff --git a/utils/__pycache__/utils.cpython-39.pyc b/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..7eadcbc Binary files /dev/null and b/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/utils/__pycache__/utils_fit.cpython-38.pyc b/utils/__pycache__/utils_fit.cpython-38.pyc new file mode 100644 index 0000000..cdaaeb0 Binary files /dev/null and b/utils/__pycache__/utils_fit.cpython-38.pyc differ diff --git a/utils/__pycache__/utils_fit.cpython-39.pyc b/utils/__pycache__/utils_fit.cpython-39.pyc new file mode 100644 index 0000000..7835bfc Binary files /dev/null and b/utils/__pycache__/utils_fit.cpython-39.pyc differ diff --git a/utils/__pycache__/utils_metrics.cpython-38.pyc b/utils/__pycache__/utils_metrics.cpython-38.pyc new file mode 100644 index 0000000..c2df3e3 Binary files /dev/null and b/utils/__pycache__/utils_metrics.cpython-38.pyc differ diff --git a/utils/__pycache__/utils_metrics.cpython-39.pyc b/utils/__pycache__/utils_metrics.cpython-39.pyc new file mode 100644 index 0000000..00f27d6 Binary files /dev/null and b/utils/__pycache__/utils_metrics.cpython-39.pyc differ diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 0000000..16ca6e5 --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,200 @@ +import os + +import matplotlib +import torch +import torch.nn.functional as F + +matplotlib.use('Agg') +from matplotlib import pyplot as plt +import scipy.signal + +import cv2 +import shutil +import numpy as np + +from PIL import Image +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter +from .utils import cvtColor, preprocess_input, resize_image +from .utils_metrics import compute_mIoU + + +class LossHistory(): + def __init__(self, log_dir, model, input_shape): + self.log_dir = log_dir + self.losses = [] + self.val_loss = [] + + os.makedirs(self.log_dir) + self.writer = SummaryWriter(self.log_dir) + try: + dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) + self.writer.add_graph(model, dummy_input) + except: + pass + + def append_loss(self, epoch, loss, val_loss): + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + + self.losses.append(loss) + self.val_loss.append(val_loss) + + with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: + f.write(str(loss)) + f.write("\n") + with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: + f.write(str(val_loss)) + f.write("\n") + + self.writer.add_scalar('loss', loss, epoch) + self.writer.add_scalar('val_loss', val_loss, epoch) + self.loss_plot() + + def loss_plot(self): + iters = range(len(self.losses)) + + plt.figure() + plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') + plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') + try: + if len(self.losses) < 25: + num = 5 + else: + num = 15 + + plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') + plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') + except: + pass + + plt.grid(True) + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend(loc="upper right") + + plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) + + plt.cla() + plt.close("all") + +class EvalCallback(): + def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, + miou_out_path=".temp_miou_out", eval_flag=True, period=1): + super(EvalCallback, self).__init__() + + self.net = net + self.input_shape = input_shape + self.num_classes = num_classes + self.image_ids = image_ids + self.dataset_path = dataset_path + self.log_dir = log_dir + self.cuda = cuda + self.miou_out_path = miou_out_path + self.eval_flag = eval_flag + self.period = period + + self.image_ids = [image_id.split()[0] for image_id in image_ids] + self.mious = [0] + self.epoches = [0] + if self.eval_flag: + with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: + f.write(str(0)) + f.write("\n") + + def get_miou_png(self, image): + #---------------------------------------------------------# + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 + # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB + #---------------------------------------------------------# + image = cvtColor(image) + orininal_h = np.array(image).shape[0] + orininal_w = np.array(image).shape[1] + #---------------------------------------------------------# + # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 + #---------------------------------------------------------# + image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) + #---------------------------------------------------------# + # 添加上batch_size维度 + #---------------------------------------------------------# + image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) + + with torch.no_grad(): + images = torch.from_numpy(image_data) + if self.cuda: + images = images.cuda() + + #---------------------------------------------------# + # 图片传入网络进行预测 + #---------------------------------------------------# + pr = self.net(images)[0] + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() + #--------------------------------------# + # 将灰条部分截取掉 + #--------------------------------------# + pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), + int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] + #---------------------------------------------------# + # 进行图片的resize + #---------------------------------------------------# + pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = pr.argmax(axis=-1) + + image = Image.fromarray(np.uint8(pr)) + return image + + def on_epoch_end(self, epoch, model_eval): + if epoch % self.period == 0 and self.eval_flag: + self.net = model_eval + gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/") + pred_dir = os.path.join(self.miou_out_path, 'detection-results') + if not os.path.exists(self.miou_out_path): + os.makedirs(self.miou_out_path) + if not os.path.exists(pred_dir): + os.makedirs(pred_dir) + print("Get miou.") + for image_id in tqdm(self.image_ids): + #-------------------------------# + # 从文件中读取图像 + #-------------------------------# + image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg") + image = Image.open(image_path) + #------------------------------# + # 获得预测txt + #------------------------------# + image = self.get_miou_png(image) + image.save(os.path.join(pred_dir, image_id + ".png")) + + print("Calculate miou.") + _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数 + temp_miou = np.nanmean(IoUs) * 100 + + self.mious.append(temp_miou) + self.epoches.append(epoch) + + with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: + f.write(str(temp_miou)) + f.write("\n") + + plt.figure() + plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou') + + plt.grid(True) + plt.xlabel('Epoch') + plt.ylabel('Miou') + plt.title('A Miou Curve') + plt.legend(loc="upper right") + + plt.savefig(os.path.join(self.log_dir, "epoch_miou.png")) + plt.cla() + plt.close("all") + + print("Get miou done.") + shutil.rmtree(self.miou_out_path) diff --git a/utils/dataloader.py b/utils/dataloader.py new file mode 100644 index 0000000..9ab78b3 --- /dev/null +++ b/utils/dataloader.py @@ -0,0 +1,169 @@ +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data.dataset import Dataset + +from utils.utils import cvtColor, preprocess_input + + +class DeeplabDataset(Dataset): + def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path): + super(DeeplabDataset, self).__init__() + self.annotation_lines = annotation_lines + self.length = len(annotation_lines) + self.input_shape = input_shape + self.num_classes = num_classes + self.train = train + self.dataset_path = dataset_path + + def __len__(self): + return self.length + + def __getitem__(self, index): + annotation_line = self.annotation_lines[index] + name = annotation_line.split()[0] + + #-------------------------------# + # 从文件中读取图像 + #-------------------------------# + jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg")) + png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png")) + #-------------------------------# + # 数据增强 + #-------------------------------# + jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) + + jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) + png = np.array(png) + png[png >= self.num_classes] = self.num_classes + #-------------------------------------------------------# + # 转化成one_hot的形式 + # 在这里需要+1是因为voc数据集有些标签具有白边部分 + # 我们需要将白边部分进行忽略,+1的目的是方便忽略。 + #-------------------------------------------------------# + seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])] + seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) + + return jpg, png, seg_labels + + def rand(self, a=0, b=1): + return np.random.rand() * (b - a) + a + + def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): + image = cvtColor(image) + label = Image.fromarray(np.array(label)) + #------------------------------# + # 获得图像的高宽与目标高宽 + #------------------------------# + iw, ih = image.size + h, w = input_shape + + if not random: + iw, ih = image.size + scale = min(w/iw, h/ih) + nw = int(iw*scale) + nh = int(ih*scale) + + image = image.resize((nw,nh), Image.BICUBIC) + new_image = Image.new('RGB', [w, h], (128,128,128)) + new_image.paste(image, ((w-nw)//2, (h-nh)//2)) + + label = label.resize((nw,nh), Image.NEAREST) + new_label = Image.new('L', [w, h], (0)) + new_label.paste(label, ((w-nw)//2, (h-nh)//2)) + return new_image, new_label + + #------------------------------------------# + # 对图像进行缩放并且进行长和宽的扭曲 + #------------------------------------------# + new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) + scale = self.rand(0.25, 2) + if new_ar < 1: + nh = int(scale*h) + nw = int(nh*new_ar) + else: + nw = int(scale*w) + nh = int(nw/new_ar) + image = image.resize((nw,nh), Image.BICUBIC) + label = label.resize((nw,nh), Image.NEAREST) + + #------------------------------------------# + # 翻转图像 + #------------------------------------------# + flip = self.rand()<.5 + if flip: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + label = label.transpose(Image.FLIP_LEFT_RIGHT) + + #------------------------------------------# + # 将图像多余的部分加上灰条 + #------------------------------------------# + dx = int(self.rand(0, w-nw)) + dy = int(self.rand(0, h-nh)) + new_image = Image.new('RGB', (w,h), (128,128,128)) + new_label = Image.new('L', (w,h), (0)) + new_image.paste(image, (dx, dy)) + new_label.paste(label, (dx, dy)) + image = new_image + label = new_label + + image_data = np.array(image, np.uint8) + + #------------------------------------------# + # 高斯模糊 + #------------------------------------------# + blur = self.rand() < 0.25 + if blur: + image_data = cv2.GaussianBlur(image_data, (5, 5), 0) + + #------------------------------------------# + # 旋转 + #------------------------------------------# + rotate = self.rand() < 0.25 + if rotate: + center = (w // 2, h // 2) + rotation = np.random.randint(-10, 11) + M = cv2.getRotationMatrix2D(center, -rotation, scale=1) + image_data = cv2.warpAffine(image_data, M, (w, h), flags=cv2.INTER_CUBIC, borderValue=(128,128,128)) + label = cv2.warpAffine(np.array(label, np.uint8), M, (w, h), flags=cv2.INTER_NEAREST, borderValue=(0)) + + #---------------------------------# + # 对图像进行色域变换 + # 计算色域变换的参数 + #---------------------------------# + r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 + #---------------------------------# + # 将图像转到HSV上 + #---------------------------------# + hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) + dtype = image_data.dtype + #---------------------------------# + # 应用变换 + #---------------------------------# + x = np.arange(0, 256, dtype=r.dtype) + lut_hue = ((x * r[0]) % 180).astype(dtype) + lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) + lut_val = np.clip(x * r[2], 0, 255).astype(dtype) + + image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) + image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) + + return image_data, label + + +# DataLoader中collate_fn使用 +def deeplab_dataset_collate(batch): + images = [] + pngs = [] + seg_labels = [] + for img, png, labels in batch: + images.append(img) + pngs.append(png) + seg_labels.append(labels) + images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) + pngs = torch.from_numpy(np.array(pngs)).long() + seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) + return images, pngs, seg_labels diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..63f2be1 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,64 @@ +import numpy as np +from PIL import Image + +#---------------------------------------------------------# +# 将图像转换成RGB图像,防止灰度图在预测时报错。 +# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB +#---------------------------------------------------------# +def cvtColor(image): + if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: + return image + else: + image = image.convert('RGB') + return image + +#---------------------------------------------------# +# 对输入图像进行resize +#---------------------------------------------------# +def resize_image(image, size): + iw, ih = image.size + w, h = size + + scale = min(w/iw, h/ih) + nw = int(iw*scale) + nh = int(ih*scale) + + image = image.resize((nw,nh), Image.BICUBIC) + new_image = Image.new('RGB', size, (128,128,128)) + new_image.paste(image, ((w-nw)//2, (h-nh)//2)) + + return new_image, nw, nh + +#---------------------------------------------------# +# 获得学习率 +#---------------------------------------------------# +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + +def preprocess_input(image): + image /= 255.0 + return image + +def show_config(**kwargs): + print('Configurations:') + print('-' * 70) + print('|%25s | %40s|' % ('keys', 'values')) + print('-' * 70) + for key, value in kwargs.items(): + print('|%25s | %40s|' % (str(key), str(value))) + print('-' * 70) + +def download_weights(backbone, model_dir="./model_data"): + import os + from torch.hub import load_state_dict_from_url + + download_urls = { + 'mobilenet' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/mobilenet_v2.pth.tar', + 'xception' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth', + } + url = download_urls[backbone] + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + load_state_dict_from_url(url, model_dir) \ No newline at end of file diff --git a/utils/utils_fit.py b/utils/utils_fit.py new file mode 100644 index 0000000..f113705 --- /dev/null +++ b/utils/utils_fit.py @@ -0,0 +1,174 @@ +import os + +import torch +from nets.deeplabv3_training import (CE_Loss, Dice_loss, Focal_Loss, + weights_init) +from tqdm import tqdm + +from utils.utils import get_lr +from utils.utils_metrics import f_score + + +def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0): + total_loss = 0 + total_f_score = 0 + + val_loss = 0 + val_f_score = 0 + + if local_rank == 0: + print('Start Train') + pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) + model_train.train() + for iteration, batch in enumerate(gen): + if iteration >= epoch_step: + break + imgs, pngs, labels = batch + + with torch.no_grad(): + weights = torch.from_numpy(cls_weights) + if cuda: + imgs = imgs.cuda(local_rank) + pngs = pngs.cuda(local_rank) + labels = labels.cuda(local_rank) + weights = weights.cuda(local_rank) + #----------------------# + # 清零梯度 + #----------------------# + optimizer.zero_grad() + if not fp16: + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 计算损失 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + + with torch.no_grad(): + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + #----------------------# + # 反向传播 + #----------------------# + loss.backward() + optimizer.step() + else: + from torch.cuda.amp import autocast + with autocast(): + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 计算损失 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + + with torch.no_grad(): + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + #----------------------# + # 反向传播 + #----------------------# + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + total_loss += loss.item() + total_f_score += _f_score.item() + + if local_rank == 0: + pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), + 'f_score' : total_f_score / (iteration + 1), + 'lr' : get_lr(optimizer)}) + pbar.update(1) + + if local_rank == 0: + pbar.close() + print('Finish Train') + print('Start Validation') + pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) + + model_train.eval() + for iteration, batch in enumerate(gen_val): + if iteration >= epoch_step_val: + break + imgs, pngs, labels = batch + with torch.no_grad(): + weights = torch.from_numpy(cls_weights) + if cuda: + imgs = imgs.cuda(local_rank) + pngs = pngs.cuda(local_rank) + labels = labels.cuda(local_rank) + weights = weights.cuda(local_rank) + + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 计算损失 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + val_loss += loss.item() + val_f_score += _f_score.item() + + if local_rank == 0: + pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1), + 'f_score' : val_f_score / (iteration + 1), + 'lr' : get_lr(optimizer)}) + pbar.update(1) + + if local_rank == 0: + pbar.close() + print('Finish Validation') + loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val) + eval_callback.on_epoch_end(epoch + 1, model_train) + print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) + print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val)) + + #-----------------------------------------------# + # 保存权值 + #-----------------------------------------------# + if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: + torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val))) + + if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): + print('Save best model to best_epoch_weights.pth') + torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) + + torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) \ No newline at end of file diff --git a/utils/utils_metrics.py b/utils/utils_metrics.py new file mode 100644 index 0000000..84a22a2 --- /dev/null +++ b/utils/utils_metrics.py @@ -0,0 +1,182 @@ +import csv +import os +from os.path import join + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + + +def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5): + n, c, h, w = inputs.size() + nt, ht, wt, ct = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) + temp_target = target.view(n, -1, ct) + + #--------------------------------------------# + # 计算dice系数 + #--------------------------------------------# + temp_inputs = torch.gt(temp_inputs, threhold).float() + tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) + fp = torch.sum(temp_inputs , axis=[0,1]) - tp + fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp + + score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) + score = torch.mean(score) + return score + +# 设标签宽W,长H +def fast_hist(a, b, n): + #--------------------------------------------------------------------------------# + # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) + #--------------------------------------------------------------------------------# + k = (a >= 0) & (a < n) + #--------------------------------------------------------------------------------# + # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) + # 返回中,写对角线上的为分类正确的像素点 + #--------------------------------------------------------------------------------# + return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) + +def per_class_iu(hist): + return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) + +def per_class_PA_Recall(hist): + return np.diag(hist) / np.maximum(hist.sum(1), 1) + +def per_class_Precision(hist): + return np.diag(hist) / np.maximum(hist.sum(0), 1) + +def per_Accuracy(hist): + return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) + +def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None): + print('Num classes', num_classes) + #-----------------------------------------# + # 创建一个全是0的矩阵,是一个混淆矩阵 + #-----------------------------------------# + hist = np.zeros((num_classes, num_classes)) + + #------------------------------------------------# + # 获得验证集标签路径列表,方便直接读取 + # 获得验证集图像分割结果路径列表,方便直接读取 + #------------------------------------------------# + gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list] + pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list] + + #------------------------------------------------# + # 读取每一个(图片-标签)对 + #------------------------------------------------# + for ind in range(len(gt_imgs)): + #------------------------------------------------# + # 读取一张图像分割结果,转化成numpy数组 + #------------------------------------------------# + pred = np.array(Image.open(pred_imgs[ind])) + #------------------------------------------------# + # 读取一张对应的标签,转化成numpy数组 + #------------------------------------------------# + label = np.array(Image.open(gt_imgs[ind])) + + # 如果图像分割结果与标签的大小不一样,这张图片就不计算 + if len(label.flatten()) != len(pred.flatten()): + print( + 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( + len(label.flatten()), len(pred.flatten()), gt_imgs[ind], + pred_imgs[ind])) + continue + + #------------------------------------------------# + # 对一张图片计算21×21的hist矩阵,并累加 + #------------------------------------------------# + hist += fast_hist(label.flatten(), pred.flatten(), num_classes) + # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 + if name_classes is not None and ind > 0 and ind % 10 == 0: + print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( + ind, + len(gt_imgs), + 100 * np.nanmean(per_class_iu(hist)), + 100 * np.nanmean(per_class_PA_Recall(hist)), + 100 * per_Accuracy(hist) + ) + ) + #------------------------------------------------# + # 计算所有验证集图片的逐类别mIoU值 + #------------------------------------------------# + IoUs = per_class_iu(hist) + PA_Recall = per_class_PA_Recall(hist) + Precision = per_class_Precision(hist) + #------------------------------------------------# + # 逐类别输出一下mIoU值 + #------------------------------------------------# + if name_classes is not None: + for ind_class in range(num_classes): + print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) + + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2))) + + #-----------------------------------------------------------------# + # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 + #-----------------------------------------------------------------# + print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) + return np.array(hist, np.int), IoUs, PA_Recall, Precision + +def adjust_axes(r, t, fig, axes): + bb = t.get_window_extent(renderer=r) + text_width_inches = bb.width / fig.dpi + current_fig_width = fig.get_figwidth() + new_fig_width = current_fig_width + text_width_inches + propotion = new_fig_width / current_fig_width + x_lim = axes.get_xlim() + axes.set_xlim([x_lim[0], x_lim[1] * propotion]) + +def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): + fig = plt.gcf() + axes = plt.gca() + plt.barh(range(len(values)), values, color='royalblue') + plt.title(plot_title, fontsize=tick_font_size + 2) + plt.xlabel(x_label, fontsize=tick_font_size) + plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) + r = fig.canvas.get_renderer() + for i, val in enumerate(values): + str_val = " " + str(val) + if val < 1.0: + str_val = " {0:.2f}".format(val) + t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') + if i == (len(values)-1): + adjust_axes(r, t, fig, axes) + + fig.tight_layout() + fig.savefig(output_path) + if plt_show: + plt.show() + plt.close() + +def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12): + draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", + os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True) + print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) + + draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", + os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) + + draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", + os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) + + draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", + os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) + + with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: + writer = csv.writer(f) + writer_list = [] + writer_list.append([' '] + [str(c) for c in name_classes]) + for i in range(len(hist)): + writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) + writer.writerows(writer_list) + print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) + \ No newline at end of file diff --git a/voc_annotation.py b/voc_annotation.py new file mode 100644 index 0000000..c04a46d --- /dev/null +++ b/voc_annotation.py @@ -0,0 +1,100 @@ +import os +import random + +import numpy as np +from PIL import Image +from tqdm import tqdm + +#-------------------------------------------------------# +# 想要增加测试集修改trainval_percent +# 修改train_percent用于改变验证集的比例 9:1 +# +# 当前该库将测试集当作验证集使用,不单独划分测试集 +#-------------------------------------------------------# +trainval_percent = 1 +train_percent = 0.9 +#-------------------------------------------------------# +# 指向VOC数据集所在的文件夹 +# 默认指向根目录下的VOC数据集 +#-------------------------------------------------------# +VOCdevkit_path = 'VOCdevkit' + +if __name__ == "__main__": + random.seed(0) + print("Generate txt in ImageSets.") + segfilepath = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass') + saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation') + + temp_seg = os.listdir(segfilepath) + total_seg = [] + for seg in temp_seg: + if seg.endswith(".png"): + total_seg.append(seg) + + num = len(total_seg) + list = range(num) + tv = int(num*trainval_percent) + tr = int(tv*train_percent) + trainval= random.sample(list,tv) + train = random.sample(trainval,tr) + + print("train and val size",tv) + print("traub suze",tr) + ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') + ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') + ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') + fval = open(os.path.join(saveBasePath,'val.txt'), 'w') + + for i in list: + name = total_seg[i][:-4]+'\n' + if i in trainval: + ftrainval.write(name) + if i in train: + ftrain.write(name) + else: + fval.write(name) + else: + ftest.write(name) + + ftrainval.close() + ftrain.close() + fval.close() + ftest.close() + print("Generate txt in ImageSets done.") + + print("Check datasets format, this may take a while.") + print("检查数据集格式是否符合要求,这可能需要一段时间。") + # classes_nums = np.zeros([256], np.int) + classes_nums = np.zeros([256], dtype=int) # 使用内置int + + for i in tqdm(list): + name = total_seg[i] + png_file_name = os.path.join(segfilepath, name) + if not os.path.exists(png_file_name): + raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name)) + + png = np.array(Image.open(png_file_name), np.uint8) + if len(np.shape(png)) > 2: + print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png)))) + #print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png)))) + + classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256) + + print("打印像素点的值与数量。") + print('-' * 37) + print("| %15s | %15s |"%("Key", "Value")) + print('-' * 37) + for i in range(256): + if classes_nums[i] > 0: + print("| %15s | %15s |"%(str(i), str(classes_nums[i]))) + print('-' * 37) + + if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0: + print("检测到标签中像素点的值仅包含0与255,数据格式有误。") + print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。") + elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0: + print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。") + + print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。") + print("如果格式有误,参考:") + print("https://github.com/bubbliiiing/segmentation-format-fix") \ No newline at end of file diff --git a/常见问题汇总.md b/常见问题汇总.md new file mode 100644 index 0000000..7586fe3 --- /dev/null +++ b/常见问题汇总.md @@ -0,0 +1,554 @@ +问题汇总的博客地址为[https://blog.csdn.net/weixin_44791964/article/details/107517428](https://blog.csdn.net/weixin_44791964/article/details/107517428)。 + +# 问题汇总 +## 1、下载问题 +### a、代码下载 +**问:up主,可以给我发一份代码吗,代码在哪里下载啊? +答:Github上的地址就在视频简介里。复制一下就能进去下载了。** + +**问:up主,为什么我下载的代码提示压缩包损坏? +答:重新去Github下载。** + +**问:up主,为什么我下载的代码和你在视频以及博客上的代码不一样? +答:我常常会对代码进行更新,最终以实际的代码为准。** + +### b、 权值下载 +**问:up主,为什么我下载的代码里面,model_data下面没有.pth或者.h5文件? +答:我一般会把权值上传到Github和百度网盘,在GITHUB的README里面就能找到。** + +### c、 数据集下载 +**问:up主,XXXX数据集在哪里下载啊? +答:一般数据集的下载地址我会放在README里面,基本上都有,没有的话请及时联系我添加,直接发github的issue即可**。 + +## 2、环境配置问题 +### a、现在库中所用的环境 +**pytorch代码对应的pytorch版本为1.2,博客地址对应**[https://blog.csdn.net/weixin_44791964/article/details/106037141](https://blog.csdn.net/weixin_44791964/article/details/106037141)。 + +**keras代码对应的tensorflow版本为1.13.2,keras版本是2.1.5,博客地址对应**[https://blog.csdn.net/weixin_44791964/article/details/104702142](https://blog.csdn.net/weixin_44791964/article/details/104702142)。 + +**tf2代码对应的tensorflow版本为2.2.0,无需安装keras,博客地址对应**[https://blog.csdn.net/weixin_44791964/article/details/109161493](https://blog.csdn.net/weixin_44791964/article/details/109161493)。 + +**问:你的代码某某某版本的tensorflow和pytorch能用嘛? +答:最好按照我推荐的配置,配置教程也有!其它版本的我没有试过!可能出现问题但是一般问题不大。仅需要改少量代码即可。** + +### b、30系列显卡环境配置 +30系显卡由于框架更新不可使用上述环境配置教程。 +当前我已经测试的可以用的30显卡配置如下: +**pytorch代码对应的pytorch版本为1.7.0,cuda为11.0,cudnn为8.0.5**。 + +**keras代码无法在win10下配置cuda11,在ubuntu下可以百度查询一下,配置tensorflow版本为1.15.4,keras版本是2.1.5或者2.3.1(少量函数接口不同,代码可能还需要少量调整。)** + +**tf2代码对应的tensorflow版本为2.4.0,cuda为11.0,cudnn为8.0.5**。 + +### c、GPU利用问题与环境使用问题 +**问:为什么我安装了tensorflow-gpu但是却没用利用GPU进行训练呢? +答:确认tensorflow-gpu已经装好,利用pip list查看tensorflow版本,然后查看任务管理器或者利用nvidia命令看看是否使用了gpu进行训练,任务管理器的话要看显存使用情况。** + +**问:up主,我好像没有在用gpu进行训练啊,怎么看是不是用了GPU进行训练? +答:查看是否使用GPU进行训练一般使用NVIDIA在命令行的查看命令,如果要看任务管理器的话,请看性能部分GPU的显存是否利用,或者查看任务管理器的Cuda,而非Copy。** +![在这里插入图片描述](https://img-blog.csdnimg.cn/20201013234241524.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70#pic_center) + +**问:up主,为什么我按照你的环境配置后还是不能使用? +答:请把你的GPU、CUDA、CUDNN、TF版本以及PYTORCH版本B站私聊告诉我。** + +**问:出现如下错误** +```python +Traceback (most recent call last): + File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 58, in + from tensorflow.python.pywrap_tensorflow_internal import * +File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 28, in +pywrap_tensorflow_internal = swig_import_helper() + File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 24, in swig_import_helper + _mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description) +File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\imp.py", line 243, in load_modulereturn load_dynamic(name, filename, file) +File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\imp.py", line 343, in load_dynamic + return _load(spec) +ImportError: DLL load failed: 找不到指定的模块。 +``` +**答:如果没重启过就重启一下,否则重新按照步骤安装,还无法解决则把你的GPU、CUDA、CUDNN、TF版本以及PYTORCH版本私聊告诉我。** + +### d、no module问题 +**问:为什么提示说no module name utils.utils(no module name nets.yolo、no module name nets.ssd等一系列问题)啊? +答:utils并不需要用pip装,它就在我上传的仓库的根目录,出现这个问题的原因是根目录不对,查查相对目录和根目录的概念。查了基本上就明白了。** + +**问:为什么提示说no module name matplotlib(no module name PIL,no module name cv2等等)? +答:这个库没安装打开命令行安装就好。pip install matplotlib** + +**问:为什么我已经用pip装了opencv(pillow、matplotlib等),还是提示no module name cv2? +答:没有激活环境装,要激活对应的conda环境进行安装才可以正常使用** + +**问:为什么提示说No module named 'torch' ? +答:其实我也真的很想知道为什么会有这个问题……这个pytorch没装是什么情况?一般就俩情况,一个是真的没装,还有一个是装到其它环境了,当前激活的环境不是自己装的环境。** + +**问:为什么提示说No module named 'tensorflow' ? +答:同上。** + +### e、cuda安装失败问题 +一般cuda安装前需要安装Visual Studio,装个2017版本即可。 + +### f、Ubuntu系统问题 +**所有代码在Ubuntu下可以使用,我两个系统都试过。** + +### g、VSCODE提示错误的问题 +**问:为什么在VSCODE里面提示一大堆的错误啊? +答:我也提示一大堆的错误,但是不影响,是VSCODE的问题,如果不想看错误的话就装Pycharm。** + +### h、使用cpu进行训练与预测的问题 +**对于keras和tf2的代码而言,如果想用cpu进行训练和预测,直接装cpu版本的tensorflow就可以了。** + +**对于pytorch的代码而言,如果想用cpu进行训练和预测,需要将cuda=True修改成cuda=False。** + +### i、tqdm没有pos参数问题 +**问:运行代码提示'tqdm' object has no attribute 'pos'。 +答:重装tqdm,换个版本就可以了。** + +### j、提示decode(“utf-8”)的问题 +**由于h5py库的更新,安装过程中会自动安装h5py=3.0.0以上的版本,会导致decode("utf-8")的错误! +各位一定要在安装完tensorflow后利用命令装h5py=2.10.0!** +``` +pip install h5py==2.10.0 +``` + +### k、提示TypeError: __array__() takes 1 positional argument but 2 were given错误 +可以修改pillow版本解决。 +``` +pip install pillow==8.2.0 +``` + +### l、其它问题 +**问:为什么提示TypeError: cat() got an unexpected keyword argument 'axis',Traceback (most recent call last),AttributeError: 'Tensor' object has no attribute 'bool'? +答:这是版本问题,建议使用torch1.2以上版本** +**其它有很多稀奇古怪的问题,很多是版本问题,建议按照我的视频教程安装Keras和tensorflow。比如装的是tensorflow2,就不用问我说为什么我没法运行Keras-yolo啥的。那是必然不行的。** + +## 3、目标检测库问题汇总(人脸检测和分类库也可参考) +### a、shape不匹配问题 +#### 1)、训练时shape不匹配问题 +**问:up主,为什么运行train.py会提示shape不匹配啊? +答:在keras环境中,因为你训练的种类和原始的种类不同,网络结构会变化,所以最尾部的shape会有少量不匹配。** + +#### 2)、预测时shape不匹配问题 +**问:为什么我运行predict.py会提示我说shape不匹配呀。 +在Pytorch里面是这样的:** +![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171631901.png) +在Keras里面是这样的: +![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171523380.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70) +**答:原因主要有仨: +1、在ssd、FasterRCNN里面,可能是train.py里面的num_classes没改。 +2、model_path没改。 +3、classes_path没改。 +请检查清楚了!确定自己所用的model_path和classes_path是对应的!训练的时候用到的num_classes或者classes_path也需要检查!** + +### b、显存不足问题 +**问:为什么我运行train.py下面的命令行闪的贼快,还提示OOM啥的? +答:这是在keras中出现的,爆显存了,可以改小batch_size,SSD的显存占用率是最小的,建议用SSD; +2G显存:SSD、YOLOV4-TINY +4G显存:YOLOV3 +6G显存:YOLOV4、Retinanet、M2det、Efficientdet、Faster RCNN等 +8G+显存:随便选吧。** +**需要注意的是,受到BatchNorm2d影响,batch_size不可为1,至少为2。** + +**问:为什么提示 RuntimeError: CUDA out of memory. Tried to allocate 52.00 MiB (GPU 0; 15.90 GiB total capacity; 14.85 GiB already allocated; 51.88 MiB free; 15.07 GiB reserved in total by PyTorch)? +答:这是pytorch中出现的,爆显存了,同上。** + +**问:为什么我显存都没利用,就直接爆显存了? +答:都爆显存了,自然就不利用了,模型没有开始训练。** +### c、训练问题(冻结训练,LOSS问题、训练效果问题等) +**问:为什么要冻结训练和解冻训练呀? +答:这是迁移学习的思想,因为神经网络主干特征提取部分所提取到的特征是通用的,我们冻结起来训练可以加快训练效率,也可以防止权值被破坏。** +在冻结阶段,模型的主干被冻结了,特征提取网络不发生改变。占用的显存较小,仅对网络进行微调。 +在解冻阶段,模型的主干不被冻结了,特征提取网络会发生改变。占用的显存较大,网络所有的参数都会发生改变。 + +**问:为什么我的网络不收敛啊,LOSS是XXXX。 +答:不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,我的yolo代码都没有归一化,所以LOSS值看起来比较高,LOSS的值不重要,重要的是是否在变小,预测是否有效果。** + +**问:为什么我的训练效果不好?预测了没有框(框不准)。 +答:** + +考虑几个问题: +1、目标信息问题,查看2007_train.txt文件是否有目标信息,没有的话请修改voc_annotation.py。 +2、数据集问题,小于500的自行考虑增加数据集,同时测试不同的模型,确认数据集是好的。 +3、是否解冻训练,如果数据集分布与常规画面差距过大需要进一步解冻训练,调整主干,加强特征提取能力。 +4、网络问题,比如SSD不适合小目标,因为先验框固定了。 +5、训练时长问题,有些同学只训练了几代表示没有效果,按默认参数训练完。 +6、确认自己是否按照步骤去做了,如果比如voc_annotation.py里面的classes是否修改了等。 +7、不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,LOSS的值不重要,重要的是是否收敛。 + +**问:我怎么出现了gbk什么的编码错误啊:** +```python +UnicodeDecodeError: 'gbk' codec can't decode byte 0xa6 in position 446: illegal multibyte sequence +``` +**答:标签和路径不要使用中文,如果一定要使用中文,请注意处理的时候编码的问题,改成打开文件的encoding方式改为utf-8。** + +**问:我的图片是xxx*xxx的分辨率的,可以用吗!** +**答:可以用,代码里面会自动进行resize或者数据增强。** + +**问:怎么进行多GPU训练? +答:pytorch的大多数代码可以直接使用gpu训练,keras的话直接百度就好了,实现并不复杂,我没有多卡没法详细测试,还需要各位同学自己努力了。** +### d、灰度图问题 +**问:能不能训练灰度图(预测灰度图)啊? +答:我的大多数库会将灰度图转化成RGB进行训练和预测,如果遇到代码不能训练或者预测灰度图的情况,可以尝试一下在get_random_data里面将Image.open后的结果转换成RGB,预测的时候也这样试试。(仅供参考)** + +### e、断点续练问题 +**问:我已经训练过几个世代了,能不能从这个基础上继续开始训练 +答:可以,你在训练前,和载入预训练权重一样载入训练过的权重就行了。一般训练好的权重会保存在logs文件夹里面,将model_path修改成你要开始的权值的路径即可。** + +### f、预训练权重的问题 +**问:如果我要训练其它的数据集,预训练权重要怎么办啊?** +**答:数据的预训练权重对不同数据集是通用的,因为特征是通用的,预训练权重对于99%的情况都必须要用,不用的话权值太过随机,特征提取效果不明显,网络训练的结果也不会好。** + +**问:up,我修改了网络,预训练权重还能用吗? +答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** +权值匹配的方式可以参考如下: +```python +# 加快模型训练的效率 +print('Loading weights into state dict...') +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model_dict = model.state_dict() +pretrained_dict = torch.load(model_path, map_location=device) +a = {} +for k, v in pretrained_dict.items(): + try: + if np.shape(model_dict[k]) == np.shape(v): + a[k]=v + except: + pass +model_dict.update(a) +model.load_state_dict(model_dict) +print('Finished!') +``` + +**问:我要怎么不使用预训练权重啊? +答:把载入预训练权重的代码注释了就行。** + +**问:为什么我不使用预训练权重效果这么差啊? +答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,voc07+12、coco+voc07+12效果都不一样,预训练权重还是非常重要的。** + +### g、视频检测问题与摄像头检测问题 +**问:怎么用摄像头检测呀? +答:predict.py修改参数可以进行摄像头检测,也有视频详细解释了摄像头检测的思路。** + +**问:怎么用视频检测呀? +答:同上** +### h、从0开始训练问题 +**问:怎么在模型上从0开始训练? +答:在算力不足与调参能力不足的情况下从0开始训练毫无意义。模型特征提取能力在随机初始化参数的情况下非常差。没有好的参数调节能力和算力,无法使得网络正常收敛。** +如果一定要从0开始,那么训练的时候请注意几点: + - 不载入预训练权重。 + - 不要进行冻结训练,注释冻结模型的代码。 + +**问:为什么我不使用预训练权重效果这么差啊? +答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,voc07+12、coco+voc07+12效果都不一样,预训练权重还是非常重要的。** + +### i、保存问题 +**问:检测完的图片怎么保存? +答:一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。详细看看predict.py文件的注释。** + +**问:怎么用视频保存呀? +答:详细看看predict.py文件的注释。** + +### j、遍历问题 +**问:如何对一个文件夹的图片进行遍历? +答:一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了,详细看看predict.py文件的注释。** + +**问:如何对一个文件夹的图片进行遍历?并且保存。 +答:遍历的话一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了。保存的话一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。如果有些库用的是cv2,那就是查一下cv2怎么保存图片。详细看看predict.py文件的注释。** + +### k、路径问题(No such file or directory) +**问:我怎么出现了这样的错误呀:** +```python +FileNotFoundError: 【Errno 2】 No such file or directory +…………………………………… +…………………………………… +``` +**答:去检查一下文件夹路径,查看是否有对应文件;并且检查一下2007_train.txt,其中文件路径是否有错。** +关于路径有几个重要的点: +**文件夹名称中一定不要有空格。 +注意相对路径和绝对路径。 +多百度路径相关的知识。** + +**所有的路径问题基本上都是根目录问题,好好查一下相对目录的概念!** +### l、和原版比较问题 +**问:你这个代码和原版比怎么样,可以达到原版的效果么? +答:基本上可以达到,我都用voc数据测过,我没有好显卡,没有能力在coco上测试与训练。** + +**问:你有没有实现yolov4所有的tricks,和原版差距多少? +答:并没有实现全部的改进部分,由于YOLOV4使用的改进实在太多了,很难完全实现与列出来,这里只列出来了一些我比较感兴趣,而且非常有效的改进。论文中提到的SAM(注意力机制模块),作者自己的源码也没有使用。还有其它很多的tricks,不是所有的tricks都有提升,我也没法实现全部的tricks。至于和原版的比较,我没有能力训练coco数据集,根据使用过的同学反应差距不大。** + +### m、FPS问题(检测速度问题) +**问:你这个FPS可以到达多少,可以到 XX FPS么? +答:FPS和机子的配置有关,配置高就快,配置低就慢。** + +**问:为什么我用服务器去测试yolov4(or others)的FPS只有十几? +答:检查是否正确安装了tensorflow-gpu或者pytorch的gpu版本,如果已经正确安装,可以去利用time.time()的方法查看detect_image里面,哪一段代码耗时更长(不仅只有网络耗时长,其它处理部分也会耗时,如绘图等)。** + +**问:为什么论文中说速度可以达到XX,但是这里却没有? +答:检查是否正确安装了tensorflow-gpu或者pytorch的gpu版本,如果已经正确安装,可以去利用time.time()的方法查看detect_image里面,哪一段代码耗时更长(不仅只有网络耗时长,其它处理部分也会耗时,如绘图等)。有些论文还会使用多batch进行预测,我并没有去实现这个部分。** + +### n、预测图片不显示问题 +**问:为什么你的代码在预测完成后不显示图片?只是在命令行告诉我有什么目标。 +答:给系统安装一个图片查看器就行了。** + +### o、算法评价问题(目标检测的map、PR曲线、Recall、Precision等) +**问:怎么计算map? +答:看map视频,都一个流程。** + +**问:计算map的时候,get_map.py里面有一个MINOVERLAP是什么用的,是iou吗? +答:是iou,它的作用是判断预测框和真实框的重合成度,如果重合程度大于MINOVERLAP,则预测正确。** + +**问:为什么get_map.py里面的self.confidence(self.score)要设置的那么小? +答:看一下map的视频的原理部分,要知道所有的结果然后再进行pr曲线的绘制。** + +**问:能不能说说怎么绘制PR曲线啥的呀。 +答:可以看mAP视频,结果里面有PR曲线。** + +**问:怎么计算Recall、Precision指标。 +答:这俩指标应该是相对于特定的置信度的,计算map的时候也会获得。** + +### p、coco数据集训练问题 +**问:目标检测怎么训练COCO数据集啊?。 +答:coco数据训练所需要的txt文件可以参考qqwweee的yolo3的库,格式都是一样的。** + +### q、模型优化(模型修改)问题 +**问:up,YOLO系列使用Focal LOSS的代码你有吗,有提升吗? +答:很多人试过,提升效果也不大(甚至变的更Low),它自己有自己的正负样本的平衡方式。** + +**问:up,我修改了网络,预训练权重还能用吗? +答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** +权值匹配的方式可以参考如下: +```python +# 加快模型训练的效率 +print('Loading weights into state dict...') +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model_dict = model.state_dict() +pretrained_dict = torch.load(model_path, map_location=device) +a = {} +for k, v in pretrained_dict.items(): + try: + if np.shape(model_dict[k]) == np.shape(v): + a[k]=v + except: + pass +model_dict.update(a) +model.load_state_dict(model_dict) +print('Finished!') +``` + +**问:up,怎么修改模型啊,我想发个小论文! +答:建议看看yolov3和yolov4的区别,然后看看yolov4的论文,作为一个大型调参现场非常有参考意义,使用了很多tricks。我能给的建议就是多看一些经典模型,然后拆解里面的亮点结构并使用。** + +### r、部署问题 +我没有具体部署到手机等设备上过,所以很多部署问题我并不了解…… + +## 4、语义分割库问题汇总 +### a、shape不匹配问题 +#### 1)、训练时shape不匹配问题 +**问:up主,为什么运行train.py会提示shape不匹配啊? +答:在keras环境中,因为你训练的种类和原始的种类不同,网络结构会变化,所以最尾部的shape会有少量不匹配。** + +#### 2)、预测时shape不匹配问题 +**问:为什么我运行predict.py会提示我说shape不匹配呀。 +在Pytorch里面是这样的:** +![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171631901.png) +在Keras里面是这样的: +![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171523380.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70) +**答:原因主要有二: +1、train.py里面的num_classes没改。 +2、预测时num_classes没改。 +请检查清楚!训练和预测的时候用到的num_classes都需要检查!** + +### b、显存不足问题 +**问:为什么我运行train.py下面的命令行闪的贼快,还提示OOM啥的? +答:这是在keras中出现的,爆显存了,可以改小batch_size。** + +**需要注意的是,受到BatchNorm2d影响,batch_size不可为1,至少为2。** + +**问:为什么提示 RuntimeError: CUDA out of memory. Tried to allocate 52.00 MiB (GPU 0; 15.90 GiB total capacity; 14.85 GiB already allocated; 51.88 MiB free; 15.07 GiB reserved in total by PyTorch)? +答:这是pytorch中出现的,爆显存了,同上。** + +**问:为什么我显存都没利用,就直接爆显存了? +答:都爆显存了,自然就不利用了,模型没有开始训练。** + +### c、训练问题(冻结训练,LOSS问题、训练效果问题等) +**问:为什么要冻结训练和解冻训练呀? +答:这是迁移学习的思想,因为神经网络主干特征提取部分所提取到的特征是通用的,我们冻结起来训练可以加快训练效率,也可以防止权值被破坏。** +**在冻结阶段,模型的主干被冻结了,特征提取网络不发生改变。占用的显存较小,仅对网络进行微调。** +**在解冻阶段,模型的主干不被冻结了,特征提取网络会发生改变。占用的显存较大,网络所有的参数都会发生改变。** + +**问:为什么我的网络不收敛啊,LOSS是XXXX。 +答:不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,我的yolo代码都没有归一化,所以LOSS值看起来比较高,LOSS的值不重要,重要的是是否在变小,预测是否有效果。** + +**问:为什么我的训练效果不好?预测了没有目标,结果是一片黑。 +答:** +**考虑几个问题: +1、数据集问题,这是最重要的问题。小于500的自行考虑增加数据集;一定要检查数据集的标签,视频中详细解析了VOC数据集的格式,但并不是有输入图片有输出标签即可,还需要确认标签的每一个像素值是否为它对应的种类。很多同学的标签格式不对,最常见的错误格式就是标签的背景为黑,目标为白,此时目标的像素点值为255,无法正常训练,目标需要为1才行。 +2、是否解冻训练,如果数据集分布与常规画面差距过大需要进一步解冻训练,调整主干,加强特征提取能力。 +3、网络问题,可以尝试不同的网络。 +4、训练时长问题,有些同学只训练了几代表示没有效果,按默认参数训练完。 +5、确认自己是否按照步骤去做了。 +6、不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,LOSS的值不重要,重要的是是否收敛。** + + + +**问:为什么我的训练效果不好?对小目标预测不准确。 +答:对于deeplab和pspnet而言,可以修改一下downsample_factor,当downsample_factor为16的时候下采样倍数过多,效果不太好,可以修改为8。** + +**问:我怎么出现了gbk什么的编码错误啊:** +```python +UnicodeDecodeError: 'gbk' codec can't decode byte 0xa6 in position 446: illegal multibyte sequence +``` +**答:标签和路径不要使用中文,如果一定要使用中文,请注意处理的时候编码的问题,改成打开文件的encoding方式改为utf-8。** + +**问:我的图片是xxx*xxx的分辨率的,可以用吗!** +**答:可以用,代码里面会自动进行resize或者数据增强。** + +**问:怎么进行多GPU训练? +答:pytorch的大多数代码可以直接使用gpu训练,keras的话直接百度就好了,实现并不复杂,我没有多卡没法详细测试,还需要各位同学自己努力了。** + +### d、灰度图问题 +**问:能不能训练灰度图(预测灰度图)啊? +答:我的大多数库会将灰度图转化成RGB进行训练和预测,如果遇到代码不能训练或者预测灰度图的情况,可以尝试一下在get_random_data里面将Image.open后的结果转换成RGB,预测的时候也这样试试。(仅供参考)** + +### e、断点续练问题 +**问:我已经训练过几个世代了,能不能从这个基础上继续开始训练 +答:可以,你在训练前,和载入预训练权重一样载入训练过的权重就行了。一般训练好的权重会保存在logs文件夹里面,将model_path修改成你要开始的权值的路径即可。** + +### f、预训练权重的问题 + +**问:如果我要训练其它的数据集,预训练权重要怎么办啊?** +**答:数据的预训练权重对不同数据集是通用的,因为特征是通用的,预训练权重对于99%的情况都必须要用,不用的话权值太过随机,特征提取效果不明显,网络训练的结果也不会好。** + +**问:up,我修改了网络,预训练权重还能用吗? +答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** +权值匹配的方式可以参考如下: + +```python +# 加快模型训练的效率 +print('Loading weights into state dict...') +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model_dict = model.state_dict() +pretrained_dict = torch.load(model_path, map_location=device) +a = {} +for k, v in pretrained_dict.items(): + try: + if np.shape(model_dict[k]) == np.shape(v): + a[k]=v + except: + pass +model_dict.update(a) +model.load_state_dict(model_dict) +print('Finished!') +``` + +**问:我要怎么不使用预训练权重啊? +答:把载入预训练权重的代码注释了就行。** + +**问:为什么我不使用预训练权重效果这么差啊? +答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,预训练权重还是非常重要的。** + +### g、视频检测问题与摄像头检测问题 +**问:怎么用摄像头检测呀? +答:predict.py修改参数可以进行摄像头检测,也有视频详细解释了摄像头检测的思路。** + +**问:怎么用视频检测呀? +答:同上** + +### h、从0开始训练问题 +**问:怎么在模型上从0开始训练? +答:在算力不足与调参能力不足的情况下从0开始训练毫无意义。模型特征提取能力在随机初始化参数的情况下非常差。没有好的参数调节能力和算力,无法使得网络正常收敛。** +如果一定要从0开始,那么训练的时候请注意几点: + - 不载入预训练权重。 + - 不要进行冻结训练,注释冻结模型的代码。 + +**问:为什么我不使用预训练权重效果这么差啊? +答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,预训练权重还是非常重要的。** + +### i、保存问题 +**问:检测完的图片怎么保存? +答:一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。详细看看predict.py文件的注释。** + +**问:怎么用视频保存呀? +答:详细看看predict.py文件的注释。** + +### j、遍历问题 +**问:如何对一个文件夹的图片进行遍历? +答:一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了,详细看看predict.py文件的注释。** + +**问:如何对一个文件夹的图片进行遍历?并且保存。 +答:遍历的话一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了。保存的话一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。如果有些库用的是cv2,那就是查一下cv2怎么保存图片。详细看看predict.py文件的注释。** + +### k、路径问题(No such file or directory) +**问:我怎么出现了这样的错误呀:** +```python +FileNotFoundError: 【Errno 2】 No such file or directory +…………………………………… +…………………………………… +``` + +**答:去检查一下文件夹路径,查看是否有对应文件;并且检查一下2007_train.txt,其中文件路径是否有错。** +关于路径有几个重要的点: +**文件夹名称中一定不要有空格。 +注意相对路径和绝对路径。 +多百度路径相关的知识。** + +**所有的路径问题基本上都是根目录问题,好好查一下相对目录的概念!** + +### l、FPS问题(检测速度问题) +**问:你这个FPS可以到达多少,可以到 XX FPS么? +答:FPS和机子的配置有关,配置高就快,配置低就慢。** + +**问:为什么论文中说速度可以达到XX,但是这里却没有? +答:检查是否正确安装了tensorflow-gpu或者pytorch的gpu版本,如果已经正确安装,可以去利用time.time()的方法查看detect_image里面,哪一段代码耗时更长(不仅只有网络耗时长,其它处理部分也会耗时,如绘图等)。有些论文还会使用多batch进行预测,我并没有去实现这个部分。** + +### m、预测图片不显示问题 +**问:为什么你的代码在预测完成后不显示图片?只是在命令行告诉我有什么目标。 +答:给系统安装一个图片查看器就行了。** + +### n、算法评价问题(miou) +**问:怎么计算miou? +答:参考视频里的miou测量部分。** + +**问:怎么计算Recall、Precision指标。 +答:现有的代码还无法获得,需要各位同学理解一下混淆矩阵的概念,然后自行计算一下。** + +### o、模型优化(模型修改)问题 +**问:up,我修改了网络,预训练权重还能用吗? +答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** +权值匹配的方式可以参考如下: + +```python +# 加快模型训练的效率 +print('Loading weights into state dict...') +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model_dict = model.state_dict() +pretrained_dict = torch.load(model_path, map_location=device) +a = {} +for k, v in pretrained_dict.items(): + try: + if np.shape(model_dict[k]) == np.shape(v): + a[k]=v + except: + pass +model_dict.update(a) +model.load_state_dict(model_dict) +print('Finished!') +``` + + + +**问:up,怎么修改模型啊,我想发个小论文! +答:建议看看目标检测中yolov4的论文,作为一个大型调参现场非常有参考意义,使用了很多tricks。我能给的建议就是多看一些经典模型,然后拆解里面的亮点结构并使用。常用的tricks如注意力机制什么的,可以试试。** + +### p、部署问题 +我没有具体部署到手机等设备上过,所以很多部署问题我并不了解…… + +## 5、交流群问题 +**问:up,有没有QQ群啥的呢? +答:没有没有,我没有时间管理QQ群……** + +## 6、怎么学习的问题 +**问:up,你的学习路线怎么样的?我是个小白我要怎么学? +答:这里有几点需要注意哈 +1、我不是高手,很多东西我也不会,我的学习路线也不一定适用所有人。 +2、我实验室不做深度学习,所以我很多东西都是自学,自己摸索,正确与否我也不知道。 +3、我个人觉得学习更靠自学** +学习路线的话,我是先学习了莫烦的python教程,从tensorflow、keras、pytorch入门,入门完之后学的SSD,YOLO,然后了解了很多经典的卷积网,后面就开始学很多不同的代码了,我的学习方法就是一行一行的看,了解整个代码的执行流程,特征层的shape变化等,花了很多时间也没有什么捷径,就是要花时间吧。 \ No newline at end of file