Source code for nfp.preprocessing.crystal_preprocessor

from typing import Dict

import networkx as nx
import numpy as np
from nfp.frameworks import tf
from nfp.preprocessing.preprocessor import PreprocessorMultiGraph
from nfp.preprocessing.tokenizer import Tokenizer


[docs]class PymatgenPreprocessor(PreprocessorMultiGraph): def __init__(self, radius=None, num_neighbors=12, **kwargs): super(PymatgenPreprocessor, self).__init__(**kwargs) self.site_tokenizer = Tokenizer() self.radius = radius self.num_neighbors = num_neighbors
[docs] def create_nx_graph(self, crystal, **kwargs) -> nx.MultiDiGraph: """crystal should be a pymatgen.core.Structure object.""" g = nx.MultiDiGraph(crystal=crystal) g.add_nodes_from(((i, {"site": site}) for i, site in enumerate(crystal.sites))) if self.radius is None: # Get the expected number of sites / volume, then find a radius # expected to yield 2x the desired number of neighbors desired_vol = (crystal.volume / crystal.num_sites) * self.num_neighbors radius = 2 * (desired_vol / (4 * np.pi / 3)) ** (1 / 3) else: radius = self.radius for i, neighbors in enumerate(crystal.get_all_neighbors(radius)): if len(neighbors) < self.num_neighbors: raise RuntimeError(f"Only {len(neighbors)} neighbors for site {i}") sorted_neighbors = sorted(neighbors, key=lambda x: x[1])[ : self.num_neighbors ] for _, distance, j, _ in sorted_neighbors: g.add_edge(i, j, distance=distance) return g
[docs] def get_edge_features( self, edge_data: list, max_num_edges ) -> Dict[str, np.ndarray]: edge_feature_matrix = np.empty(max_num_edges, dtype="float32") edge_feature_matrix[:] = np.nan # Initialize distances with nans for n, (_, _, edge_dict) in enumerate(edge_data): edge_feature_matrix[n] = edge_dict["distance"] return {"distance": edge_feature_matrix}
[docs] def get_node_features( self, node_data: list, max_num_nodes ) -> Dict[str, np.ndarray]: site_feature_matrix = np.zeros(max_num_nodes, dtype=self.output_dtype) for n, site_dict in node_data: site_feature_matrix[n] = self.site_tokenizer( self.site_features(site_dict["site"]) ) return {"site": site_feature_matrix}
[docs] def get_graph_features(self, graph_data: dict) -> Dict[str, np.ndarray]: return {}
@property def site_classes(self): return self.site_tokenizer.num_classes + 1 @staticmethod def site_features(site): species = site.as_dict()["species"] assert len(species) == 1 return species[0]["element"] @property def output_signature(self) -> Dict[str, tf.TensorSpec]: if tf is None: raise ImportError("Tensorflow was not found") return { "site": tf.TensorSpec(shape=(None,), dtype=self.output_dtype), "distance": tf.TensorSpec(shape=(None,), dtype="float32"), "connectivity": tf.TensorSpec(shape=(None, 2), dtype=self.output_dtype), } @property def padding_values(self) -> Dict[str, tf.constant]: if tf is None: raise ImportError("Tensorflow was not found") return { "site": tf.constant(0, dtype=self.output_dtype), "distance": tf.constant(np.nan, dtype="float32"), "connectivity": tf.constant(0, dtype=self.output_dtype), } @property def tfrecord_features(self) -> Dict[str, tf.io.FixedLenFeature]: if tf is None: raise ImportError("Tensorflow was not found") return { key: tf.io.FixedLenFeature([], dtype=tf.string) for key in self.output_signature.keys() }