dig.ggraph3D.method¶
Method classes under dig.ggraph3D.method
.
The method class for G-SphereNet algorithm proposed in the paper An Autoregressive Flow Model for 3D Molecular Geometry Generation from Scratch. |
- class G_SphereNet[source]¶
The method class for G-SphereNet algorithm proposed in the paper An Autoregressive Flow Model for 3D Molecular Geometry Generation from Scratch. This class provides interfaces for running training and generation with G-SphereNet algorithm. Please refer to the example codes for usage examples.
- generate(model_conf_dict, checkpoint_path, n_mols=1000, chunk_size=100, num_min_node=7, num_max_node=25, temperature=[1.0, 1.0, 1.0, 1.0], focus_th=0.5)[source]¶
Running graph generation for random generation task.
- Parameters
model_conf_dict (dict) – The python dict for configuring the model hyperparameters.
checkpoint_path (str) – The path to the saved model checkpoint file.
n_mols (int, optional) – The number of molecular geometries to generate. (default:
1000
)chunk_size (int, optional) – The maximum number of molecular geometries that are allowed to be generated in parallel. (default:
100
)num_min_node (int, optional) – The minimum number of nodes in the generated molecular geometries. (default:
7
)num_max_node (int, optional) – the maximum number of nodes in the generated molecular geometries. (default:
25
)temperature (list, optional) – a list of four float numbers, the temperature parameter of prior distribution. (default:
[1.0, 1.0, 1.0, 1.0]
)focus_th (float, optional) – The threshold for focus node classification. (default:
0.5
)
- Return type
mol_dicts, A python dict where the key is the number of atoms, and the value indexed by that key is another python dict storing the atomic number matrix (indexed by the key ‘_atomic_numbers’) and the coordinate tensor (indexed by the key ‘_positions’) of all generated molecular geometries with that atom number.
- train(loader, lr, wd, max_epochs, model_conf_dict, checkpoint_path, save_interval, save_dir)[source]¶
Running training for random generation task.
- Parameters
loader – The data loader for loading training samples. It is supposed to use dig.ggraph3D.dataset.QM93DGEN as the dataset class, and apply torch.utils.data.DataLoader to it to form the data loader.
lr (float) – The learning rate for training.
wd (float) – The weight decay factor for training.
max_epochs (int) – The maximum number of training epochs.
model_conf_dict (dict) – The python dict for configuring the model hyperparameters.
save_interval (int) – Indicate the frequency to save the model parameters to .pth files, e.g., if save_interval=2, the model parameters will be saved for every 2 training epochs.
save_dir (str) – The directory to save the model parameters.