0x0. 背景
我也是偶然在知乎的一個問題下看到這個問題,大概就是說在使用apex的LayerNorm/RMSNorm的時候可以打開這個api的memory_efficient開關,這個開關可以在速度和精度無損的情況下節省網絡訓練的顯存占用。感覺比較有趣,我就研究了一下,因此也就有了這篇文章。
我去實測了一下,單機8卡A100訓練LLama7B,純數據并行的情況下打開memory_efficient開關相比于不打開節省了大約2個G的顯存,如果模型繼續scale up,那么省掉的顯存也會更多。因此,本文就是對這個memory_efficient開關的背后實現做一個解讀,另外也會對apex里面LayerNorm/RMSNorm本身的cuda kernel實現做一個細節解讀。
apex的LayerNorm/RMSNorm被實現成一個fuse kernel,然后上層使用torch.autograd.Function來封裝,本文的講解主要以LayerNorm為例子
實際上RMSNorm和LayerNorm的實現是共享的,只不過在kernel內部會區分一下縮放策略是2個參數(LayerNorm的gamma和beta)還是一個參數。
classFusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod defforward(ctx,input,weight,bias,normalized_shape,eps,memory_efficient=False): globalfused_layer_norm_cuda iffused_layer_norm_cudaisNone: fused_layer_norm_cuda=importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape=normalized_shape ctx.eps=eps ctx.memory_efficient=memory_efficient input_=input.contiguous() weight_=weight.contiguous() bias_=bias.contiguous() output,mean,invvar=fused_layer_norm_cuda.forward_affine( input_,ctx.normalized_shape,weight_,bias_,ctx.eps ) ifctx.memory_efficient: ctx.save_for_backward(output,weight_,bias_,None,invvar) else: ctx.save_for_backward(input_,weight_,bias_,mean,invvar) returnoutput
可以看到在非memory_efficient模式下面,ctx.save_for_backward(output, weight_, bias_, None, invvar)保存了用于backward的tensor,包括輸入,權重,偏置,均值和方差的逆。但在memory_efficient模式下面ctx.save_for_backward(output, weight_, bias_, None, invvar),則是保存了輸出,權重偏置以及方差的逆。
這個地方看下你是否會掉入誤區?從表面上看,這里也就只省掉了一個gamma,因為輸入和輸出tensor的形狀是一樣的,那么這樣還有什么收益呢?背景是,在pre-ln的transformer架構里面LayerNorm/RMSNorm之后緊接著是一個線性投影,無論是在注意力機制還是在多層感知機(mlp)中都是如此,所以輸出Tensor一定要被保存下來。而在post-ln架構中,輸出還會直接用于殘差連接。然而,在這兩種情況下,LayerNorm/RMSNorm的輸入都不再被使用,所以這里原本的輸入保存變得相當多余,因為我們可以保存無論如何都會被保存的輸出張量。這樣就可以達到節省顯存的目的了。
接下來就詳細解讀下實現。
0x1. Apex的LayerNorm前向cuda實現
https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda.cpp 這個文件是基于實現的LayerNorm cuda kernel使用torch extension模塊導出python接口。
同時這個文件還寫了幾個工具函數,比如compute_n1_n2用來計算LayerNorm中非歸一化和歸一化部分的大小:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L7C31-L7C51 ,check_args函數對LayerNorm的參數進行檢查:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L32C22-L143 。
此外,這個cpp預定義了cuda_layer_norm的函數接口,并且考慮了gamma/beta是否為空。
接下來就正式對LayerNorm的前向cuda實現進行解析。
0x1.1 工具函數
LayerNorm使用Welford算法統計均值方差,在 https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu 寫了一系列kernel實現中需要用到的工具函數,這些函數是gpu上用到的。下面對其簡單解析一下,另外Welford算法可以看這篇博客的介紹:用Welford算法實現LN的方差更新(感嘆一下,zzk寫這篇文章的時候還是萌新,經過2年時間已經成長為國內頂級的工程師了,開掛般學習能力) 。工具函數包含:cuWelfordOnlineSum,cuChanOnlineSum,cuRMSOnlineSum,cuChanRMSOnlineSum這些,我把自己的原始注釋使用gpt4進行了潤色,這樣會顯得更加通俗一些。具體解釋如下:
//這段代碼是個CUDA函數,名叫cuWelfordOnlineSum,擅長用Welford算法來邊收數據邊算這些數據的平均值和變化范圍(就是均值和方差)。 //用Welford算法來算這個,特別穩,不會因為數據太多而出錯,而且每加一個數據就能更新一次均值和方差。 // const U curr:這個是新來的數據點。 // U& mu:這個是我們到現在為止算出來的所有數據的平均值。 // U& sigma2:這個是我們到現在為止算出來的方差,可以告訴你數據變化有多大。 // U& count:這個記錄了我們到現在處理了多少數據點。 template__device__ voidcuWelfordOnlineSum( constUcurr, U&mu, U&sigma2, U&count) { count=count+U(1);//每次調用這個函數,就把處理的數據數量加一。 Udelta=curr-mu;//看看新數據和現有平均值差多少。 Ulmean=mu+delta/count;//用這個差值和數據總量來算一個新的平均值。 mu=lmean;//把這個新算的平均值記下來。 Udelta2=curr-lmean;//現在再算一下新數據和新平均值的差。 sigma2=sigma2+delta*delta2;//利用這個新舊平均值的差來更新方差。 } //這段代碼是個CUDA函數,名叫cuChanOnlineSum。它用于處理一種特殊的情況: //當你有兩堆數據,想要快速算出它們合并后的平均值和方差時,這個函數就派上用場了。 // const U muB, sigma2B, countB:這三個是你新加入的那堆數據的平均值、方差和數據點數量。 // U& mu, sigma2, count:這三個是你之前已經有的數據的平均值、方差和數據點數量。 //這個函數會更新這些值,讓它們反映出兩堆數據合并后的情況。 template __device__ voidcuChanOnlineSum( constUmuB, constUsigma2B, constUcountB, U&mu, U&sigma2, U&count) { Udelta=muB-mu;//先算算新數據堆和老數據堆的平均值差了多少。 UnA=count;//記下當前數據堆(我們叫它A堆)的大小。 UnB=countB;//看看新來的那堆數據(B堆)有多少個點。 count=count+countB;//把兩堆數據的數量加起來。 UnX=count;//這就是合并后總數據量的大小。 if(nX>U(0)){ nA=nA/nX;//算一下A堆數據在總數據中占的比例。 nB=nB/nX;//同理,算一下B堆的比例。 mu=nA*mu+nB*muB;//利用這些比例和各自的平均值,算出總的平均值。 sigma2=sigma2+sigma2B+delta*delta*nA*nB*nX;//然后用一點復雜的公式,把方差也算出來,這個公式考慮了兩堆數據的方差和它們平均值的差異。 }else{ //如果合并后的總數是0,那就說明兩堆數據其實都是空的,所以把平均值和方差都設為0。 mu=U(0); sigma2=U(0); } } //這里定義了一個名叫cuRMSOnlineSum的CUDA函數,它的主要任務就是在線實時計算一串數據的平方和。 //你可能會問,為什么要算平方和呢?這是因為我們可以用它來算出均方根(RMS, Root Mean Square), //均方根是一種描述數據波動大小的指標,特別常用于信號處理領域。 template __device__ voidcuRMSOnlineSum( constUcurr, U&sigma2) { sigma2=sigma2+curr*curr;//每次函數被調用,就把當前值的平方加到累計平方和中。 } //又定義了一個名叫cuChanRMSOnlineSum的CUDA函數,這個家伙的工作就是幫你算兩組數據的平方和總和。 //當你有兩組數據,想要快速合并它們的均方根(RMS)時,這個函數就能派上用場。 //它其實是均方根計算過程中的一個環節,用于處理兩個獨立數據集的情況。 template __device__ voidcuChanRMSOnlineSum( constUsigma2B, U&sigma2) { sigma2=sigma2+sigma2B;//這里就簡單直接了,把第二組數據的平方和加到當前的累計值上。 }
這里還有一個函數cuWelfordMuSigma2是用來計算張量某一維度上的均值(mu)和方差(sigma2)的,它調用了上面的工具函數,但是這個函數我們在kernel實現階段解析,因為它需要一些kernel啟動的背景。
0x1.2 啟動邏輯
先對kernel啟動這部分的代碼進行注釋,首先是共享內存的結構體定義。
//這段代碼定義了一個叫做SharedMemory的模板結構體,專門用在CUDA設備函數里來訪問所謂的“共享內存”。 //在CUDA編程里,共享內存是一種特別高效的內存類型,非常適合用來在CUDA的一個塊(block)內的不同線程間共享數據。 //這里還包括了針對float和double類型數據的SharedMemory結構體的特化版本。 namespace{ //這是通用的SharedMemory結構體模板。注意,我們通過在函數體內使用一個未定義的符號來阻止這個結構體被實例化, //這樣如果嘗試用未特化的類型來編譯這個結構體,編譯器就會報錯。 //template//structSharedMemory //{ ////確保我們不會編譯任何未特化的類型 //__device__T*getPointer() //{ //extern__device__voiderror(void); //error(); //returnNULL; //} //}; template structSharedMemory; //這是SharedMemory結構體針對float類型的特化版本。 template<> structSharedMemory { //這個函數返回一個指向共享內存的float類型指針。 __device__float*getPointer() { //這里聲明了一個名為s_float的外部共享內存數組,用于存儲float類型的數據。 // extern和__shared__關鍵字表明這個數組是在共享內存中定義的。 extern__shared__floats_float[]; returns_float; } }; //下面是針對double類型的特化版本,工作方式和float版本相似。 template<> structSharedMemory { __device__double*getPointer() { extern__shared__doubles_double[]; returns_double; } }; }
然后是Kernel啟動的具體邏輯部分:
//這段代碼里,我們定義了一個CUDA設備函數叫做cuApplyLayerNorm_,它的主要任務是執行LayerNorm(層歸一化)。 //層歸一化是深度學習中的一個技巧,用來讓每一層的輸出更加標準化,有助于模型訓練。 //我們定義了三種模板參數:T是輸入數據類型,U是中間計算(比如均值和方差)的類型,V是輸出數據類型。 // output_vals, mean, invvar, vals, gamma, beta 這些都是指向不同數據的指針。 //在層歸一化中,我們通常把一個多維數據(張量)分為兩部分:一部分用來做標準化,另一部分保持原樣。 //比如,如果你有一個[batch_size,channels,height,width]形狀的4D張量, //而你只想對最后兩個維度進行層歸一化,那么n1是batch_size * channels,n2是height * width。 template__device__ voidcuApplyLayerNorm_( V*__restrict__output_vals, U*__restrict__mean, U*__restrict__invvar, constT*__restrict__vals, constintn1, constintn2, constUepsilon, constV*__restrict__gamma, constV*__restrict__beta, boolrms_only ) { //基本假設: // 1) blockDim.x 是 warp 的大小(這是一個CUDA的技術細節)。 // 2)輸入的張量數據在內存中是連續的。 // //這段代碼遍歷n1維度,每次處理一個i1索引。 //假設每個CUDA線程塊的x維度等于warp大小,確保數據處理是高效的。 //這里一個線程可能要處理多行,所以我們用gridDim.y來控制步長。(因為gridDim.x=1) for(autoi1=blockIdx.y;i1shared; U*buf=shared.getPointer();//創建一個 SharedMemory 實例用于處理類型 U 的數據。 Umu,sigma2;//這里mu和sigma2分別代表均值和方差,我們接下來要計算它們。 //調用 cuWelfordMuSigma2 函數計算給定索引 i1 處的均值(mu)和方差(sigma2)。 cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,rms_only); //定位到當前 i1 索引處的輸入和輸出的起始位置。 constT*lvals=vals+i1*n2; V*ovals=output_vals+i1*n2; //計算逆方差 c_invvar,這是層歸一化中一個關鍵的步驟。 Uc_invvar=rsqrt(sigma2+epsilon); //計算每個 CUDA 塊的線程總數(numx)和當前線程的一維索引(thrx)。 constintnumx=blockDim.x*blockDim.y; constintthrx=threadIdx.x+threadIdx.y*blockDim.x; //如果提供了gamma和beta參數,或者我們只是在做RMS計算,我們會用一種特別的方式來計算輸出值。 if(gamma!=NULL&&(beta!=NULL||rms_only)){ for(inti=thrx;i(lvals[i]); if(!rms_only){ //標準化當前值,然后用gamma和beta進行調整。 ovals[i]=gamma[i]*static_cast (c_invvar*(curr-mu))+beta[i]; }else{ ////如果是RMS模式,我們稍微簡化計算過程。 ovals[i]=gamma[i]*static_cast (c_invvar*curr); } } } //否則,如果沒有提供gamma和beta,我們就直接用計算出的均值和逆方差來進行標準化。 else{ for(inti=thrx;i(lvals[i]); if(!rms_only){ //直接進行標準化計算。 ovals[i]=static_cast (c_invvar*(curr-mu)); }else{ //// RMS模式下的簡化計算。 ovals[i]=static_cast (c_invvar*curr); } } } //在每個 CUDA 塊中,僅由一個線程(線程(0,0))更新均值和逆方差。 if(threadIdx.x==0&&threadIdx.y==0){ if(!rms_only){ mean[i1]=mu; } invvar[i1]=c_invvar; } //用于同步塊內的所有線程。 __syncthreads(); } } //對上個函數的參數透傳,不過rms_only設為False template __global__ voidcuApplyLayerNorm( V*__restrict__output_vals, U*__restrict__mean, U*__restrict__invvar, constT*__restrict__vals, constintn1, constintn2, constUepsilon, constV*__restrict__gamma, constV*__restrict__beta ) { cuApplyLayerNorm_ (output_vals,mean,invvar,vals,n1,n2,epsilon,gamma,beta,false); } //kernel啟動代碼,設置線程塊和線程數 template voidHostApplyLayerNorm( V*output, U*mean, U*invvar, constT*input, intn1, intn2, doubleepsilon, constV*gamma, constV*beta ) { // threads和blocks定義了CUDA內核的線程和塊的維度。這里,每個線程塊有32×4的線程,而塊的數量由n1和GPU設備的最大網格大小限制決定。 autostream=at::getCurrentCUDAStream().stream(); constdim3threads(32,4,1); constuint64_tmaxGridY=at::getCurrentDeviceProperties()->maxGridSize[1]; constdim3blocks(1,std::min((uint64_t)n1,maxGridY),1); //這段代碼計算內核函數需要多少共享內存。如果threads.y大于1,它會根據U類型的大小分配足夠的內存。 intnshared= threads.y>1? threads.y*sizeof(U)+(threads.y/2)*sizeof(U): 0; //最后,函數使用cuApplyLayerNorm kernel來執行實際的LayerNorm操作。 // kernel函數的調用使用了之前計算的線程塊和線程配置,以及共享內存大小和CUDA流。 cuApplyLayerNorm<< >>( output,mean,invvar,input,n1,n2,U(epsilon),gamma,beta); }
這段代碼包含了kernel的啟動邏輯,包括設置block的個數以及每個block中的線程排布方式,然后在cuApplyLayerNorm_里面有一個跨線程網格的大循環作用在n1維度,每個線程可能會處理多行數據。而在每一行數據的處理上,調用了cuWelfordMuSigma2 函數計算給定索引 i1 處的均值(mu)和方差(sigma2),并隨后在n2維度上來計算LayerNorm的輸出,同時會在每個Block的線程(0, 0)更新cuWelfordMuSigma2算出來的均值和方差(這里的記錄的實際上是方差的逆)。
0x1.3 kernel實現
從上面的分析可知,整個LayerNorm實現的核心就是cuWelfordMuSigma2函數,下面對這個函數進行解析。
//`cuWelfordMuSigma2`是一個CUDA設備函數,旨在高效計算張量某一特定維度上的均值(mu)和方差(sigma2)。 //它基于Welford算法實現,以提高數值穩定性。此外,該函數支持僅計算均方根(RMS)作為一種操作模式。 //模板參數:定義了處理張量值(T)和執行計算(U)時使用的數據類型。 // const T*__restrict__ vals:指向張量數據的指針。 // const int n1, n2:指定張量的維度,其中n1是參與計算的維度的大小,n2是被約減的維度的大小。 // const int i1:當前正在處理的n1維度上的特定索引。 // U& mu, sigma2:用于存儲計算得出的均值和方差。 // U* buf:指向用于線程間通訊的共享內存緩沖區的指針。 // bool rms_only:一個標志,用于指示是否僅計算RMS(為true時)或同時計算均值和方差(為false時)。 template __device__ voidcuWelfordMuSigma2( constT*__restrict__vals, constintn1, constintn2, constinti1, U&mu, U&sigma2, U*buf, boolrms_only) { //前提條件: // 1) blockDim.x 等于 warp 的大小。 // 2)輸入的張量在內存中連續存儲。 // 3)有足夠的共享內存可用,大小為 2*blockDim.y*sizeof(U)+ blockDim.y*sizeof(int)。 // //在 n2 維度上計算方差和均值。 //初始化 count, mu, 和 sigma2 為零。 Ucount=U(0); mu=U(0); sigma2=U(0); //確保處理的 i1 索引在張量的有效范圍內。 if(i1(lvals[l+k]); //根據 rms_only 標志調用相應的函數來更新均值和方差或僅更新平方和(用于計算 RMS)。 if(!rms_only){ cuWelfordOnlineSum(curr,mu,sigma2,count); }else{ cuRMSOnlineSum(curr,sigma2); } } } //這個循環處理了之前在步長為 4*numx 的循環中未處理的張量元素。每個線程獨立處理它們剩余的部分。 for(;l(lvals[l]); if(!rms_only){ cuWelfordOnlineSum(curr,mu,sigma2,count); }else{ cuRMSOnlineSum(curr,sigma2); } } //在同一個warp內進行歸約操作。 for(intl=0;l<=?4;??++l)?{ ??????//?是在 CUDA 設備上進行 warp 內部數據交換的關鍵部分。 ??????//?這行代碼用于確定在一個 warp(32個線程)內,每個線程應該從哪個“lane”(即其他線程)獲取數據。 ??????//?(1< (muB,sigma2B,countB,mu,sigma2,count); }else{ cuChanRMSOnlineSum(sigma2B,sigma2); } } //threadIdx.x==0hascorrectvaluesforeachwarp //inter-warpreductions //檢查是否有多個 warp。如果 blockDim.y 大于 1,則表示塊中有多個 warp 需要進行reduce操作。 if(blockDim.y>1){ //為方差和均值的reduce操作分配共享內存。ubuf 用于存儲方差和均值,ibuf 用于存儲計數。 U*ubuf=(U*)buf; U*ibuf=(U*)(ubuf+blockDim.y); //這個循環是對 warp 間的reduce操作進行分層合并。 for(intoffset=blockDim.y/2;offset>0;offset/=2){ //upperhalfofwarpswritetoshared //確保只有部分線程(warp 的上半部分)將其計算的結果寫入共享內存。 if(threadIdx.x==0&&threadIdx.y>=offset&&threadIdx.y2*offset)?{ ??????????const?int?wrt_y?=?threadIdx.y?-?offset; ??????????if?(!rms_only)?{ ????????????ubuf[2*wrt_y]?=?mu; ????????????ibuf[wrt_y]?=?count; ??????????} ??????????ubuf[2*wrt_y+1]?=?sigma2; ????????} ????????//?同步以等待共享內存存儲完畢 ????????__syncthreads(); ????????//?lower?half?merges ????????//?此部分是對 warp 間數據的合并操作。 ????????//?確保只有部分線程(warp 的下半部分)從共享內存中讀取數據并進行合并。 ????????if?(threadIdx.x?==?0?&&?threadIdx.y?(muB,sigma2B,countB,mu,sigma2,count); }else{ cuChanRMSOnlineSum(sigma2B,sigma2); } } __syncthreads(); } //threadIdx.x=0&&threadIdx.y==0onlythreadthathascorrectvalues //最終的結果由塊內的第一個線程(threadIdx.x ==0&& threadIdx.y ==0)計算并寫入共享內存。 if(threadIdx.x==0&&threadIdx.y==0){ if(!rms_only){ ubuf[0]=mu; } ubuf[1]=sigma2; } __syncthreads(); //如果不是只計算 RMS,則還需要更新均值 mu。 if(!rms_only){ mu=ubuf[0]; } //計算最終的方差。 sigma2=ubuf[1]/U(n2); //don'tcareaboutfinalvalueofcount,weknowcount==n2 } //如果塊中只有一個 warp(blockDim.y == 1),則通過 WARP_SHFL 直接在 warp 內進行數據交換和更新。 else{ if(!rms_only){ mu=WARP_SHFL(mu,0); } sigma2=WARP_SHFL(sigma2/U(n2),0); } }
cuWelfordMuSigma2函數就是在n2維度上使用工具函數章節的Weleford方法來完成均值和方差的計算,然后這里還借助了共享內存來做warp內和warp間的reduce,最終得到全局的均值和方差。
前向的kernel就分析到這里,大家如果想對LayerNorm的優化做進一步的了解,推薦看一下OneFlow的SoftMax和LayerNorm優化文章。CUDA優化之LayerNorm性能優化實踐(https://zhuanlan.zhihu.com/p/443026261) ,這篇文章也是講解了LayerNorm的前向優化流程,文章開頭有一張性能的圖:
實際上在大模型時代,我們的隱藏層維度已經越來越大了,所以我們在實際訓練的時候,OneFlow版本的kernel相比于apex的layerNorm在13B之類的模型訓練里就拿不到明顯收益了。而在CV中,由于做LayerNorm的維度可能相對小一些,所以相比于apex的LayerNorm就可以取得明顯加速。
0x2. Apex的LayerNorm反向cuda實現(memory_efficient相關計算)
在apex的LayerNorm反向實現時我們不僅要關注它的cuda kernel是怎么寫的,還要關注memory_efficient打開時是如何根據輸出來計算梯度的。我們知道LayerNorm需要對輸入,gamma,beta都計算梯度,介于篇幅原因,這里對實現得最復雜的gamma/beta的反向過程進行走讀。
0x2.1 啟動邏輯
這里從kernel的啟動邏輯開始梳理:
//這是一個模板函數,支持不同的數據類型:T(輸入數據類型)、 // U(通常用于中間計算的數據類型,默認為float)、V(輸出數據類型,默認與T相同)。 //參數包括輸出梯度(dout)、均值(mean)、方差倒數(invvar)、輸入或輸出的PyTorch張量(input_or_output)、 //兩個維度參數(n1、n2)、gamma和beta參數、用于數值穩定的epsilon、輸入梯度(grad_input)、 // gamma梯度(grad_gamma)和beta梯度(grad_beta)、以及一個指示是否優化內存使用的布爾值(memory_efficient)。 templatevoidHostLayerNormGradient( constV*dout, constU*mean, constU*invvar, at::Tensor*input_or_output, intn1, intn2, constV*gamma, constV*beta, doubleepsilon, T*grad_input, V*grad_gamma, V*grad_beta, boolmemory_efficient ) { //獲取當前CUDA流以用于后續的CUDA內核調用。 autostream=at::getCurrentCUDAStream().stream(); //如果gamma和beta不為NULL,函數會計算它們的梯度。 //這涉及兩個CUDA內核的調用:cuComputePartGradGammaBeta和cuComputeGradGammaBeta。 if(gamma!=NULL&&beta!=NULL){ //computegrad_gamma(j)andgrad_beta(j) // part_size是分塊計算梯度時的部分大小。 constintpart_size=16; // threads2定義了每個CUDA線程塊中的線程數量(32×4×1)。 constdim3threads2(32,4,1); // blocks2定義了CUDA網格中的塊數量,其中,n2維度被分成多個塊,以確保每個塊可以處理n2中的一部分。 constdim3blocks2((n2+threads2.x-1)/threads2.x,part_size,1); //這部分代碼計算用于CUDA內核的共享內存大小。nshared2_a和nshared2_b是基于線程和塊維度的兩種不同共享內存大小估算。 constintnshared2_a=2*sizeof(U)*threads2.y*threads2.y*(threads2.x+1); constintnshared2_b=threads2.x*threads2.y*sizeof(U); //最終選擇較大的一個估算值作為實際的共享內存大小(nshared2)。 constintnshared2=nshared2_a>nshared2_b?nshared2_a:nshared2_b; //note(mkozuki):Icanhardcodepart_grad_gamma'sdtypeasfloatgiventhat //the`cuda_layer_norm_gradient`doesn'tsupportdouble. //根據輸入或輸出張量的數據類型決定局部梯度張量part_grad_gamma和part_grad_beta的數據類型。 //如果輸入或輸出是半精度浮點數(Half)或BFloat16,則使用單精度浮點數(Float);否則,使用輸入或輸出的相同數據類型。 constautopart_grad_dtype= (input_or_output->scalar_type()==at::Half||input_or_output->scalar_type()==at::BFloat16)? at::Float: input_or_output->scalar_type(); //創建兩個新的PyTorch張量part_grad_gamma和part_grad_beta,用于存儲gamma和beta的局部梯度計算結果。 at::Tensorpart_grad_gamma=at::empty({part_size,n2},input_or_output->options().dtype(part_grad_dtype)); at::Tensorpart_grad_beta=at::empty_like(part_grad_gamma); //使用BOOL_SWITCH宏處理memory_efficient參數,以決定是否使用內存高效版本的CUDA內核。 //調用cuComputePartGradGammaBeta內核計算gamma和beta的梯度。 //這個內核函數接收必要的輸入參數,并將梯度結果寫入part_grad_gamma和part_grad_beta張量。 BOOL_SWITCH(memory_efficient,MemoryEfficient,[&]{ autokernel=&cuComputePartGradGammaBeta ; kernel<< >>( dout, input_or_output->DATA_PTR (), n1,n2, mean, invvar, U(epsilon), gamma, beta, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), epsilon, false); }); //定義了每個CUDA線程塊中的線程數量(32×8×1)。 constdim3threads3(32,8,1); //定義了CUDA網格中的塊數量。在這里,n2維度被分成多個塊,每個塊的大小由threads2.x(之前定義的線程數量)確定。 constdim3blocks3((n2+threads2.x-1)/threads2.x,1,1); //這行代碼計算了cuComputeGradGammaBeta內核所需的共享內存大小。它基于threads3線程塊的維度和數據類型U的大小。 constintnshared3=threads3.x*threads3.y*sizeof(U); //kernel接收局部梯度張量(part_grad_gamma和part_grad_beta)、塊大小(part_size)、 //維度參數(n1和n2)和指向梯度輸出的指針(grad_gamma和grad_beta)。 cuComputeGradGammaBeta<< >>( part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), part_size, n1,n2, grad_gamma, grad_beta, false); } ... }
這里省略了計算輸入梯度的啟動代碼,只看計算gamma和beta梯度的代碼。可以發現,這里對gamma和beta的梯度進行計算時使用了分塊計算的方式,首先會調用cuComputePartGradGammaBeta這個kernel計算出一個部分gamma和部分beta,也就是part_grad_gamma和part_grad_beta,需要注意這個kernel開啟的線程塊為:const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1),其中part_size=16,此外每個線程塊中的線程排布為:const dim3 threads2(32,4,1),即每個線程塊有128個線程。我們可以簡單算一下block2的大小,threads2.x=32,那么blocks2=(n2/32,16,1),也就是一共會有n2/2個線程塊。
使用cuComputePartGradGammaBeta計算完局部gamma和beta的grad之后,會調用cuComputeGradGammaBeta這個kernel來匯總全局的gamma和beta的梯度。這里開啟的線程塊為:const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1),而每個線程塊里面有256個線程,排布為const dim3 threads3(32,8,1)。
現在了解了線程塊的組織方式就需要去kernel實現里面對應看一下具體是怎么計算的。
0x2.2 kernel計算邏輯
首先來看分段計算gamma和beta梯度的kernel實現,注釋如下:
// part_size是分塊計算梯度時的部分大小。 //constintpart_size=16; // threads2定義了每個CUDA線程塊中的線程數量(32×4×1)。 //constdim3threads2(32,4,1); // blocks2定義了CUDA網格中的塊數量,其中,n2維度被分成多個塊,以確保每個塊可以處理n2中的一部分。 //constdim3blocks2((n2+threads2.x-1)/threads2.x,part_size,1); //-> //blockDim.x=32,blockDim.y=4,gridDim.y=16 //假設n1=4,n2=256,并且當前是第一個線程塊 template__global__ voidcuComputePartGradGammaBeta( constV*__restrict__dout, constT*__restrict__input_or_output, constintn1, constintn2, constU*__restrict__mean, constU*__restrict__invvar, Uepsilon, constV*__restrict__gamma, constV*__restrict__beta, U*part_grad_gamma, U*part_grad_beta, constdoubleeps, boolrms_only) { // numsegs_n1計算n1維度(4)被分成多少段。使用blockDim.y*blockDim.y(16)作為分段大小。 //帶入值:numsegs_n1 =(4 + 16 - 1)/ 16 = 1。 constintnumsegs_n1=(n1+blockDim.y*blockDim.y-1)/(blockDim.y*blockDim.y); // segs_per_block計算每個線程塊要處理的段數。 //帶入值:segs_per_block =(1 + 16 - 1)/ 16 = 1。 constintsegs_per_block=(numsegs_n1+gridDim.y-1)/gridDim.y; //這些行計算當前線程塊開始和結束處理n1維度的索引 //i1_beg和i1_beg_plus_one相差segs_per_block*blockDim.y*blockDim.y=1*4*4=16 //帶入blockIdx.y =0:i1_beg =0* 1 * 4 * 4 =0, i1_beg_plus_one = 1 * 1 * 4 * 4 = 16,i1_end = min(16, 4)= 4 constinti1_beg=blockIdx.y*segs_per_block*blockDim.y*blockDim.y; constinti1_beg_plus_one=(blockIdx.y+1)*segs_per_block*blockDim.y*blockDim.y; constinti1_end=i1_beg_plus_oneshared; U*buf=shared.getPointer();//bufhasatleastblockDim.x*blockDim.y*blockDim.y+(blockDim.y-1)*(blockDim.x/blockDim.y)elements U*warp_buf1=(U*)buf;//大小是31*4*4=496 U*warp_buf2=warp_buf1+blockDim.y*blockDim.y*row_stride;//大小是3*(32/4)=24 //computepartialsumsfromstridedinputs //dothistoincreasenumberofloadsinflight cuLoadWriteStridedInputs (i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps,rms_only); // for循環處理每個數據塊(由i1_beg和i1_end確定)。 //它在數據塊之間以步幅blockDim.y*blockDim.y迭代,允許不同的線程塊處理不同的數據區域。 for(inti1_block=i1_beg+blockDim.y*blockDim.y;i1_block(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps,rms_only); } //確保在所有線程完成其加載和處理操作之前,沒有線程會繼續執行后續的操作。 __syncthreads(); //inter-warpreductions //sumwithineachwarp //這部分代碼執行內部歸約,計算每個warp內部的部分和。 // acc1和acc2分別用于累積來自warp_buf1和warp_buf2的值。這些緩沖區包含之前步驟計算的中間結果。 Uacc1=U(0); Uacc2=U(0); //內部循環對于blockDim.y內的每一行進行累加,if (!rms_only)條件檢查是否需要執行特定的分支邏輯。 //需要特別注意,這個累加實際上是在列方向上也就是n2維度,在n2維度上一個線程負責計算blockDim.y列 for(intk=0;k1;offset/=2){ //在每次迭代中,只有threadIdx.y小于當前offset的線程會參與計算,這樣可以避免重復的工作。 if(threadIdx.y
在理解這段代碼之前,有一個大前提,那就是這里的訪問方式是n1是和blockDim.y綁定的,而n2是和blockDim.x綁定的,而且以二維矩陣的角度來看,n1是在列方向,而n2是在行的方向。然后const int row_stride = blockDim.x+1這一行是對共享內存進行padding避免Bank Conflict的,而在計算時對共享內存的訪問就是按照列來訪問,徹底避免bank conflict。
這也是為什么cuLoadWriteStridedInputs和cuLoadAddStridedInputs函數名中有一個Strided,這也暗示了它們的訪問模式是跨stride的。剩下的部分其實和前向就比較類似了,做warp內和warp間的reduce。
另外一個值得注意的點是在cuLoadWriteStridedInputs和cuLoadAddStridedInputs計算時,會根據memory_efficient開關選擇不同的計算公式,分別從輸入和輸出來計算出梯度,達到kernel內部重計算的目的。
//這段代碼定義了一個名為cuLoadWriteStridedInputs的CUDA設備函數模板,用于在計算LayerNorm的梯度時, //從輸入張量中加載數據并進行必要的計算,將結果存儲在 warp 緩沖區中。這個函數支持內存高效模式(MemoryEfficient)。 //模板參數 T, U, V 代表不同的數據類型。 // bool MemoryEfficient 用于選擇是否采用內存高效的方式處理數據。 //__device__表明這是一個 CUDA 設備函數。 //函數參數包括各種用于LayerNorm梯度計算的數據, //如輸入/輸出張量、梯度張量 dout、均值 mean、逆方差 invvar、縮放參數 gamma、偏移參數 beta 等。 template__device__ voidcuLoadWriteStridedInputs( constinti1_block, constintthr_load_row_off, constintthr_load_col_off, constinti2_off, constintrow_stride, U*warp_buf1, U*warp_buf2, constT*input_or_output, constV*dout, constinti1_end, constintn2, constU*__restrict__mean, constU*__restrict__invvar, constV*__restrict__gamma, constV*__restrict__beta, constdoubleeps, boolrms_only ) { //計算 i1,表示當前處理的行索引。 inti1=i1_block+thr_load_row_off; if(i1(input_or_output[load_idx]); Ucurr_dout=static_cast(dout[load_idx]); //根據 rms_only 和 MemoryEfficient 的值,使用不同的公式計算梯度,并將結果存儲在 warp 緩沖區中。 if(!rms_only){ warp_buf1[write_idx]=curr_dout; if(MemoryEfficient){ Ucurr_beta=static_cast(beta[i2]); warp_buf2[write_idx]=curr_dout*(c_h-curr_beta)/static_cast(clamp_by_magnitude(gamma[i2],eps)); }else{ warp_buf2[write_idx]=curr_dout*(c_h-mean[i1])*invvar[i1]; } }else{ if(MemoryEfficient){ warp_buf2[write_idx]=curr_dout*(c_h)/static_cast(clamp_by_magnitude(gamma[i2],eps)); }else{ warp_buf2[write_idx]=curr_dout*(c_h)*invvar[i1]; } } }else{ //對于超出 n2 范圍的索引,將相應的 warp 緩沖區位置設置為0。 if(!rms_only){ warp_buf1[write_idx]=U(0); } warp_buf2[write_idx]=U(0); } } }else{ //對于超出 n1 范圍的索引,也將相應的 warp 緩沖區位置設置為0。 for(intk=0;k
執行完cuComputePartGradGammaBeta這個kernel之后,它的輸出part_grad_gamma和part_grad_beta分別以行為n2列為n1的內存視角保存了LayerNorm的局部均值和方差的梯度。
接下來會使用cuComputeGradGammaBeta這個kernel來計算全局的均值和方差的梯度,由于局部計算的時候分塊大小是16,而每個線程負責了4行的計算,那么這里還需要累積16/4=4次,以得到當前行的所有局部梯度的和。
//blockDim.x=n2/32,blockDim.y=1 //threadDim.x=32,threadDim.y=8 template__global__ voidcuComputeGradGammaBeta( constU*part_grad_gamma, constU*part_grad_beta, constintpart_size, constintn1, constintn2, V*grad_gamma, V*grad_beta, boolrms_only) { //sumpartialgradientsforgammaandbeta SharedMemoryshared; U*buf=shared.getPointer(); //計算每個線程的全局索引i2,用于確定它在n2維度上的位置。 inti2=blockIdx.x*blockDim.x+threadIdx.x; //如果線程索引i2小于n2的大小,該線程會參與計算。 if(i2=1;offset/=2){ //tophalfwritetosharedmemory //在這個歸約階段,線程首先將其累加結果寫入共享內存,然后從共享內存讀取并繼續累加。 if(threadIdx.y>=offset&&threadIdx.y2*offset)?{ ??????????const?int?write_idx?=?(threadIdx.y?-?offset)?*?blockDim.x?+?threadIdx.x; ??????????buf[write_idx]?=?sum_gamma; ??????????if?(!rms_only)?{ ????????????buf[write_idx+nbsize3]?=?sum_beta; ??????????} ????????} ????????//?__syncthreads()在每次迭代結束時同步所有線程,確保共享內存的一致性。 ????????__syncthreads(); ????????//?bottom?half?sums ????????if?(threadIdx.y?
注意,for (int offset = blockDim.y/2; offset >= 1; offset /= 2) 這個循環包起來的代碼在這里不會工作,因為這個kernel的啟動設置中 blockDim.y=1。另外,我們知道輸入的數據已經是寫到全局內存里面的了,已經是同步之后的了,然后每個線程累積4次這個過程也是從global memory里面先讀再計算最后寫回全局內存,所以確實不需要再reduce了。
關于memory_efficient開關打開時的梯度計算公式,按照 https://github.com/NVIDIA/apex/pull/1715 這個pr 來看應該就是把原始的輸入用重計算的輸入替換之后再代入到之前的梯度計算公式中算出來的。
?
https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L579 這里就對應了對gamma的梯度,https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L582C5-L582C5 這里則對應了對beta的梯度。這里的就等于,公式和代碼實現都能完整對應上。
0x3. 總結
這篇文章記錄了筆者在研究大模型訓練中偶然見到的一個Trick的代碼解密過程,希望對學習cuda的小伙伴有所幫助,謝謝大家。
審核編輯:劉清
-
NVIDIA
+關注
關注
14文章
4978瀏覽量
102988 -
RMS
+關注
關注
2文章
138瀏覽量
35787 -
python
+關注
關注
56文章
4792瀏覽量
84628 -
CUDA
+關注
關注
0文章
121瀏覽量
13620 -
GPU芯片
+關注
關注
1文章
303瀏覽量
5804
原文標題:【BBuf的CUDA筆記】十二,LayerNorm/RMSNorm的重計算實現
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論