dig.ggraph3D.dataset

Dataset interfaces under dig.ggraph3D.dataset.

QM93DGEN

A Pytorch Geometric data interface for datasets used in molecule generation.

collate_fn

The collate function used to form a mini-batch of data when creating data loaders with the dataset class dig.ggraph3D.dataset.QM93DGEN.

class QM93DGEN(root='./qm9_3Dgen', subset_idxs=None, transform=None, pre_transform=None, pre_filter=None)[source]

A Pytorch Geometric data interface for datasets used in molecule generation.

Note

When creating data loaders of this dataset class, only dig.ggraph3D.dataset.collate_fn can be used as the collate function.

Parameters
  • root (string, optional) – Root directory where the dataset should be saved. (default: /)

  • subset_idxs (list, optional) – if it is not None, only the data located at the indexs in subset_idxs of the dataset will be sampled. (default: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

download()[source]

Downloads the dataset to the self.raw_dir folder.

get(idx)[source]

Gets the data object at index :idx:.

Parameters

idx – The index of the data that you want to reach.

Return type

dict a python dict with the following items: “atom_type” — the atom types of previously generated geometries at each generation step; “position” — the atom coordinates of previously generated geometries at each generation step; “batch” — the identity indexes used to discriminate different generation step; “focus” — the index of focus atom at each generation step; “c1_focus” — the index of c1 and focus atom at each generation step; “c2_c1_focus” — the index of c2, c1, and focus atom at each generation step; “new_atom_type” — the atom types to be generated at each generation step; “new_dist” — the distances to be generated at each generation step; “new_angle” — the angles to be generated at each generation step; “new_torsion” — the torsion angles to be generated at each generation step; “cannot_focus” — the labels denoting whether each atom can serve as the focus atom or not in the previously generated geometries at each generation step.

get_idx_split(task)[source]

Gets the train-valid set split indices of the dataset for different tasks.

Parameters

task – The name of the task that the dataset will be used in, including ‘rand_gen’ for random molecular geometry generation, ‘gap_opt’ for discovering molecular geometries with low HOMO-LUMO gaps, and ‘alpha_opt’ for discovering molecular geometries with high isotropic polarizabilities.

Return type

A dictionary for training-validation split with key train and valid.

len()[source]

Gets the number of molecular geometries that can be sampled in total.

process()[source]

Processes the dataset from raw data file to the self.processed_dir folder.

property processed_file_names

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

property raw_file_names

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

collate_fn(data_batch_list)[source]

The collate function used to form a mini-batch of data when creating data loaders with the dataset class dig.ggraph3D.dataset.QM93DGEN.

Parameters

data_batch_list – a list of python dict returned from the get function of dig.ggraph3D.dataset.QM93DGEN.

Return type

dict a python dict with the same keys as every python dict in data_batch_list.