import json
import logging
import warnings
from abc import ABC, abstractmethod
from inspect import getmembers
from typing import Any, Dict, Optional
import networkx as nx
import numpy as np
from nfp.frameworks import tf
from nfp.preprocessing.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
[docs]class Preprocessor(ABC):
"""A base class for graph preprocessing
Parameters
----------
output_dtype
A parameter used in child classes for determining the datatype of
the returned arrays
"""
def __init__(self, output_dtype: str = "int32"):
self.output_dtype = output_dtype
[docs] @abstractmethod
def create_nx_graph(self, structure: Any, *args, **kwargs) -> nx.DiGraph:
"""Given an input structure object, convert it to a networkx digraph
with node, edge, and graph features assigned.
Parameters
----------
structure
Any input graph object
kwargs
keyword arguments passed from `__call__`, useful for specifying
additional features in addition to the graph object.
Returns
-------
nx.DiGraph
A networkx graph with the node, edge, and graph features set
"""
pass
[docs] @abstractmethod
def get_edge_features(
self, edge_data: list, max_num_edges: int
) -> Dict[str, np.ndarray]:
"""Given a list of edge features from the nx.Graph, processes and
concatenates them to an array.
Parameters
----------
edge_data
A list of edge data generated by `nx_graph.edges(data=True)`
max_num_edges
If desired, this function should pad to a maximum number of edges
passed from the `__call__` function.
Returns
-------
Dict[str, np.ndarray]
a dictionary of feature, array pairs, where array contains features for
all edges in the graph.
"""
pass
[docs] @abstractmethod
def get_node_features(
self, node_data: list, max_num_nodes: int
) -> Dict[str, np.ndarray]:
"""Given a list of node features from the nx.Graph, processes and
concatenates them to an array.
Parameters
----------
node_data
A list of edge data generated by `nx_graph.nodes(data=True)`
max_num_nodes
If desired, this function should pad to a maximum number of nodes
passed from the `__call__` function.
Returns
-------
Dict[str, np.ndarray]
a dictionary of feature, array pairs, where array contains features for
all nodes in the graph.
"""
pass
[docs] @abstractmethod
def get_graph_features(self, graph_data: dict) -> Dict[str, np.ndarray]:
"""Process the nx.graph features into a dictionary of arrays.
Parameters
----------
graph_data
A dictionary of graph data generated by `nx_graph.graph`
Returns
-------
Dict[str, np.ndarray]
a dictionary of features for the graph
"""
pass
[docs] @staticmethod
def get_connectivity(
graph: nx.DiGraph, max_num_edges: int
) -> Dict[str, np.ndarray]:
"""Get the graph connectivity from the networkx graph
Parameters
----------
graph
The input graph
max_num_edges
len(graph.edges), or the specified maximum number of graph edges
Returns
-------
Dict[str, np.ndarray]
A dictionary of with the single 'connectivity' key, containing an (n,2)
array of (node_index, node_index) pairs indicating the start and end
nodes for each edge.
"""
connectivity = np.zeros((max_num_edges, 2), dtype="int64")
if len(graph.edges) > 0: # Handle odd case with no edges
connectivity[: len(graph.edges)] = np.asarray(graph.edges)
return {"connectivity": connectivity}
[docs] def __call__(
self,
structure: Any,
*args,
train: bool = False,
max_num_nodes: Optional[int] = None,
max_num_edges: Optional[int] = None,
**kwargs,
) -> Dict[str, np.ndarray]:
"""Convert an input graph structure into a featurized set of node, edge,
and graph-level features.
Parameters
----------
structure
An input graph structure (i.e., molecule, crystal, etc.)
train
A training flag passed to `Tokenizer` member attributes
max_num_nodes
A size attribute passed to `get_node_features`, defaults to the
number of nodes in the current graph
max_num_edges
A size attribute passed to `get_edge_features`, defaults to the
number of edges in the current graph
kwargs
Additional features or parameters passed to `construct_nx_graph`
Returns
-------
Dict[str, np.ndarray]
A dictionary of key, array pairs as a single sample.
"""
nx_graph = self.create_nx_graph(structure, *args, **kwargs)
max_num_edges = len(nx_graph.edges) if max_num_edges is None else max_num_edges
assert (
len(nx_graph.edges) <= max_num_edges
), "max_num_edges too small for given input"
max_num_nodes = len(nx_graph.nodes) if max_num_nodes is None else max_num_nodes
assert (
len(nx_graph.nodes) <= max_num_nodes
), "max_num_nodes too small for given input"
# Make sure that Tokenizer classes are correctly initialized
for _, tokenizer in getmembers(self, lambda x: type(x) == Tokenizer):
tokenizer.train = train
node_features = self.get_node_features(nx_graph.nodes(data=True), max_num_nodes)
edge_features = self.get_edge_features(nx_graph.edges(data=True), max_num_edges)
graph_features = self.get_graph_features(nx_graph.graph)
connectivity = self.get_connectivity(nx_graph, max_num_edges)
return {**node_features, **edge_features, **graph_features, **connectivity}
[docs] def construct_feature_matrices(
self, *args, train=False, **kwargs
) -> Dict[str, np.ndarray]:
"""
.. deprecated:: 0.3.0
`construct_feature_matrices` will be removed in 0.4.0, use
`__call__` instead
"""
warnings.warn(
"construct_feature_matrices is deprecated, use `call` instead as "
"of nfp 0.4.0",
DeprecationWarning,
)
return self(*args, train=train, **kwargs)
[docs] def to_json(self, filename: str) -> None:
"""Serialize the classes's data to a json file"""
with open(filename, "w") as f:
json.dump(self, f, default=lambda x: x.__dict__)
[docs] def from_json(self, filename: str) -> None:
"""Set's the class's data with attributes taken from the save file"""
with open(filename, "r") as f:
json_data = json.load(f)
load_from_json(self, json_data)
@property
@abstractmethod
def output_signature(self) -> Dict[str, tf.TensorSpec]:
pass
@property
@abstractmethod
def padding_values(self) -> Dict[str, tf.constant]:
pass
@property
def tfrecord_features(self) -> Dict[str, tf.io.FixedLenFeature]:
"""Useful feature for storing preprocessed outputs in tfrecord files"""
raise NotImplementedError
[docs]class PreprocessorMultiGraph(Preprocessor, ABC):
"""Class to handle graphs with parallel edges and self-loops"""
[docs] @abstractmethod
def create_nx_graph(self, structure: Any, **kwargs) -> nx.MultiDiGraph:
pass
[docs] @staticmethod
def get_connectivity(
graph: nx.DiGraph, max_num_edges: int
) -> Dict[str, np.ndarray]:
# Don't include keys in the connectivity matrix
connectivity = np.zeros((max_num_edges, 2), dtype="int64")
if len(graph.edges) > 0: # Handle odd case with no edges
connectivity[: len(graph.edges)] = np.asarray(graph.edges)[:, :2]
return {"connectivity": connectivity}
[docs]def load_from_json(obj: Any, data: Dict):
"""Function to set member attributes from json data recursively.
Parameters
----------
obj
the class to initialize
data
a dictionary of potentially nested attribute - value pairs
Returns
-------
Any
The object, with attributes set to those from the data file.
"""
for key, val in obj.__dict__.items():
try:
if isinstance(val, type(data[key])):
obj.__dict__[key] = data[key]
elif hasattr(val, "__dict__"):
load_from_json(val, data[key])
except KeyError:
logger.warning(
f"{key} not found in JSON file, it may have been created with"
" an older nfp version"
)