import zipfile
import os
import numpy as np
import os.path as osp
from torch.utils.data import Dataset
LEAF_NAMES = [
'femnist', 'celeba', 'synthetic', 'shakespeare', 'twitter', 'subreddit'
]
def is_exists(path, names):
exists_list = [osp.exists(osp.join(path, name)) for name in names]
return False not in exists_list
[docs]class LEAF(Dataset):
"""
Base class for LEAF dataset from "LEAF: A Benchmark for Federated Settings"
Arguments:
root (str): root path.
name (str): name of dataset, in `LEAF_NAMES`.
transform: transform for x.
target_transform: transform for y.
"""
def __init__(self, root, name, transform, target_transform):
self.root = root
self.name = name
self.data_dict = {}
if name not in LEAF_NAMES:
raise ValueError(f'No leaf dataset named {self.name}')
self.transform = transform
self.target_transform = target_transform
self.process_file()
@property
def raw_file_names(self):
names = ['all_data.zip']
return names
@property
def extracted_file_names(self):
names = ['all_data']
return names
@property
def raw_dir(self):
return osp.join(self.root, self.name, 'raw')
@property
def processed_dir(self):
return osp.join(self.root, self.name, 'processed')
def __repr__(self):
return f'{self.__class__.__name__}({self.__len__()})'
def __len__(self):
return len(self.data_dict)
def __getitem__(self, index):
raise NotImplementedError
def __iter__(self):
for index in range(len(self.data_dict)):
yield self.__getitem__(index)
def download(self):
raise NotImplementedError
def extract(self):
for name in self.raw_file_names:
with zipfile.ZipFile(osp.join(self.raw_dir, name), 'r') as f:
f.extractall(self.raw_dir)
def process_file(self):
os.makedirs(self.processed_dir, exist_ok=True)
if len(os.listdir(self.processed_dir)) == 0:
if not is_exists(self.raw_dir, self.extracted_file_names):
if not is_exists(self.raw_dir, self.raw_file_names):
self.download()
self.extract()
self.process()
def process(self):
raise NotImplementedError
[docs]class LocalDataset(Dataset):
"""
Convert data list to torch Dataset to save memory usage.
"""
def __init__(self,
Xs,
targets,
pre_process=None,
transform=None,
target_transform=None):
assert len(Xs) == len(
targets), "The number of data and labels are not equal."
self.Xs = np.array(Xs)
self.targets = np.array(targets)
self.pre_process = pre_process
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.Xs)
def __getitem__(self, idx):
data, target = self.Xs[idx], self.targets[idx]
if self.pre_process:
data = self.pre_process(data)
if self.transform:
data = self.transform(data)
if self.target_transform:
target = self.target_transform(target)
return data, target
def extend(self, dataset):
self.Xs = np.vstack((self.Xs, dataset.Xs))
self.targets = np.hstack((self.targets, dataset.targets))