zl程序教程

您现在的位置是:首页 >  其它

当前栏目

tf.lookup.StaticHashTable 用法

用法 TF lookup
2023-09-14 09:09:28 时间
tf.lookup.StaticHashTable 本质是tensorflow 内置字典,在yolov3 tf代码中多次应用

def load_tfrecord_dataset(file_pattern, class_file, size=416):
    LINE_NUMBER = -1  # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER
    class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
        class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)

    files = tf.data.Dataset.list_files(file_pattern)
    dataset = files.flat_map(tf.data.TFRecordDataset)
    return dataset.map(lambda x: parse_tfrecord(x, class_table, size))

###### 在当前目录下新建文件 voc2012.names

aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import tensorflow as tf
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
        class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
class_table.lookup(tf.constant(['cat','person']))                                                                                                                                                  
<tf.Tensor: shape=(2,), dtype=int64, numpy=array([ 7, 14])>

class_table.export()                                                                                                                                                                                

(<tf.Tensor: shape=(20,), dtype=string, numpy=
 array([b'cat', b'chair', b'dog', b'person', b'bird', b'motorbike',
        b'bottle', b'car', b'bus', b'sheep', b'boat', b'train',
        b'aeroplane', b'pottedplant', b'sofa', b'tvmonitor', b'cow',
        b'diningtable', b'horse', b'bicycle'], dtype=object)>,
 <tf.Tensor: shape=(20,), dtype=int64, numpy=
 array([ 7,  8, 11, 14,  2, 13,  4,  6,  5, 16,  3, 18,  0, 15, 17, 19,  9,
        10, 12,  1])>)