一、介紹
缺陷檢測被廣泛使用于布匹瑕疵檢測、工件表面質量檢測、航空航天領域等。傳統(tǒng)的算法對規(guī)則缺陷以及場景比較簡單的場合,能夠很好工作,但是對特征不明顯的、形狀多樣、場景比較混亂的場合,則不再適用。近年來,基于深度學習的識別算法越來越成熟,許多公司開始嘗試把深度學習算法應用到工業(yè)場合中。
二、缺陷數據
這里以布匹數據作為案例,常見的有以下三種缺陷,磨損、白點、多線。
如何制作訓練數據呢?這里是在原圖像上進行截取,截取到小圖像,比如上述圖像是512x512,這里我裁剪成64x64的小圖像。這里以第一類缺陷為例,下面是制作數據的方法。
注意:在制作缺陷數據的時候,缺陷面積至少占截取圖像的2/3,否則舍棄掉,不做為缺陷圖像。
一般來說,缺陷數據都要比背景數據少很多, 最后通過增強后的數據,缺陷:背景=1:1,每類在1000幅左右~~~
三、網絡結構
具體使用的網絡結構如下所示,輸入大小就是64x64x3,采用的是截取的小圖像的大小。每個Conv卷積層后都接BN層,具體層參數如下所示。
Conv1:64x3x3
Conv2:128x3x3 ResNetBlock和DenseNetBlock各兩個,具體細節(jié)請參考殘差網絡和DenseNet。
Add:把殘差模塊輸出的結果和DenseNetBlock輸出的結果在對應feature map上進行相加,相加方式和殘差模塊相同。
注意,其實這里是為了更好的提取特征,方式不一定就是殘差模塊+DenseNetBlock,也可以是inception,或者其它。
Conv3:128x3x3 Maxpool:stride=2,size=2x2 FC1:4096 Dropout1:0.5 FC2:1024 Dropout1:0.5 Softmax:對應的就是要分的類別,在這里我是二分類。
關于最后的損失函數,建議選擇Focal Loss,這是何凱明大神的杰作,源碼如下所示:
def focal_loss(y_true, y_pred): pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) return -K.sum(K.pow(1. - pt_1, 2) * K.log(pt_1))
數據做好,就可以開始訓練了~~~
四、整幅場景圖像的缺陷檢測
上述訓練的網絡,輸入是64x64x3的,但是整幅場景圖像卻是512x512的,這個輸入和模型的輸入對不上號,這怎么辦呢?其實,可以把訓練好的模型參數提取出來,然后賦值到另外一個新的模型中,然后把新的模型的輸入改成512x512就好,只是最后在conv3+maxpool層提取的feature map比較大,這個時候把feature map映射到原圖,比如原模型在最后一個maxpool層后,輸出的feature map尺寸是8x8x128,其中128是通道數。如果輸入改成512x512,那輸出的feature map就成了64x64x128,這里的每個8x8就對應原圖上的64x64,這樣就可以使用一個8x8的滑動窗口在64x64x128的feature map上進行滑動裁剪特征。然后把裁剪的特征進行fatten,送入到全連接層。具體如下圖所示。
全連接層也需要重新建立一個模型,輸入是flatten之后的輸入,輸出是softmax層的輸出。這是一個簡單的小模型。
在這里提供一個把訓練好的模型參數,讀取到另外一個模型中的代碼
#提取特征的大模型 def read_big_model(inputs): # 第一個卷積和最大池化層 X = Conv2D(16, (3, 3), name="conv2d_1")(inputs) X = BatchNormalization(name="batch_normalization_1")(X) X = Activation('relu', name="activation_1")(X) X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name="max_pooling2d_1")(X) # google_inception模塊 conv_1 = Conv2D(32, (1, 1), padding='same', name='conv2d_2')(X) conv_1 = BatchNormalization(name='batch_normalization_2')(conv_1) conv_1 = Activation('relu', name='activation_2')(conv_1) conv_2 = Conv2D(32, (3, 3), padding='same', name='conv2d_3')(X) conv_2 = BatchNormalization(name='batch_normalization_3')(conv_2) conv_2 = Activation('relu', name='activation_3')(conv_2) conv_3 = Conv2D(32, (5, 5), padding='same', name='conv2d_4')(X) conv_3 = BatchNormalization(name='batch_normalization_4')(conv_3) conv_3 = Activation('relu', name='activation_4')(conv_3) pooling_1 = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same', name='max_pooling2d_2')(X) X = merge([conv_1, conv_2, conv_3, pooling_1], mode='concat', name='merge_1') X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='max_pooling2d_3')(X) # 這里的尺寸變成16x16x112 X = Conv2D(64, (3, 3), kernel_regularizer=regularizers.l2(0.01), padding='same', name='conv2d_5')(X) X = BatchNormalization(name='batch_normalization_5')(X) X = Activation('relu', name='activation_5')(X) X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='max_pooling2d_4')(X) # 這里尺寸變成8x8x64 X = Conv2D(128, (3, 3), padding='same', name='conv2d_6')(X) X = BatchNormalization(name='batch_normalization_6')(X) X = Activation('relu', name='activation_6')(X) X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='max_pooling2d_5')(X) # 這里尺寸變成4x4x128 return X def read_big_model_classify(inputs_sec): X_ = Flatten(name='flatten_1')(inputs_sec) X_ = Dense(256, activation='relu', name="dense_1")(X_) X_ = Dropout(0.5, name="dropout_1")(X_) predictions = Dense(2, activation='softmax', name="dense_2")(X_) return predictions #建立的小模型
inputs=Input(shape=(512,512,3)) X=read_big_model(inputs)#讀取訓練好模型的網絡參數 #建立第一個model model=Model(inputs=inputs, outputs=X) model.load_weights('model_halcon.h5', by_name=True)
五、識別定位結果
上述的滑窗方式可以定位到原圖像,8x8的滑窗定位到原圖就是64x64,同樣,在原圖中根據滑窗方式不同(在這里選擇的是左右和上下的步長為16個像素)識別定位到的缺陷位置也不止一個,這樣就涉及到定位精度了。在這里選擇投票的方式,其實就是對原圖像上每個被標記的像素位置進行計數,當數字大于指定的閾值,就被判斷為缺陷像素。
識別結果如下圖所示:
六、一些Trick
對上述案例來說,其實64x64大小的定位框不夠準確,可以考慮訓練一個32x32大小的模型,然后應用方式和64x64的模型相同,最后基于32x32的定位位置和64x64的定位位置進行投票,但是這會涉及到一個問題,就是時間上會增加很多,要慎用。
對背景和前景相差不大的時候,網絡盡量不要太深,因為太深的網絡到后面基本學到的東西都是相同的,沒有很好的區(qū)分能力,這也是我在這里為什么不用object detection的原因,這些檢測模型網絡,深度動輒都是50+,效果反而不好,雖然有殘差模塊作為backbone。
但是對背景和前景相差很大的時候,可以選擇較深的網絡,這個時候,object detection方式就派上用場了。
審核編輯:劉清
-
CMOS
+關注
關注
58文章
5710瀏覽量
235428 -
CCD
+關注
關注
32文章
880瀏覽量
142228 -
機器視覺
+關注
關注
161文章
4369瀏覽量
120293 -
工業(yè)相機
+關注
關注
5文章
322瀏覽量
23623 -
機器視覺系統(tǒng)
+關注
關注
1文章
83瀏覽量
18867
原文標題:基于深度學習識別模型的缺陷檢測
文章出處:【微信號:vision263com,微信公眾號:新機器視覺】歡迎添加關注!文章轉載請注明出處。
發(fā)布評論請先 登錄
相關推薦
評論