Using BertClient with tf.data APIΒΆ

Note

The complete example can be found example4.py. There is also an example in Keras.

The tf.data API enables you to build complex input pipelines from simple, reusable pieces. One can also use BertClient to encode sentences on-the-fly and use the vectors in a downstream model. Here is an example:

batch_size = 256
num_parallel_calls = 4

# start a thead-safe client to support num_parallel_calls in tf.data API
bc = ConcurrentBertClient(num_parallel_calls)


def get_encodes(x):
   # x is `batch_size` of lines, each of which is a json object
   samples = [json.loads(l) for l in x]
   text = [s['raw_text'] for s in samples]  # List[List[str]]
   labels = [s['label'] for s in samples]  # List[str]
   features = bc.encode(text)
   return features, labels


ds = (tf.data.TextLineDataset(train_fp).batch(batch_size)
        .map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.string]),  num_parallel_calls=num_parallel_calls)
        .map(lambda x, y: {'feature': x, 'label': y})
        .make_one_shot_iterator().get_next())

The trick here is to start a pool of BertClient and reuse them one by one. In this way, we can fully harness the power of num_parallel_calls in Dataset.map() API.