Source code for nfp.preprocessing.features

# The rest of the methods in this module are specific functions for computing
# atom and bond features. New ones can be easily added though, and these are
# passed directly to the Preprocessor class.


[docs]def get_ring_size(obj, max_size=12): if not obj.IsInRing(): return 0 else: for i in range(max_size): if obj.IsInRingSize(i): return i else: return "max"
[docs]def atom_features_v1(atom): """Return an integer hash representing the atom type""" return str( ( atom.GetSymbol(), atom.GetDegree(), atom.GetTotalNumHs(), atom.GetImplicitValence(), atom.GetIsAromatic(), ) )
[docs]def atom_features_v2(atom): props = [ "GetChiralTag", "GetDegree", "GetExplicitValence", "GetFormalCharge", "GetHybridization", "GetImplicitValence", "GetIsAromatic", "GetNoImplicit", "GetNumExplicitHs", "GetNumImplicitHs", "GetNumRadicalElectrons", "GetSymbol", "GetTotalDegree", "GetTotalNumHs", "GetTotalValence", ] atom_type = [getattr(atom, prop)() for prop in props] atom_type += [get_ring_size(atom)] return str(tuple(atom_type))
[docs]def bond_features_v1(bond, **kwargs): """Return an integer hash representing the bond type. flipped : bool Only valid for 'v3' version, whether to swap the begin and end atom types """ return str( ( bond.GetBondType(), bond.GetIsConjugated(), bond.IsInRing(), sorted([bond.GetBeginAtom().GetSymbol(), bond.GetEndAtom().GetSymbol()]), ) )
[docs]def bond_features_v2(bond, **kwargs): return str( ( bond.GetBondType(), bond.GetIsConjugated(), bond.GetStereo(), get_ring_size(bond), sorted([bond.GetBeginAtom().GetSymbol(), bond.GetEndAtom().GetSymbol()]), ) )
[docs]def bond_features_v3(bond, flipped=False): if not flipped: start_atom = atom_features_v1(bond.GetBeginAtom()) end_atom = atom_features_v1(bond.GetEndAtom()) else: start_atom = atom_features_v1(bond.GetEndAtom()) end_atom = atom_features_v1(bond.GetBeginAtom()) return str( ( bond.GetBondType(), bond.GetIsConjugated(), bond.GetStereo(), get_ring_size(bond), bond.GetEndAtom().GetSymbol(), start_atom, end_atom, ) )
[docs]def bond_features_wbo(start_atom, end_atom, bondatoms): start_atom_symbol = bondatoms[0].GetSymbol() end_atom_symbol = bondatoms[1].GetSymbol() return str((start_atom_symbol, end_atom_symbol))