文本基于深度學(xué)習(xí)和遷移學(xué)習(xí)方法,對(duì)瘧疾等傳染病檢測(cè)問(wèn)題進(jìn)行了研究。作者對(duì)瘧疾的檢測(cè)原理以及遷移學(xué)習(xí)理論進(jìn)行了介紹,并使用VGG-19預(yù)訓(xùn)練模型,進(jìn)行了基于特征提取和基于參數(shù)微調(diào)的遷移學(xué)習(xí)實(shí)踐。
前言
"健康就是財(cái)富",這是一個(gè)老生常談的話(huà)題,但不得不說(shuō)這是一個(gè)真理。在這篇文章中,我們將研究如何利用AI技術(shù)來(lái)檢測(cè)一種致命的疾病——瘧疾。本文將提出一個(gè)低成本、高效率和高準(zhǔn)確率的開(kāi)源解決方案。本文有兩個(gè)目的:1.了解瘧疾的傳染原因和其致命性;2、介紹如何運(yùn)用深度學(xué)習(xí)有效檢測(cè)瘧疾。本章的主要內(nèi)容如下:
開(kāi)展本項(xiàng)目的動(dòng)機(jī)
瘧疾檢測(cè)的方法
用深度學(xué)習(xí)檢測(cè)瘧疾
從頭開(kāi)始訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)(CNN)
利用預(yù)訓(xùn)練模型進(jìn)行遷移學(xué)習(xí)
本文不是為了宣揚(yáng) AI 將要取代人類(lèi)的工作,或者接管世界等論調(diào),而是僅僅展示 AI 是如何用一種低成本、高效率和高準(zhǔn)確率的方案,來(lái)幫助人類(lèi)去檢測(cè)和診斷瘧疾,并盡量減少人工操作。
Python and TensorFlow?—?A great combo to build open-source deep learning solutions
在本文中,我們將使用 Python 和 tensorflow ,來(lái)構(gòu)建一個(gè)強(qiáng)大的、可擴(kuò)展的、有效的深度學(xué)習(xí)解決方案。這些工具都是免費(fèi)并且開(kāi)源的,這使得我們能夠構(gòu)建一個(gè)真正低成本、高效精準(zhǔn)的解決方案,而且可以讓每個(gè)人都可以輕松使用。讓我們開(kāi)始吧!
動(dòng)機(jī)
瘧疾是經(jīng)瘧蚊叮咬而感染瘧原蟲(chóng)所引起的蟲(chóng)媒傳染病,瘧疾最常通過(guò)受感染的雌性瘧蚊來(lái)傳播。雖然我們不必詳細(xì)了解這種疾病,但是我們需要知道瘧疾有五種常見(jiàn)的類(lèi)型。下圖展示了這種疾病的致死性在全球的分布情況。
Malaria Estimated Risk Heath Map (Source: treated.com)
從上圖中可以明顯看到,瘧疾遍布全球,尤其是在熱帶區(qū)域分布密集。本項(xiàng)目就是基于這種疾病的特性和致命性來(lái)開(kāi)展的,下面我們舉個(gè)例子來(lái)說(shuō)明。起初,如果你被一只受感染的蚊子叮咬了,那么蚊子所攜帶的寄生蟲(chóng)就會(huì)進(jìn)入你的血液,并且開(kāi)始摧毀你體內(nèi)的攜氧紅細(xì)胞。通常來(lái)講,你會(huì)在被瘧蚊叮咬后的幾天或幾周內(nèi)感到不適,一般會(huì)首先出現(xiàn)類(lèi)似流感或者病毒感染的癥狀。然而,這些致命的寄生蟲(chóng)可以在你身體里完好地存活超過(guò)一年的時(shí)間,并且不產(chǎn)生任何其他癥狀!延遲接受正確的治療,可能會(huì)導(dǎo)致并發(fā)癥甚至死亡。因此,早期并有效的瘧疾檢測(cè)和排查可以挽救這些生命。
世界衛(wèi)生組織(WHO)發(fā)布了幾個(gè)關(guān)于瘧疾的重要事實(shí),詳情見(jiàn)此。簡(jiǎn)而言之,世界上將近一半的人口面臨瘧疾風(fēng)險(xiǎn),每年有超過(guò)2億的瘧疾病例,以及有大約40萬(wàn)人死于瘧疾。這些事實(shí)讓我們認(rèn)識(shí)到,快速簡(jiǎn)單高效的瘧疾檢查是多么重要,這也是本文的動(dòng)機(jī)所在。
瘧疾檢查的方法
文章《 Pre-trained convolutional neural networks as feature extractors toward improved Malaria parasite detection in thin blood smear images》(本文的數(shù)據(jù)和分析也是基于這篇文章)簡(jiǎn)要介紹了瘧疾檢測(cè)的幾種方法,這些方法包括但是不限于厚薄血涂片檢查、聚合酶鏈?zhǔn)椒磻?yīng)(PCR)和快速診斷測(cè)試(RDT)。在本文中,我們沒(méi)有對(duì)這些方法進(jìn)行詳細(xì)介紹,但是需要注意的一點(diǎn)是,后兩種方法常常作為替代方案使用,尤其是在缺乏高質(zhì)量顯微鏡服務(wù)的情況下。
我們將簡(jiǎn)要討論基于血液涂片檢測(cè)流程的標(biāo)準(zhǔn)瘧疾診斷方法,首先感謝 Carlos Ariza 的博文,以及 Adrian Rosebrock 關(guān)于瘧疾檢查的文章,這兩篇文章讓我們對(duì)瘧疾檢查領(lǐng)域有了更為深入的了解。
A blood smear workflow for Malaria detection (Source)
根據(jù)上圖所示的 WHO 的血液涂片檢測(cè)流程,該工作包括在100倍放大倍數(shù)下對(duì)血涂片進(jìn)行深入檢查,其中人們需要從5000個(gè)細(xì)胞中,手動(dòng)檢測(cè)出含有寄生蟲(chóng)的紅細(xì)胞。Rajaraman 等人的論文中更加詳細(xì)的給出了相關(guān)的描述,如下所示:
厚血涂片有助于檢測(cè)寄生蟲(chóng)的存在,而薄血涂片有助于識(shí)別引起感染的寄生蟲(chóng)種類(lèi)(Centers for Disease Control and Prevention, 2012)。診斷準(zhǔn)確性在很大程度上取決于人類(lèi)的專(zhuān)業(yè)知識(shí),并且可能受到觀察者間的差異和觀察者的可靠性所帶來(lái)的不利影響,以及受到在疾病流行或資源受限的區(qū)域內(nèi)的大規(guī)模診斷造成的負(fù)擔(dān)所帶來(lái)的不利影響(Mitiku,Mengistu&Gelaw,2003)。替代技術(shù),例如聚合酶鏈?zhǔn)椒磻?yīng)(PCR)和快速診斷測(cè)試(RDT),也會(huì)被使用;但是PCR分析受到其性能的限制(Hommelsheim等,2014),而RDT在疾病流行地區(qū)的成本效益較低(Hawkes,Katsuva&Masumbuko,2009)。
因此,傳統(tǒng)的瘧疾檢測(cè)絕對(duì)是一個(gè)密集的手工過(guò)程,或許深度學(xué)習(xí)技術(shù)可以幫助它完成自動(dòng)化。上文提到的這些內(nèi)容為后文打下了基礎(chǔ)。
用深度學(xué)習(xí)檢測(cè)瘧疾
手工診斷血液涂片,是一項(xiàng)重復(fù)且規(guī)律的工作,而且需要一定的專(zhuān)業(yè)知識(shí)來(lái)區(qū)分和統(tǒng)計(jì)被寄生的和未感染的細(xì)胞。如果某些地區(qū)的工作人員沒(méi)有正確的專(zhuān)業(yè)知識(shí),那么這種方法就不能很好地推廣,并且會(huì)導(dǎo)致一些問(wèn)題。現(xiàn)有工作已經(jīng)取得了一些進(jìn)展,包括利用最先進(jìn)的圖像處理和分析技術(shù)來(lái)提取手工設(shè)計(jì)的特征,并利用這些特性構(gòu)建基于機(jī)器學(xué)習(xí)的分類(lèi)模型。但是,由于手工設(shè)計(jì)的部分需要花費(fèi)大量的時(shí)間,當(dāng)有更多的數(shù)據(jù)可供訓(xùn)練時(shí),模型卻無(wú)法及時(shí)的進(jìn)行擴(kuò)展。
深度學(xué)習(xí)模型,或更具體地說(shuō),卷積神經(jīng)網(wǎng)絡(luò)(CNN)在各種計(jì)算機(jī)視覺(jué)任務(wù)中獲得了非常好的效果。本文假設(shè)您已經(jīng)對(duì) CNN 有一定的了解,但是如果您并不了解 CNN ,可以通過(guò)這篇文章進(jìn)行深入了解。簡(jiǎn)單來(lái)講,CNN 最關(guān)鍵的層主要包括卷積層和池化層,如下圖所示。
A typical CNN architeture (Source: deeplearning.net)
卷積層從數(shù)據(jù)中學(xué)習(xí)空間層級(jí)模式,這些模式具有平移不變性,因此卷積層能夠?qū)W習(xí)圖像的不同方面。例如,第一卷積層將學(xué)習(xí)諸如邊緣和角落的微型局部模式,第二卷積層將基于第一層所提取的特征,來(lái)學(xué)習(xí)更大的圖像模式,如此循序漸進(jìn)。這使得 CNN 能夠自動(dòng)進(jìn)行特征工程,并且學(xué)習(xí)有效的特征,這些特征對(duì)新的數(shù)據(jù)具有很好的泛化能力。池化層常用于下采樣和降維。
因此,CNN 能夠幫助我們實(shí)現(xiàn)自動(dòng)化的和可擴(kuò)展的特征工程。此外,在模型的末端接入密集層,能夠使我們執(zhí)行圖像分類(lèi)等任務(wù)。使用像CNN這樣的深度學(xué)習(xí)模型,進(jìn)行自動(dòng)化的瘧疾檢測(cè),可能是一個(gè)高效、低成本、可擴(kuò)展的方案。特別是隨著遷移學(xué)習(xí)的發(fā)展和預(yù)訓(xùn)練模型的共享,在數(shù)據(jù)量較少等限制條件下,深度學(xué)習(xí)模型也能取得很好的效果。
Rajaraman 等人的論文《Pre-trained convolutional neural networks as feature extractors toward improved parasite detection in thin blood smear images》利用 6 個(gè)預(yù)訓(xùn)練模型,在進(jìn)行瘧疾檢測(cè)時(shí)取得了 95.9% 的準(zhǔn)確率。本文的重點(diǎn)是從頭開(kāi)始嘗試一些簡(jiǎn)單的 CNN 模型和一些預(yù)先訓(xùn)練的模型,并利用遷移學(xué)習(xí)來(lái)檢驗(yàn)我們?cè)谕粩?shù)據(jù)集下得到的結(jié)果。本文將使用 Python 和 TensorFlow 框架來(lái)構(gòu)建模型。
數(shù)據(jù)集的詳情
首先感謝 Lister Hill 國(guó)家生物醫(yī)學(xué)通信中心(LHNCBC)的研究人員(國(guó)家醫(yī)學(xué)圖書(shū)館(NLM)的部門(mén)),他們仔細(xì)收集并注釋了這個(gè)血涂片圖像的數(shù)據(jù)集,數(shù)據(jù)中包含健康和感染這兩種類(lèi)型的血涂片圖像。您可以從官方網(wǎng)站上下載這些圖像。
實(shí)際上,他們開(kāi)發(fā)了一款可以運(yùn)行在標(biāo)準(zhǔn)安卓智能手機(jī)上的應(yīng)用程序,該程序可以連接傳統(tǒng)的光學(xué)顯微鏡 (Poostchi et al., 2018) 。他們從孟加拉國(guó)吉大港醫(yī)學(xué)院附屬醫(yī)院進(jìn)行拍照記錄了樣本集,其中包括150個(gè)惡性瘧原蟲(chóng)感染的樣本和 50 個(gè)健康的樣本,每個(gè)樣本都是經(jīng)過(guò) Giemsa 染色的薄血涂片。智能手機(jī)的內(nèi)置攝像頭可以捕獲樣本的每一個(gè)局部微觀視圖。來(lái)自泰國(guó)曼谷的瑪希隆-牛津熱帶醫(yī)學(xué)研究所的專(zhuān)業(yè)人員為這些圖像進(jìn)行了手動(dòng)注釋。讓我們簡(jiǎn)要地看一下數(shù)據(jù)集結(jié)構(gòu)。首先根據(jù)本文所使用的操作系統(tǒng),我們需要安裝一些基本的依賴(lài)項(xiàng)。
本文所使用的系統(tǒng)是云上的 Debian 系統(tǒng),該系統(tǒng)配置有 GPU ,這能夠加速我們模型的訓(xùn)練。首先安裝依賴(lài)樹(shù),這能夠方便我們查看目錄結(jié)構(gòu)。(sudo apt install tree)
從上圖所示的目錄結(jié)構(gòu)中可以看到,我們的文件里包含兩個(gè)文件夾,分別包含受感染的和健康的細(xì)胞圖像。利用以下代碼,我們可以進(jìn)一步了解圖像的總數(shù)是多少。
import osimport globbase_dir = os.path.join('./cell_images')infected_dir = os.path.join(base_dir,'Parasitized')healthy_dir = os.path.join(base_dir,'Uninfected')infected_files = glob.glob(infected_dir+'/*.png')healthy_files = glob.glob(healthy_dir+'/*.png')len(infected_files), len(healthy_files)# Output(13779, 13779)
從上述結(jié)果可以看到, 瘧疾和非瘧疾(未感染)的細(xì)胞圖像的數(shù)據(jù)集均包含13779張圖片,兩個(gè)數(shù)據(jù)集的大小是相對(duì)平衡的。接下來(lái)我們將利用這些數(shù)據(jù)構(gòu)建一個(gè)基于pandas的dataframe類(lèi)型的數(shù)據(jù),這對(duì)我們后續(xù)構(gòu)建數(shù)據(jù)集很有幫助。
import numpy as npimport pandas as pdnp.random.seed(42)files_df = pd.DataFrame({ 'filename': infected_files + healthy_files, 'label': ['malaria'] * len(infected_files) + ['healthy'] * len(healthy_files)}).sample(frac=1, random_state=42).reset_index(drop=True)files_df.head()
構(gòu)建和探索圖像數(shù)據(jù)集
在構(gòu)建深度學(xué)習(xí)模型之前,我們不僅需要訓(xùn)練數(shù)據(jù),還需要未用于訓(xùn)練的數(shù)據(jù)來(lái)驗(yàn)證和測(cè)試模型的性能。本文采用 60:10:30 的比例來(lái)劃分訓(xùn)練集、驗(yàn)證集和測(cè)試集。我們將使用訓(xùn)練集和驗(yàn)證集來(lái)訓(xùn)練模型,并利用測(cè)試集來(lái)檢驗(yàn)?zāi)P偷男阅堋?/p>
from sklearn.model_selection import train_test_splitfrom collections import Countertrain_files, test_files, train_labels, test_labels = train_test_split(files_df['filename'].values, files_df['label'].values, test_size=0.3, random_state=42)train_files, val_files, train_labels, val_labels = train_test_split(train_files, train_labels, test_size=0.1, random_state=42)print(train_files.shape, val_files.shape, test_files.shape)print('Train:', Counter(train_labels), ' Val:', Counter(val_labels), ' Test:', Counter(test_labels))# Output(17361,) (1929,) (8268,)Train: Counter({'healthy': 8734, 'malaria': 8627}) Val: Counter({'healthy': 970, 'malaria': 959}) Test: Counter({'malaria': 4193, 'healthy': 4075})
可以發(fā)現(xiàn),由于血液來(lái)源、測(cè)試方法以及圖像拍攝的方向不同,血液涂片和細(xì)胞的圖像尺寸不盡相同。我們需要獲取一些訓(xùn)練數(shù)據(jù)的統(tǒng)計(jì)信息,從而確定最優(yōu)的圖像尺寸(請(qǐng)注意,在這里我們完全沒(méi)用到測(cè)試集!)。
import cv2from concurrent import futuresimport threadingdef get_img_shape_parallel(idx, img, total_imgs): if idx % 5000 == 0 or idx == (total_imgs - 1): print('{}: working on img num:{}'.format(threading.current_thread().name,idx)) return cv2.imread(img).shapeex = futures.ThreadPoolExecutor(max_workers=None)data_inp = [(idx, img, len(train_files)) for idx, img in enumerate(train_files)]print('Starting Img shape computation:')train_img_dims_map = ex.map(get_img_shape_parallel, [record[0] for record in data_inp], [record[1] for record in data_inp], [record[2] for record in data_inp])train_img_dims = list(train_img_dims_map)print('Min Dimensions:', np.min(train_img_dims, axis=0)) print('Avg Dimensions:', np.mean(train_img_dims, axis=0))print('Median Dimensions:', np.median(train_img_dims, axis=0))print('Max Dimensions:', np.max(train_img_dims, axis=0))# OutputStarting Img shape computation:ThreadPoolExecutor-0_0: working on img num: 0ThreadPoolExecutor-0_17: working on img num: 5000ThreadPoolExecutor-0_15: working on img num: 10000ThreadPoolExecutor-0_1: working on img num: 15000ThreadPoolExecutor-0_7: working on img num: 17360Min Dimensions: [46 46 3]Avg Dimensions: [132.77311215 132.45757733 3.]Median Dimensions: [130. 130. 3.]Max Dimensions: [385 394 3]
我們采用了并行處理的策略來(lái)加速圖像讀取操作。基于匯總的統(tǒng)計(jì)信息,我們決定將每張圖像的大小調(diào)整為125x125。現(xiàn)在讓我們加載所有的圖像,并把他們的大小都調(diào)整為上述固定的尺寸。
IMG_DIMS = (125, 125)def get_img_data_parallel(idx, img, total_imgs): if idx % 5000 == 0 or idx == (total_imgs - 1): print('{}: working on img num: {}'.format(threading.current_thread().name,idx)) img = cv2.imread(img) img = cv2.resize(img, dsize=IMG_DIMS, interpolation=cv2.INTER_CUBIC) img = np.array(img, dtype=np.float32) return imgex = futures.ThreadPoolExecutor(max_workers=None)train_data_inp = [(idx, img, len(train_files)) for idx, img in enumerate(train_files)]val_data_inp = [(idx, img, len(val_files)) for idx, img in enumerate(val_files)]test_data_inp = [(idx, img, len(test_files)) for idx, img in enumerate(test_files)]print('Loading Train Images:')train_data_map = ex.map(get_img_data_parallel, [record[0] for record in train_data_inp], [record[1] for record in train_data_inp], [record[2] for record in train_data_inp])train_data = np.array(list(train_data_map))print(' Loading Validation Images:')val_data_map = ex.map(get_img_data_parallel, [record[0] for record in val_data_inp], [record[1] for record in val_data_inp], [record[2] for record in val_data_inp])val_data = np.array(list(val_data_map))print(' Loading Test Images:')test_data_map = ex.map(get_img_data_parallel, [record[0] for record in test_data_inp], [record[1] for record in test_data_inp], [record[2] for record in test_data_inp])test_data = np.array(list(test_data_map))train_data.shape, val_data.shape, test_data.shape # OutputLoading Train Images:ThreadPoolExecutor-1_0: working on img num: 0ThreadPoolExecutor-1_12: working on img num: 5000ThreadPoolExecutor-1_6: working on img num: 10000ThreadPoolExecutor-1_10: working on img num: 15000ThreadPoolExecutor-1_3: working on img num: 17360Loading Validation Images:ThreadPoolExecutor-1_13: working on img num: 0ThreadPoolExecutor-1_18: working on img num: 1928Loading Test Images:ThreadPoolExecutor-1_5: working on img num: 0ThreadPoolExecutor-1_19: working on img num: 5000ThreadPoolExecutor-1_8: working on img num: 8267((17361, 125, 125, 3), (1929, 125, 125, 3), (8268, 125, 125, 3))
我們?cè)俅芜\(yùn)用了并行處理策略來(lái)加速圖像加載和尺寸調(diào)整的計(jì)算,如上面輸出結(jié)果中展示的,我們最終得到了所需尺寸的圖像張量。現(xiàn)在我們可以查看一些樣本的細(xì)胞圖像,從而從直觀上認(rèn)識(shí)一下我們的數(shù)據(jù)的情況。
import matplotlib.pyplot as plt%matplotlib inlineplt.figure(1 , figsize = (8 , 8))n = 0 for i in range(16): n += 1 r = np.random.randint(0 , train_data.shape[0] , 1) plt.subplot(4 , 4 , n) plt.subplots_adjust(hspace = 0.5 , wspace = 0.5) plt.imshow(train_data[r[0]]/255.) plt.title('{}'.format(train_labels[r[0]])) plt.xticks([]) , plt.yticks([])
從上面的樣本圖像可以看出,瘧疾和健康細(xì)胞圖像之間存在一些細(xì)微差別。我們將構(gòu)建深度學(xué)習(xí)模型,通過(guò)不斷訓(xùn)練來(lái)使模型嘗試學(xué)習(xí)這些模式。在開(kāi)始訓(xùn)練模型之前,我們先對(duì)模型的參數(shù)進(jìn)行一些基本的設(shè)置。
BATCH_SIZE = 64NUM_CLASSES = 2EPOCHS = 25INPUT_SHAPE = (125, 125, 3)train_imgs_scaled = train_data / 255.val_imgs_scaled = val_data / 255.# encode text category labelsfrom sklearn.preprocessing import LabelEncoderle = LabelEncoder()le.fit(train_labels)train_labels_enc = le.transform(train_labels)val_labels_enc = le.transform(val_labels)print(train_labels[:6], train_labels_enc[:6])# Output['malaria' 'malaria' 'malaria' 'healthy' 'healthy' 'malaria'][1 1 1 0 0 1]
上面的代碼設(shè)定了圖像的維度,批尺寸,epoch 的次數(shù),并且對(duì)我們的類(lèi)別標(biāo)簽進(jìn)行了編碼。TensorFLow 2.0 alpha 版本在2019年3月發(fā)布,它為我們項(xiàng)目的實(shí)施提供了一個(gè)完美的接口。
import tensorflow as tf# Load the TensorBoard notebook extension (optional)%load_ext tensorboard.notebooktf.random.set_seed(42)tf.__version__# Output'2.0.0-alpha0'
深度學(xué)習(xí)模型的訓(xùn)練階段
在模型訓(xùn)練階段,我們將構(gòu)建幾個(gè)深度學(xué)習(xí)模型,利用前面構(gòu)建的訓(xùn)練集進(jìn)行訓(xùn)練,并在驗(yàn)證集上比較它們的性能。然后,我們將保存這些模型,并在模型評(píng)估階段再次使用它們。
模型1:從頭開(kāi)始訓(xùn)練CNN
對(duì)于本文的第一個(gè)瘧疾檢測(cè)模型,我們將構(gòu)建并從頭開(kāi)始訓(xùn)練一個(gè)基本的卷積神經(jīng)網(wǎng)絡(luò)(CNN)。首先,我們需要定義模型的結(jié)構(gòu)。
inp = tf.keras.layers.Input(shape=INPUT_SHAPE)conv1 = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same')(inp)pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)conv2 = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')(pool1)pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)conv3 = tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')(pool2)pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)flat = tf.keras.layers.Flatten()(pool3)hidden1 = tf.keras.layers.Dense(512, activation='relu')(flat)drop1 = tf.keras.layers.Dropout(rate=0.3)(hidden1)hidden2 = tf.keras.layers.Dense(512, activation='relu')(drop1)drop2 = tf.keras.layers.Dropout(rate=0.3)(hidden2)out = tf.keras.layers.Dense(1, activation='sigmoid')(drop2)model = tf.keras.Model(inputs=inp, outputs=out)model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])model.summary()# OutputModel: "model"_________________________________________________________________Layer (type) Output Shape Param # =================================================================input_1 (InputLayer) [(None, 125, 125, 3)] 0 _________________________________________________________________conv2d (Conv2D) (None, 125, 125, 32) 896 _________________________________________________________________max_pooling2d (MaxPooling2D) (None, 62, 62, 32) 0 _________________________________________________________________conv2d_1 (Conv2D) (None, 62, 62, 64) 18496 _________________________________________________________________......_________________________________________________________________dense_1 (Dense) (None, 512) 262656 _________________________________________________________________dropout_1 (Dropout) (None, 512) 0 _________________________________________________________________dense_2 (Dense) (None, 1) 513 =================================================================Total params: 15,102,529Trainable params: 15,102,529Non-trainable params: 0_________________________________________________________________
上述代碼所構(gòu)建的 CNN 模型,包含3個(gè)卷積層、1個(gè)池化層以及2個(gè)全連接層,并對(duì)全連接層設(shè)置 dropout 參數(shù)用于正則化。現(xiàn)在讓我們開(kāi)始訓(xùn)練模型吧!
import datetimelogdir = os.path.join('/home/dipanzan_sarkar/projects/tensorboard_logs', datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir,histogram_freq=1)reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',factor=0.5,patience=2, min_lr=0.000001)callbacks = [reduce_lr, tensorboard_callback]history = model.fit(x=train_imgs_scaled, y=train_labels_enc, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=(val_imgs_scaled, val_labels_enc), callbacks=callbacks, verbose=1) # OutputTrain on 17361 samples, validate on 1929 samplesEpoch 1/2517361/17361 [====] - 32s 2ms/sample - loss: 0.4373 - accuracy: 0.7814 - val_loss: 0.1834 - val_accuracy: 0.9393Epoch 2/2517361/17361 [====] - 30s 2ms/sample - loss: 0.1725 - accuracy: 0.9434 - val_loss: 0.1567 - val_accuracy: 0.9513......Epoch 24/2517361/17361 [====] - 30s 2ms/sample - loss: 0.0036 - accuracy: 0.9993 - val_loss: 0.3693 - val_accuracy: 0.9565Epoch 25/2517361/17361 [====] - 30s 2ms/sample - loss: 0.0034 - accuracy: 0.9994 - val_loss: 0.3699 - val_accuracy: 0.9559
從上面的結(jié)果可以看到,我們的模型在驗(yàn)證集上的準(zhǔn)確率為 95.6% ,這是非常好的。我們注意到模型在訓(xùn)練集上的準(zhǔn)確率為 99.9% ,這看起來(lái)有一些過(guò)擬合。為了更加清晰地查看這個(gè)問(wèn)題,我們可以分別繪制在訓(xùn)練和驗(yàn)證階段的準(zhǔn)確度曲線(xiàn)和損失曲線(xiàn)。
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))t = f.suptitle('Basic CNN Performance', fontsize=12)f.subplots_adjust(top=0.85, wspace=0.3)max_epoch = len(history.history['accuracy'])+1epoch_list = list(range(1,max_epoch))ax1.plot(epoch_list, history.history['accuracy'], label='Train Accuracy')ax1.plot(epoch_list, history.history['val_accuracy'], label='Validation Accuracy')ax1.set_xticks(np.arange(1, max_epoch, 5))ax1.set_ylabel('Accuracy Value')ax1.set_xlabel('Epoch')ax1.set_title('Accuracy')l1 = ax1.legend(loc="best")ax2.plot(epoch_list, history.history['loss'], label='Train Loss')ax2.plot(epoch_list, history.history['val_loss'], label='Validation Loss')ax2.set_xticks(np.arange(1, max_epoch, 5))ax2.set_ylabel('Loss Value')ax2.set_xlabel('Epoch')ax2.set_title('Loss')l2 = ax2.legend(loc="best")
Learning Curves for BasicCNN
從圖中可以看出,在第5個(gè) epoch 之后,在驗(yàn)證集上的精度似乎不再提高。我們先將這個(gè)模型保存,在后面我們會(huì)再次用到它。
model.save('basic_cnn.h5')
深度遷移學(xué)習(xí)
就像人類(lèi)能夠運(yùn)用知識(shí)完成跨任務(wù)工作一樣,遷移學(xué)習(xí)使得我們能夠利用在先前任務(wù)中學(xué)習(xí)到的知識(shí),來(lái)處理新的任務(wù),在機(jī)器學(xué)習(xí)和深度學(xué)習(xí)的環(huán)境下也是如此。這些文章涵蓋了遷移學(xué)習(xí)的詳細(xì)介紹和討論,有興趣的讀者可以參考學(xué)習(xí)。
Ideas for deep transferlearning
我們能否采用遷移學(xué)習(xí)的思想,將預(yù)訓(xùn)練的深度學(xué)習(xí)模型(已在大型數(shù)據(jù)集上進(jìn)行過(guò)訓(xùn)練的模型——例如 ImageNet)的知識(shí)應(yīng)用到我們的問(wèn)題——進(jìn)行瘧疾檢測(cè)上呢?我們將采用兩種目前最主流的遷移學(xué)習(xí)策略。
將預(yù)訓(xùn)練模型作為特征提取器
對(duì)預(yù)訓(xùn)練模型進(jìn)行微調(diào)
我們將使用由牛津大學(xué)視覺(jué)幾何組(VGG)所開(kāi)發(fā)的預(yù)訓(xùn)練模型 VGG-19 進(jìn)行實(shí)驗(yàn)。像 VGG-19 這樣的預(yù)訓(xùn)練模型,一般已經(jīng)在大型數(shù)據(jù)集上進(jìn)行過(guò)訓(xùn)練,這些數(shù)據(jù)集涵蓋多種類(lèi)別的圖像。基于此,這些預(yù)訓(xùn)練模型應(yīng)該已經(jīng)使用CNN模型學(xué)習(xí)到了一個(gè)具有高度魯棒性的特征的層次結(jié)構(gòu),并且其應(yīng)具有尺度、旋轉(zhuǎn)和平移不變性。因此,這個(gè)已經(jīng)學(xué)習(xí)了超過(guò)一百萬(wàn)個(gè)圖像的具有良好特征表示的模型,可以作為一個(gè)很棒的圖像特征提取器,為包括瘧疾檢測(cè)問(wèn)題在內(nèi)的其他計(jì)算機(jī)視覺(jué)問(wèn)題服務(wù)。在引入強(qiáng)大的遷移學(xué)習(xí)之前,我們先簡(jiǎn)要討論一下 VGG-19 的結(jié)構(gòu)。
理解VGG-19模型
VGG-19 是一個(gè)具有 19 個(gè)層(包括卷積層和全連接層)的深度學(xué)習(xí)網(wǎng)絡(luò),該模型基于 ImageNet 數(shù)據(jù)集進(jìn)行訓(xùn)練,該數(shù)據(jù)集是專(zhuān)門(mén)為圖像識(shí)別和分類(lèi)所構(gòu)建的。VGG-19 是由 Karen Simonyan 和 Andrew Zisserman 提出的,該模型在他們的論文《Very Deep Convolutional Networks for Large-Scale Image Recognition》中有詳細(xì)介紹,建議有興趣的讀者可以去讀一讀這篇優(yōu)秀的論文。VGG-19 模型的結(jié)構(gòu)如下圖所示。
VGG-19 Model Architecture
從上圖可以清楚地看到,該模型具有 16 個(gè)使用 3x3 卷積核的卷積層,其中部分卷積層后面接了一個(gè)最大池化層,用于下采樣;隨后依次連接了兩個(gè)具有 4096 個(gè)隱層神經(jīng)元的全連接層,接著連接了一個(gè)具有 1000 個(gè)隱層神經(jīng)元的全連接層, 最后一個(gè)全連接層的每個(gè)神經(jīng)元都代表 ImageNet 數(shù)據(jù)集中的一個(gè)圖像類(lèi)別。由于我們需要使用新的全連接層來(lái)分類(lèi)瘧疾,因此我們不需要最后的三個(gè)全連接層。我們更關(guān)心的是前五個(gè)塊,以便我們可以利用 VGG 模型作為有效的特征提取器。
前文提到有兩種遷移學(xué)習(xí)的策略,對(duì)于第一種策略,我們將把 VGG 模型當(dāng)做一個(gè)特征提取器,這可以通過(guò)凍結(jié)前五個(gè)卷積塊,使得它們的權(quán)重參數(shù)不會(huì)隨著新的訓(xùn)練過(guò)程而更新來(lái)實(shí)現(xiàn)。對(duì)于第二種策略,我們將會(huì)解凍最后的兩個(gè)卷積塊(模塊4和模塊5),從而使得它們的參數(shù)會(huì)隨著新的訓(xùn)練過(guò)程而不斷更新。
模型2:將預(yù)訓(xùn)練模型作為特征提取機(jī)
為了構(gòu)建這個(gè)模型,我們將利用 TensorFlow 加載 VGG-19 模型,并凍結(jié)它的卷積塊,以便我們可以將其用作圖像特征提取器。我們將在該模型的末尾插入自己的全連接層,用于執(zhí)行本文的分類(lèi)任務(wù)。
vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet',input_shape=INPUT_SHAPE)vgg.trainable = False# Freeze the layersfor layer in vgg.layers: layer.trainable = Falsebase_vgg = vggbase_out = base_vgg.outputpool_out = tf.keras.layers.Flatten()(base_out)hidden1 = tf.keras.layers.Dense(512, activation='relu')(pool_out)drop1 = tf.keras.layers.Dropout(rate=0.3)(hidden1)hidden2 = tf.keras.layers.Dense(512, activation='relu')(drop1)drop2 = tf.keras.layers.Dropout(rate=0.3)(hidden2)out = tf.keras.layers.Dense(1, activation='sigmoid')(drop2)model = tf.keras.Model(inputs=base_vgg.input, outputs=out)model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=1e-4),loss='binary_crossentropy',metrics=['accuracy'])model.summary()# OutputModel: "model_1"_________________________________________________________________Layer (type) Output Shape Param # =================================================================input_2 (InputLayer) [(None, 125, 125, 3)] 0 _________________________________________________________________block1_conv1 (Conv2D) (None, 125, 125, 64) 1792 _________________________________________________________________block1_conv2 (Conv2D) (None, 125, 125, 64) 36928 _________________________________________________________________......_________________________________________________________________block5_pool (MaxPooling2D) (None, 3, 3, 512) 0 _________________________________________________________________flatten_1 (Flatten) (None, 4608) 0 _________________________________________________________________dense_3 (Dense) (None, 512) 2359808 _________________________________________________________________dropout_2 (Dropout) (None, 512) 0 _________________________________________________________________dense_4 (Dense) (None, 512) 262656 _________________________________________________________________dropout_3 (Dropout) (None, 512) 0 _________________________________________________________________dense_5 (Dense) (None, 1) 513 =================================================================Total params: 22,647,361Trainable params: 2,622,977Non-trainable params: 20,024,384
從上面代碼的輸出可以看到,我們的模型有很多層,并且我們僅僅只利用了 VGG-19 的凍結(jié)層來(lái)提取特征。下面的代碼可以驗(yàn)證本模型中有多少層用于訓(xùn)練,以及檢驗(yàn)本模型中一共有多少層。
print("Total Layers:", len(model.layers))print("Total trainable layers:",sum([1 for l in model.layers if l.trainable]))# OutputTotal Layers: 28Total trainable layers: 6
現(xiàn)在我們將訓(xùn)練該模型,在訓(xùn)練過(guò)程中所用到的配置和回調(diào)函數(shù)與模型1中的類(lèi)似,完整的代碼可以參考github鏈接。下圖展示了在訓(xùn)練過(guò)程中,模型的準(zhǔn)確度曲線(xiàn)和損失曲線(xiàn)。
Learning Curves for frozen pre-trained CNN
從上圖可以看出,該模型不像模型1中基本的 CNN 模型那樣存在過(guò)擬合的現(xiàn)象,但是性能并不是很好。事實(shí)上,它的性能還沒(méi)有基本的 CNN 模型好。現(xiàn)在我們將模型保存,用于后續(xù)的評(píng)估。
model.save( 'vgg_frozen.h5')
模型3:具有圖像增廣的微調(diào)的預(yù)訓(xùn)練模型
在這個(gè)模型中,我們將微調(diào)預(yù)訓(xùn)練 VGG-19 模型的最后兩個(gè)區(qū)塊中層的權(quán)重。除此之外,我們還將介紹圖像增廣的概念。圖像增廣背后的原理與它的名稱(chēng)聽(tīng)起來(lái)完全一樣。我們首先從訓(xùn)練數(shù)據(jù)集中加載現(xiàn)有的圖像,然后對(duì)它們進(jìn)行一些圖像變換的操作,例如旋轉(zhuǎn),剪切,平移,縮放等,從而生成現(xiàn)有圖像的新的、變化的版本。由于這些隨機(jī)變換的操作,我們每次都會(huì)得到不同的圖像。我們將使用 tf.keras 中的 ImageDataGenerator 工具,它能夠幫助我們實(shí)現(xiàn)圖像增廣。
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, zoom_range=0.05, rotation_range=25, width_shift_range=0.05, height_shift_range=0.05, shear_range=0.05, horizontal_flip=True, fill_mode='nearest')val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)# build image augmentation generatorstrain_generator = train_datagen.flow(train_data, train_labels_enc, batch_size=BATCH_SIZE, shuffle=True)val_generator = val_datagen.flow(val_data, val_labels_enc, batch_size=BATCH_SIZE, shuffle=False)
在驗(yàn)證集上,我們只會(huì)對(duì)圖像進(jìn)行縮放操作,而不進(jìn)行其他的轉(zhuǎn)換,這是因?yàn)槲覀冃枰诿總€(gè)訓(xùn)練的 epoch 結(jié)束后,用驗(yàn)證集來(lái)評(píng)估我們的模型。有關(guān)圖像增廣的詳細(xì)說(shuō)明,可以參考這篇文章。讓我們來(lái)看看進(jìn)行圖像增廣變換后的一些樣本結(jié)果。
img_id = 0sample_generator = train_datagen.flow(train_data[img_id:img_id+1], train_labels[img_id:img_id+1],batch_size=1)sample = [next(sample_generator) for i in range(0,5)]fig, ax = plt.subplots(1,5, figsize=(16, 6))print('Labels:', [item[1][0] for item in sample])l = [ax[i].imshow(sample[i][0][0]) for i in range(0,5)]
Sample Augmented Images
從上圖可以清楚的看到圖像發(fā)生了輕微的變化。現(xiàn)在我們將構(gòu)建新的深度模型,該模型需要確保 VGG-19 模型的最后兩個(gè)塊可以進(jìn)行訓(xùn)練。
vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet',input_shape=INPUT_SHAPE)# Freeze the layersvgg.trainable = Trueset_trainable = Falsefor layer in vgg.layers: if layer.name in ['block5_conv1', 'block4_conv1']: set_trainable = True if set_trainable: layer.trainable = True else: layer.trainable = Falsebase_vgg = vggbase_out = base_vgg.outputpool_out = tf.keras.layers.Flatten()(base_out)hidden1 = tf.keras.layers.Dense(512, activation='relu')(pool_out)drop1 = tf.keras.layers.Dropout(rate=0.3)(hidden1)hidden2 = tf.keras.layers.Dense(512, activation='relu')(drop1)drop2 = tf.keras.layers.Dropout(rate=0.3)(hidden2)out = tf.keras.layers.Dense(1, activation='sigmoid')(drop2)model = tf.keras.Model(inputs=base_vgg.input, outputs=out)model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=1e-5),loss='binary_crossentropy',metrics=['accuracy'])print("Total Layers:", len(model.layers))print("Total trainable layers:", sum([1 for l in model.layers if l.trainable]))# OutputTotal Layers: 28Total trainable layers: 16
由于我們不希望在微調(diào)過(guò)程中,對(duì)預(yù)訓(xùn)練的層進(jìn)行較大的權(quán)重更新,我們降低了模型的學(xué)習(xí)率。由于我們使用數(shù)據(jù)生成器來(lái)加載數(shù)據(jù),本模型的訓(xùn)練過(guò)程會(huì)和之前稍稍不同,在這里,我們需要用到函數(shù) fit_generator(…) 。
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir,histogram_freq=1)reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,patience=2, min_lr=0.000001)callbacks = [reduce_lr, tensorboard_callback]train_steps_per_epoch = train_generator.n //train_generator.batch_sizeval_steps_per_epoch = val_generator.n //val_generator.batch_sizehistory = model.fit_generator(train_generator, steps_per_epoch=train_steps_per_epoch,epochs=EPOCHS,validation_data=val_generator,validation_steps=val_steps_per_epoch,verbose=1)# OutputEpoch 1/25271/271 [====] - 133s 489ms/step - loss: 0.2267 - accuracy: 0.9117 - val_loss: 0.1414 - val_accuracy: 0.9531Epoch 2/25271/271 [====] - 129s 475ms/step - loss: 0.1399 - accuracy: 0.9552 - val_loss: 0.1292 - val_accuracy: 0.9589......Epoch 24/25271/271 [====] - 128s 473ms/step - loss: 0.0815 - accuracy: 0.9727 - val_loss: 0.1466 - val_accuracy: 0.9682Epoch 25/25271/271 [====] - 128s 473ms/step - loss: 0.0792 - accuracy: 0.9729 - val_loss: 0.1127 - val_accuracy: 0.9641
下圖展示了該模型的訓(xùn)練曲線(xiàn),可以看出該模型是這三個(gè)模型中最好的模型,其驗(yàn)證準(zhǔn)確度幾乎達(dá)到了 96.5% ,而且從訓(xùn)練準(zhǔn)確度上看,我們的模型也沒(méi)有像第一個(gè)模型那樣出現(xiàn)過(guò)擬合。
Learning Curves for fine-tuned pre-trained CNN
現(xiàn)在讓我們保存這個(gè)模型,很快我們將在測(cè)試集上用到它進(jìn)行性能評(píng)估。
model.save( 'vgg_finetuned.h5')
至此,模型訓(xùn)練階段告一段落,我們即將在真實(shí)的測(cè)試集上去測(cè)試這些模型的性能。
深度學(xué)習(xí)模型的性能評(píng)估階段
現(xiàn)在,我們將對(duì)之前訓(xùn)練好的三個(gè)模型進(jìn)行評(píng)估。僅僅使用驗(yàn)證集來(lái)評(píng)估模型的好壞是不夠的, 因此,我們將使用測(cè)試集來(lái)進(jìn)一步評(píng)估模型的性能。我們構(gòu)建了一個(gè)實(shí)用的模塊 model_evaluation_utils,該模塊采用相關(guān)的分類(lèi)指標(biāo),用于評(píng)估深度學(xué)習(xí)模型的性能。首先我們需要將測(cè)試數(shù)據(jù)進(jìn)行縮放。
test_imgs_scaled = test_data / 255.test_imgs_scaled.shape, test_labels.shape# Output((8268, 125, 125, 3), (8268,))
第二步是加載之前所保存的深度學(xué)習(xí)模型,然后在測(cè)試集上進(jìn)行預(yù)測(cè)。
# Load Saved Deep Learning Modelsbasic_cnn = tf.keras.models.load_model('./basic_cnn.h5')vgg_frz = tf.keras.models.load_model('./vgg_frozen.h5')vgg_ft = tf.keras.models.load_model('./vgg_finetuned.h5')# Make Predictions on Test Databasic_cnn_preds = basic_cnn.predict(test_imgs_scaled, batch_size=512)vgg_frz_preds = vgg_frz.predict(test_imgs_scaled, batch_size=512)vgg_ft_preds = vgg_ft.predict(test_imgs_scaled, batch_size=512)basic_cnn_pred_labels = le.inverse_transform([1 if pred > 0.5 else 0 for pred in basic_cnn_preds.ravel()])vgg_frz_pred_labels = le.inverse_transform([1 if pred > 0.5 else 0 for pred in vgg_frz_preds.ravel()])vgg_ft_pred_labels=le.inverse_transform([1ifpred>0.5else0forpredinvgg_ft_preds.ravel()])
最后一步是利用 model_evaluation_utils 模塊,根據(jù)不同的分類(lèi)評(píng)價(jià)指標(biāo),來(lái)評(píng)估每個(gè)模型的性能。
import model_evaluation_utils as meuimport pandas as pdbasic_cnn_metrics = meu.get_metrics(true_labels=test_labels, predicted_labels=basic_cnn_pred_labels)vgg_frz_metrics = meu.get_metrics(true_labels=test_labels, predicted_labels=vgg_frz_pred_labels)vgg_ft_metrics = meu.get_metrics(true_labels=test_labels, predicted_labels=vgg_ft_pred_labels)pd.DataFrame([basic_cnn_metrics, vgg_frz_metrics, vgg_ft_metrics], index=['Basic CNN', 'VGG-19 Frozen', 'VGG-19 Fine-tuned'])
從圖中可以看到,第三個(gè)模型在測(cè)試集上的性能是最好的,其準(zhǔn)確度和 f1-score 都達(dá)到了96%,這是一個(gè)非常好的結(jié)果,而且這個(gè)結(jié)果和論文中提到的更為復(fù)雜的模型所得到的結(jié)果具有相當(dāng)?shù)?可比性!
結(jié)論
本文研究了一個(gè)有趣的醫(yī)學(xué)影像案例——瘧疾檢測(cè)。瘧疾檢測(cè)是一個(gè)復(fù)雜的過(guò)程,而且能夠進(jìn)行正確操作的醫(yī)療人員也很少,這是一個(gè)很?chē)?yán)重的問(wèn)題。本文利用 AI 技術(shù)構(gòu)建了一個(gè)開(kāi)源的項(xiàng)目,該項(xiàng)目在瘧疾檢測(cè)問(wèn)題上具有最高的準(zhǔn)確率,并使AI技術(shù)為社會(huì)帶來(lái)了效益。
-
python
+關(guān)注
關(guān)注
56文章
4792瀏覽量
84628 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5500瀏覽量
121113 -
ai技術(shù)
+關(guān)注
關(guān)注
1文章
1266瀏覽量
24287
原文標(biāo)題:醫(yī)生再添新助手!深度學(xué)習(xí)診斷傳染病 | 完整代碼+實(shí)操
文章出處:【微信號(hào):rgznai100,微信公眾號(hào):rgznai100】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論