Source code for nfp.preprocessing.tfrecord

import numpy as np
from nfp.frameworks import tf

# Code from https://www.tensorflow.org/tutorials/load_data/tfrecord


def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


[docs]def serialize_value(value): if type(value) == np.ndarray: return _bytes_feature(tf.io.serialize_tensor(value)) elif type(value) == int: return _int64_feature(value) elif type(value) == float: return _float_feature(value) else: raise TypeError(f"Didn't recognize type {type(value)}")