nfp.layers.batched_segment_op¶
- batched_segment_op(data, segment_ids, num_segments, data_mask=None, reduction='sum')[source]¶
Flattens data and segment_ids containing a batch dimension for tf.math.segment* operations. Includes support for masking.
- Parameters
data – tensor of shape [B, L, F], where B is the batch size, L is the length, and F is a feature dimension
segment_ids – tensor of shape [B, L] containing up to N segments
num_segments – N, integer
data_mask – boolean tensor of shape [B, L] masking the input data
reduction – string for specific tf.math.unsorted_segment_* function