這是一個Facebook的目標檢測Transformer (DETR)的完整指南。
介紹
DEtection TRansformer (DETR)是Facebook研究團隊巧妙地利用了Transformer 架構開發的一個目標檢測模型。在這篇文章中,我將通過分析DETR架構的內部工作方式來幫助提供一些關于它的直覺。
下面,我將解釋一些結構,但是如果你只是想了解如何使用模型,可以直接跳到代碼部分。
結構
DETR模型由一個預訓練的CNN骨干(如ResNet)組成,它產生一組低維特征集。這些特征被格式化為一個特征集合并添加位置編碼,輸入一個由Transformer組成的編碼器和解碼器中,和原始的Transformer論文中描述的Encoder-Decoder的使用方式非常的類似。解碼器的輸出然后被送入固定數量的預測頭,這些預測頭由預定義數量的前饋網絡組成。每個預測頭的輸出都包含一個類預測和一個預測框。損失是通過計算二分匹配損失來計算的。
該模型做出了預定義數量的預測,并且每個預測都是并行計算的。
CNN主干
假設我們的輸入圖像,有三個輸入通道。CNN backbone由一個(預訓練過的)CNN(通常是ResNet)組成,我們用它來生成C個具有寬度W和高度H的低維特征(在實踐中,我們設置C=2048, W=W?/32和H=H?/32)。
這留給我們的是C個二維特征,由于我們將把這些特征傳遞給一個transformer,每個特征必須允許編碼器將每個特征處理為一個序列的方式重新格式化。這是通過將特征矩陣扁平化為H?W向量,然后將每個向量連接起來來實現的。
扁平化的卷積特征再加上空間位置編碼,位置編碼既可以學習,也可以預定義。
The Transformer
Transformer幾乎與原始的編碼器-解碼器架構完全相同。不同之處在于,每個解碼器層并行解碼N個(預定義的數目)目標。該模型還學習了一組N個目標的查詢,這些查詢是(類似于編碼器)學習出來的位置編碼。
目標查詢
下圖描述了N=20個學習出來的目標查詢(稱為prediction slots)如何聚焦于一張圖像的不同區域。
“我們觀察到,在不同的操作模式下,每個slot 都會學習特定的區域和框大小?!?—— DETR的作者
理解目標查詢的直觀方法是想象每個目標查詢都是一個人。每個人都可以通過注意力來查看圖像的某個區域。一個目標查詢總是會問圖像中心是什么,另一個總是會問左下角是什么,以此類推。
使用PyTorch實現簡單的DETR
importtorch importtorch.nnasnn fromtorchvision.modelsimportresnet50 classSimpleDETR(nn.Module): """ MinimalExampleoftheDetectionTransformermodelwithlearnedpositionalembedding """ def__init__(self,num_classes,hidden_dim,num_heads, num_enc_layers,num_dec_layers): super(SimpleDETR,self).__init__() self.num_classes=num_classes self.hidden_dim=hidden_dim self.num_heads=num_heads self.num_enc_layers=num_enc_layers self.num_dec_layers=num_dec_layers #CNNBackbone self.backbone=nn.Sequential( *list(resnet50(pretrained=True).children())[:-2]) self.conv=nn.Conv2d(2048,hidden_dim,1) #Transformer self.transformer=nn.Transformer(hidden_dim,num_heads, num_enc_layers,num_dec_layers) #PredictionHeads self.to_classes=nn.Linear(hidden_dim,num_classes+1) self.to_bbox=nn.Linear(hidden_dim,4) #PositionalEncodings self.object_query=nn.Parameter(torch.rand(100,hidden_dim)) self.row_embed=nn.Parameter(torch.rand(50,hidden_dim//2) self.col_embed=nn.Parameter(torch.rand(50,hidden_dim//2)) defforward(self,X): X=self.backbone(X) h=self.conv(X) H,W=h.shape[-2:] pos_enc=torch.cat([ self.col_embed[:W].unsqueeze(0).repeat(H,1,1), self.row_embed[:H].unsqueeze(1).repeat(1,W,1)], dim=-1).flatten(0,1).unsqueeze(1) h=self.transformer(pos_enc+h.flatten(2).permute(2,0,1), self.object_query.unsqueeze(1)) class_pred=self.to_classes(h) bbox_pred=self.to_bbox(h).sigmoid() returnclass_pred,bbox_pred
二分匹配損失 (Optional)
讓為預測的集合,其中是包括了預測類別(可以是空類別)和包圍框的二元組,其中上劃線表示框的中心點,和表示框的寬和高。
設y為ground truth集合。假設y和?之間的損失為L,每一個y?和??之間的損失為L?。由于我們是在集合的層次上工作,損失L必須是排列不變的,這意味著無論我們如何排序預測,我們都將得到相同的損失。因此,我們想找到一個排列,它將預測的索引映射到ground truth目標的索引上。在數學上,我們求解:
計算的過程稱為尋找最優的二元匹配。這可以用匈牙利算法找到。但為了找到最優匹配,我們需要實際定義一個損失函數,計算和之間的匹配成本。
回想一下,我們的預測包含一個邊界框和一個類?,F在讓我們假設類預測實際上是一個類集合上的概率分布。那么第i個預測的總損失將是類預測產生的損失和邊界框預測產生的損失之和。作者在http://arxiv.org/abs/1906.05909中將這種損失定義為邊界框損失和類預測概率的差異:
其中,是的argmax,是是來自包圍框的預測的損失,如果,則表示匹配損失為0。
框損失的計算為預測值與ground truth的L?損失和的GIOU損失的線性組合。同樣,如果你想象兩個不相交的框,那么框的錯誤將不會提供任何有意義的上下文(我們可以從下面的框損失的定義中看到)。
其中,λ???和是超參數。注意,這個和也是面積和距離產生的誤差的組合。為什么會這樣呢?
可以把上面的等式看作是與預測相關聯的總損失,其中面積誤差的重要性是λ???,距離誤差的重要性是。
現在我們來定義GIOU損失函數。定義如下:
由于我們從已知的已知類的數目來預測類,那么類預測就是一個分類問題,因此我們可以使用交叉熵損失來計算類預測誤差。我們將損失函數定義為每N個預測損失的總和:
為目標檢測使用DETR
在這里,你可以學習如何加載預訓練的DETR模型,以便使用PyTorch進行目標檢測。
加載模型
首先導入需要的模塊。
#Importrequiredmodules importtorch fromtorchvisionimporttransformsasTimportrequests#forloadingimagesfromweb fromPILimportImage#forviewingimages importmatplotlib.pyplotasplt
下面的代碼用ResNet50作為CNN骨干從torch hub加載預訓練的模型。其他主干請參見DETR github:https://github.com/facebookresearch/detr
detr=torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
加載一張圖像
要從web加載圖像,我們使用requests庫:
url='https://www.tempetourism.com/wp-content/uploads/Postino-Downtown-Tempe-2.jpg'#Sampleimageimage=Image.open(requests.get(url,stream=True).raw)plt.imshow(image) plt.show()
設置目標檢測的Pipeline
為了將圖像輸入到模型中,我們需要將PIL圖像轉換為張量,這是通過使用torchvision的transforms庫來完成的。
transform=T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])])
上面的變換調整了圖像的大小,將PIL圖像進行轉換,并用均值-標準差對圖像進行歸一化。其中[0.485,0.456,0.406]為各顏色通道的均值,[0.229,0.224,0.225]為各顏色通道的標準差。
我們裝載的模型是預先在COCO Dataset上訓練的,有91個類,還有一個表示空類(沒有目標)的附加類。我們用下面的代碼手動定義每個標簽:
CLASSES= ['N/A','Person','Bicycle','Car','Motorcycle','Airplane','Bus','Train','Truck','Boat','Traffic-Light','Fire-Hydrant','N/A','Stop-Sign','ParkingMeter','Bench','Bird','Cat','Dog','Horse','Sheep','Cow','Elephant','Bear','Zebra','Giraffe','N/A','Backpack','Umbrella','N/A','N/A','Handbag','Tie','Suitcase','Frisbee','Skis','Snowboard','Sports-Ball','Kite','BaseballBat','BaseballGlove','Skateboard','Surfboard','TennisRacket','Bottle','N/A','WineGlass','Cup','Fork','Knife','Spoon','Bowl','Banana','Apple','Sandwich','Orange','Broccoli','Carrot','Hot-Dog','Pizza','Donut','Cake','Chair','Couch','PottedPlant','Bed','N/A','DiningTable','N/A','N/A','Toilet','N/A','TV','Laptop','Mouse','Remote','Keyboard','Cell-Phone','Microwave','Oven','Toaster','Sink','Refrigerator','N/A','Book','Clock','Vase','Scissors','Teddy-Bear','Hair-Dryer','Toothbrush']
如果我們想輸出不同顏色的邊框,我們可以手動定義我們想要的RGB格式的顏色
COLORS=[ [0.000,0.447,0.741], [0.850,0.325,0.098], [0.929,0.694,0.125], [0.494,0.184,0.556], [0.466,0.674,0.188], [0.301,0.745,0.933] ]
格式化輸出
我們還需要重新格式化模型的輸出。給定一個轉換后的圖像,模型將輸出一個字典,包含100個預測類的概率和100個預測邊框。
每個包圍框的形式為(x, y, w, h),其中(x,y)為包圍框的中心(包圍框是單位正方形[0,1]×[0,1]), w, h為包圍框的寬度和高度。因此,我們需要將邊界框輸出轉換為初始和最終坐標,并重新縮放框以適應圖像的實際大小。
下面的函數返回邊界框端點:
#Getcoordinates(x0,y0,x1,y0)frommodeloutput(x,y,w,h)defget_box_coords(boxes): x,y,w,h=boxes.unbind(1) x0,y0=(x-0.5*w),(y-0.5*h) x1,y1=(x+0.5*w),(y+0.5*h) box=[x0,y0,x1,y1] returntorch.stack(box,dim=1)
我們還需要縮放了框的大小。下面的函數為我們做了這些:
#Scaleboxfrom[0,1]x[0,1]to[0,width]x[0,height]defscale_boxes(output_box,width,height): box_coords=get_box_coords(output_box) scale_tensor=torch.Tensor( [width,height,width,height]).to( torch.cuda.current_device())returnbox_coords*scale_tensor
現在我們需要一個函數來封裝我們的目標檢測pipeline。下面的detect函數為我們完成了這項工作。
#ObjectDetectionPipelinedefdetect(im,model,transform): device=torch.cuda.current_device() width=im.size[0] height=im.size[1] #mean-stdnormalizetheinputimage(batch-size:1) img=transform(im).unsqueeze(0) img=img.to(device) #demomodelonlysupportbydefaultimageswithaspectratiobetween0.5and2assertimg.shape[-2]<=?1600?and?img.shape[-1]?<=?1600,????#?propagate?through?the?model ????outputs?=?model(img)????#?keep?only?predictions?with?0.7+?confidence ????probas?=?outputs['pred_logits'].softmax(-1)[0,?:,?:-1] ????keep?=?probas.max(-1).values?>0.85 #convertboxesfrom[0;1]toimagescales bboxes_scaled=scale_boxes(outputs['pred_boxes'][0,keep],width,height)returnprobas[keep],bboxes_scaled
現在,我們需要做的是運行以下程序來獲得我們想要的輸出:
probs,bboxes=detect(image,detr,transform)
繪制結果
現在我們有了檢測到的目標,我們可以使用一個簡單的函數來可視化它們。
#PlotPredictedBoundingBoxesdefplot_results(pil_img,prob,boxes,labels=True): plt.figure(figsize=(16,10)) plt.imshow(pil_img) ax=plt.gca() forprob,(x0,y0,x1,y1),colorinzip(prob,boxes.tolist(),COLORS*100):ax.add_patch(plt.Rectangle((x0,y0),x1-x0,y1-y0, fill=False,color=color,linewidth=2)) cl=prob.argmax() text=f'{CLASSES[cl]}:{prob[cl]:0.2f}' iflabels: ax.text(x0,y0,text,fontsize=15, bbox=dict(facecolor=color,alpha=0.75)) plt.axis('off') plt.show()
現在可以可視化結果:
plot_results(image,probs,bboxes,labels=True)
審核編輯:彭菁
-
Facebook
+關注
關注
3文章
1429瀏覽量
54720 -
代碼
+關注
關注
30文章
4779瀏覽量
68524 -
檢測模型
+關注
關注
0文章
17瀏覽量
7306 -
Transformer
+關注
關注
0文章
143瀏覽量
5995
原文標題:Transformer (DETR) 對象檢測實操!
文章出處:【微信號:vision263com,微信公眾號:新機器視覺】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論