在TensorFlow中讀數據一般有三種方法:
使用placeholder讀內存中的數據
使用queue讀硬盤中的數據
使用Dataset讀內存個硬盤中的數據
基本概率
由于第三種方法在語法上更簡潔,因此本文主要介紹第三種方法。官方給出的Dataset API類圖:
image.png
其中終于重要的兩個基礎類:Dateset和Iterator。Dateset是具有相同類型的“元素”的有序表,元素可以是向量、字符串、圖片等。
從內存中創建Dataset
以數字元素為例:
例1
從Dataset中實例化一個Iterator,然后對Iterator進行迭代。
iterator = dataset.make_one_shot_iterator()
從dataset中實例化一個iterator,是“one shot iterator”,即只能從頭到尾讀取一次。
one_element = iterator.get_next()
從iterator中取出一個元素, one_element是一個tensor,因此需要調用sess.run(one_element)取出值。
如果元素被讀取完了,再sess.run(one_element)會拋出tf.errors.OutOfRangeError異常。解決方法:使用 dataset.repeat()
更復雜的輸入形式,例如,在圖像識別的應用中,一個元素可以使{“image”:image_tensor, “label”:lable_tensor}
dataset = tf.data.Dataset.from_tensor_slices( { "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), "b": np.random.uniform(size=(5, 2)) } )
最終dataset中的一個元素為{"a": 1.0, "b": [0.9, 0.1]}的形式。或者
dataset = tf.data.Dataset.from_tensor_slices( (np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.random.uniform(size=(5, 2))) )
對Dataset中的元素做變換:Transformation
一個Dataset通過Transformation變成一個新的Dataset。常用的操作有:
map
batch
shuffle
repeat
下面分別來介紹以上幾個操作。(1)mapmap接收一個函數,dataset中的每個元素都可以作為這個函數的輸入,并將函數的返回值作為新的dataset,例如:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
(2)batch將多個元素組合成batch,例如:
dataset = dataset.batch(32)
(3)shuffle打亂dataset中的元素,參數buffersize表示打亂時buffer的大小。
dataset = dataset.shuffle(buffer_size=10000)
(4)repeat將整個序列重復多次,只用用來處理epoch。如果直接調用repeat()的話,生成的序列就會無限重復下去,沒有結束,因此也不會拋出。tf.errors.OutOfRangeError異常:
dataset = dataset.repeat(5)
例子:讀磁盤圖片與對應的label
讀入磁盤中的圖片和圖片相應的label,并將其打亂,組成batch_size=32的訓練樣本。在訓練時重復10個epoch。
# 函數的功能時將filename對應的圖片文件讀進來,并縮放到統一的大小def _parse_function(filename, label): image_string = tf.read_file(filename) image_decoded = tf.image.decode_image(image_string) image_resized = tf.image.resize_images(image_decoded, [28, 28]) return image_resized, label# 圖片文件的列表filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])# label[i]就是圖片filenames[i]的labellabels = tf.constant([0, 37, ...])# 此時dataset中的一個元素是(filename, label)dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))# 此時dataset中的一個元素是(image_resized, label)dataset = dataset.map(_parse_function)# 此時dataset中的一個元素是(image_resized_batch, label_batch)dataset = dataset.shuffle(buffersize=1000).batch(32).repeat(10)# 此時dataset中的一個元素是(image_resized_batch, label_batch)# image_resized_batch的形狀為(32, 28, 28, 3), label_batch的形狀為(32, )
-
函數
+關注
關注
3文章
4333瀏覽量
62708 -
tensorflow
+關注
關注
13文章
329瀏覽量
60537 -
DataSet
+關注
關注
0文章
5瀏覽量
2208
原文標題:TensorFlow讀數據
文章出處:【微信號:C_Expert,微信公眾號:C語言專家集中營】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論