import random
from torch_geometric.data import Batch
[docs]class RandomView():
r"""Generate views by random transformation (augmentation) on given batched graphs,
where each graph in the batch is treated independently. Class objects callable via
method :meth:`views_fn`.
Args:
candidates (list): A list of callable view generation functions (classes).
"""
def __init__(self, candidates):
self.candidates = candidates
[docs] def __call__(self, data):
return self.views_fn(data)
[docs] def views_fn(self, batch_data):
r"""Method to be called when :class:`RandomView` object is called.
Args:
batch_data (:class:`torch_geometric.data.Batch`): The input batched graphs.
:rtype: :class:`torch_geometric.data.Batch`.
"""
data_list = batch_data.to_data_list()
transformed_list = []
for data in data_list:
view_fn = random.choice(self.candidates)
transformed = view_fn(data)
transformed_list.append(transformed)
return Batch.from_data_list(transformed_list)
[docs]class Sequential():
r"""Generate views by applying a sequence of transformations (augmentations) on
given batched graphs. Class objects callable via method :meth:`views_fn`.
Args:
fn_sequence (list): A list of callable view generation functions (classes).
"""
def __init__(self, fn_sequence):
self.fn_sequence = fn_sequence
[docs] def __call__(self, data):
return self.views_fn(data)
[docs] def views_fn(self, data):
r"""Method to be called when :class:`Sequential` object is called.
Args:
data (:class:`torch_geometric.data.Data`): The input graph or batched graphs.
:rtype: :class:`torch_geometric.data.Data`.
"""
for fn in self.fn_sequence:
data = fn(data)
return data