nfp.preprocessing.crystal_preprocessor.PymatgenPreprocessor

class PymatgenPreprocessor(radius=None, num_neighbors=12, **kwargs)[source]

Bases: nfp.preprocessing.preprocessor.PreprocessorMultiGraph

Methods

construct_feature_matrices

Deprecated since version 0.3.0.

create_nx_graph

crystal should be a pymatgen.core.Structure object.

from_json

Set's the class's data with attributes taken from the save file

get_connectivity

Get the graph connectivity from the networkx graph

get_edge_features

Given a list of edge features from the nx.Graph, processes and concatenates them to an array.

get_graph_features

Process the nx.graph features into a dictionary of arrays.

get_node_features

Given a list of node features from the nx.Graph, processes and concatenates them to an array.

site_features

to_json

Serialize the classes's data to a json file

Attributes

output_signature

padding_values

site_classes

tfrecord_features

Useful feature for storing preprocessed outputs in tfrecord files

create_nx_graph(crystal, **kwargs)[source]

crystal should be a pymatgen.core.Structure object.

Return type

networkx.classes.multidigraph.MultiDiGraph

get_edge_features(edge_data, max_num_edges)[source]

Given a list of edge features from the nx.Graph, processes and concatenates them to an array.

Parameters
  • edge_data (list) – A list of edge data generated by nx_graph.edges(data=True)

  • max_num_edges – If desired, this function should pad to a maximum number of edges passed from the __call__ function.

Returns

a dictionary of feature, array pairs, where array contains features for all edges in the graph.

Return type

Dict[str, np.ndarray]

get_node_features(node_data, max_num_nodes)[source]

Given a list of node features from the nx.Graph, processes and concatenates them to an array.

Parameters
  • node_data (list) – A list of edge data generated by nx_graph.nodes(data=True)

  • max_num_nodes – If desired, this function should pad to a maximum number of nodes passed from the __call__ function.

Returns

a dictionary of feature, array pairs, where array contains features for all nodes in the graph.

Return type

Dict[str, np.ndarray]

get_graph_features(graph_data)[source]

Process the nx.graph features into a dictionary of arrays.

Parameters

graph_data (dict) – A dictionary of graph data generated by nx_graph.graph

Returns

a dictionary of features for the graph

Return type

Dict[str, np.ndarray]

property tfrecord_features: Dict[str, tensorflow.io.FixedLenFeature]

Useful feature for storing preprocessed outputs in tfrecord files

__call__(structure, *args, train=False, max_num_nodes=None, max_num_edges=None, **kwargs)
Convert an input graph structure into a featurized set of node, edge,

and graph-level features.

Parameters
  • structure (Any) – An input graph structure (i.e., molecule, crystal, etc.)

  • train (bool) – A training flag passed to Tokenizer member attributes

  • max_num_nodes (Optional[int]) – A size attribute passed to get_node_features, defaults to the number of nodes in the current graph

  • max_num_edges (Optional[int]) – A size attribute passed to get_edge_features, defaults to the number of edges in the current graph

  • kwargs – Additional features or parameters passed to construct_nx_graph

Returns

A dictionary of key, array pairs as a single sample.

Return type

Dict[str, np.ndarray]

construct_feature_matrices(*args, train=False, **kwargs)

Deprecated since version 0.3.0: construct_feature_matrices will be removed in 0.4.0, use __call__ instead

Return type

Dict[str, numpy.ndarray]

from_json(filename)

Set’s the class’s data with attributes taken from the save file

Parameters

filename (str) –

Return type

None

static get_connectivity(graph, max_num_edges)

Get the graph connectivity from the networkx graph

Parameters
  • graph (networkx.classes.digraph.DiGraph) – The input graph

  • max_num_edges (int) – len(graph.edges), or the specified maximum number of graph edges

Returns

A dictionary of with the single ‘connectivity’ key, containing an (n,2) array of (node_index, node_index) pairs indicating the start and end nodes for each edge.

Return type

Dict[str, np.ndarray]

to_json(filename)

Serialize the classes’s data to a json file

Parameters

filename (str) –

Return type

None