nfp.layers.batched_segment_op

batched_segment_op(data, segment_ids, num_segments, data_mask=None, reduction='sum')[source]

Flattens data and segment_ids containing a batch dimension for tf.math.segment* operations. Includes support for masking.

Parameters
  • data – tensor of shape [B, L, F], where B is the batch size, L is the length, and F is a feature dimension

  • segment_ids – tensor of shape [B, L] containing up to N segments

  • num_segments – N, integer

  • data_mask – boolean tensor of shape [B, L] masking the input data

  • reduction – string for specific tf.math.unsorted_segment_* function