dig.ggraph3D.method

Method classes under dig.ggraph3D.method.

G_SphereNet

The method class for G-SphereNet algorithm proposed in the paper An Autoregressive Flow Model for 3D Molecular Geometry Generation from Scratch.

class G_SphereNet[source]

The method class for G-SphereNet algorithm proposed in the paper An Autoregressive Flow Model for 3D Molecular Geometry Generation from Scratch. This class provides interfaces for running training and generation with G-SphereNet algorithm. Please refer to the example codes for usage examples.

generate(model_conf_dict, checkpoint_path, n_mols=1000, chunk_size=100, num_min_node=7, num_max_node=25, temperature=[1.0, 1.0, 1.0, 1.0], focus_th=0.5)[source]

Running graph generation for random generation task.

Parameters
  • model_conf_dict (dict) – The python dict for configuring the model hyperparameters.

  • checkpoint_path (str) – The path to the saved model checkpoint file.

  • n_mols (int, optional) – The number of molecular geometries to generate. (default: 1000)

  • chunk_size (int, optional) – The maximum number of molecular geometries that are allowed to be generated in parallel. (default: 100)

  • num_min_node (int, optional) – The minimum number of nodes in the generated molecular geometries. (default: 7)

  • num_max_node (int, optional) – the maximum number of nodes in the generated molecular geometries. (default: 25)

  • temperature (list, optional) – a list of four float numbers, the temperature parameter of prior distribution. (default: [1.0, 1.0, 1.0, 1.0])

  • focus_th (float, optional) – The threshold for focus node classification. (default: 0.5)

Return type

mol_dicts, A python dict where the key is the number of atoms, and the value indexed by that key is another python dict storing the atomic number matrix (indexed by the key ‘_atomic_numbers’) and the coordinate tensor (indexed by the key ‘_positions’) of all generated molecular geometries with that atom number.

train(loader, lr, wd, max_epochs, model_conf_dict, checkpoint_path, save_interval, save_dir)[source]

Running training for random generation task.

Parameters
  • loader – The data loader for loading training samples. It is supposed to use dig.ggraph3D.dataset.QM93DGEN as the dataset class, and apply torch.utils.data.DataLoader to it to form the data loader.

  • lr (float) – The learning rate for training.

  • wd (float) – The weight decay factor for training.

  • max_epochs (int) – The maximum number of training epochs.

  • model_conf_dict (dict) – The python dict for configuring the model hyperparameters.

  • save_interval (int) – Indicate the frequency to save the model parameters to .pth files, e.g., if save_interval=2, the model parameters will be saved for every 2 training epochs.

  • save_dir (str) – The directory to save the model parameters.