dig.auggraph.method¶
Graph Augmentation Methods
GraphAug¶
An augmentation method for graph datasets under dig.auggraph.method.GraphAug
implemented from the paper
Automated Data Augmentations for Graph Classification.
Runs the training of a graph classification model using the augmented data generated by the already trained generator. |
|
Runs the training of an augmented samples generator model which uses the already trained reward generation model. |
|
Runs the training of a reward generation model which will be able to distinguish between graphs with different labels. |
- class RunnerAugCls(data_root_path, dataset_name, conf)[source]¶
Runs the training of a graph classification model using the augmented data generated by the already trained generator. Check
examples.auggraph.GraphAug.run_aug_cls
for examples on how to run this augmented classifier model.- Parameters
data_root_path (string) – Directory where datasets should be saved.
dataset_name (
dig.auggraph.method.GraphAug.constants.enums.DatasetName
) – Name of the graph dataset.conf (dict) – Hyperparameters for the model. Check
examples.auggraph.GraphAug.conf.aug_cls_conf
for examples on how to define the conf dictionary for the augmented classifier model.
- train_test(out_root_path, log_file='record.txt')[source]¶
This method is used to run the training for the classification model on the augmented graph dataset and then validate the epoch parameters.
- Parameters
out_root_path (string) – Directory where the results of this augmented classifier model will be saved.
log_file (string) – File where training and validation logs are written.
- class RunnerGenerator(data_root_path, dataset_name, conf)[source]¶
Runs the training of an augmented samples generator model which uses the already trained reward generation model. For a given graph, the model generates an augmented sample and a likelihood that this is a label invariant augmentation. This prediction is then evaluated by the reward generation model and a loss is computed based on these metrics. The loss is then minimized through training. Check
examples.auggraph.GraphAug.run_generator
for examples on how to run the generator model.- Parameters
data_root_path (string) – Directory where datasets should be saved.
dataset_name (
dig.auggraph.method.GraphAug.constants.enums.DatasetName
) – Name of the graph dataset.conf (dict) – Hyperparameters for the model. Check
examples.auggraph.GraphAug.conf.generator_conf
for examples on how to define the conf dictionary for the generator.
- class RunnerRewardGen(data_root_path, dataset_name, conf)[source]¶
Runs the training of a reward generation model which will be able to distinguish between graphs with different labels. Check
examples.auggraph.GraphAug.run_reward_gen
for examples on how to run the reward generation model.- Parameters
data_root_path (string) – Directory where datasets should be saved.
dataset_name (
dig.auggraph.method.GraphAug.constants.enums.DatasetName
) – Name of the graph dataset.conf (dict) – Hyperparameters for the model. Check
examples.auggraph.GraphAug.conf.reward_gen_conf
for examples on how to define the conf dictionary for the reward generator.
- train_test(results_path, num_save=30)[source]¶
This method is used to run the training for the reward generation model and validate the epoch parameters.
- Parameters
results_path (string) – Directory where the resulting optimal parameters of the reward generation model will be saved.
num_save (int) – Number of final epochs for which model parameters will be saved.