Source code for federatedscope.core.auxiliaries.transform_builder
from importlib import import_module
import federatedscope.register as register
[docs]def get_transform(config, package):
"""
This function is to build transforms applying to dataset.
Args:
config: ``CN`` from ``federatedscope/core/configs/config.py``
package: one of package from \
``['torchvision', 'torch_geometric', 'torchtext', 'torchaudio']``
Returns:
Dict of transform functions.
"""
transform_funcs = {}
for name in ['transform', 'target_transform', 'pre_transform']:
if config.data[name]:
transform_funcs[name] = config.data[name]
val_transform_funcs = {}
for name in ['val_transform', 'val_target_transform', 'val_pre_transform']:
suf_name = name.split('val_')[1]
if config.data[name]:
val_transform_funcs[suf_name] = config.data[name]
test_transform_funcs = {}
for name in [
'test_transform', 'test_target_transform', 'test_pre_transform'
]:
suf_name = name.split('test_')[1]
if config.data[name]:
test_transform_funcs[suf_name] = config.data[name]
# Transform are all `[]`, do not import package and return dict with
# None value
if len(transform_funcs) == 0 and len(val_transform_funcs) == 0 and len(
test_transform_funcs) == 0:
return {}, {}, {}
transforms = getattr(import_module(package), 'transforms')
def convert(trans):
# Recursively converting expressions to functions
if isinstance(trans[0], str):
if len(trans) == 1:
trans.append({})
transform_type, transform_args = trans
for func in register.transform_dict.values():
transform_func = func(transform_type, transform_args)
if transform_func is not None:
return transform_func
transform_func = getattr(transforms,
transform_type)(**transform_args)
return transform_func
else:
transform = [convert(x) for x in trans]
if hasattr(transforms, 'Compose'):
return transforms.Compose(transform)
elif hasattr(transforms, 'Sequential'):
return transforms.Sequential(transform)
else:
return transform
# return composed transform or return list of transform
if transform_funcs:
for key in transform_funcs:
transform_funcs[key] = convert(config.data[key])
if val_transform_funcs:
for key in val_transform_funcs:
val_transform_funcs[key] = convert(config.data[key])
else:
val_transform_funcs = transform_funcs
if test_transform_funcs:
for key in test_transform_funcs:
test_transform_funcs[key] = convert(config.data[key])
else:
test_transform_funcs = transform_funcs
return transform_funcs, val_transform_funcs, test_transform_funcs