class Config(object):
text_datasets_dir = None
output_format = "jsonl"
output_compression = None
raw_datasets_dir = None
shuffled_datasets_dir = None
composed_dataset_dir = None # composed dataset (train/val split) is saved into this directory
local_dirs_by_dataset_id = {}
local_dirs_by_source_id = {}
sampling_factor_by_dataset_id = {}
sampling_factor_by_source_id = {}
sampling_factor_by_language = {}
only_selected_datasets: bool = False
selected_dataset_ids: List[str] = []
selected_source_ids: List[str] = []
validation_ratio = 0.005 # number of documents in the split: len(dataset) * ratio
validation_min_total_docs = 1_000 # to be used as validation set, the dataset must have at least n docs
validation_max_split_docs = 1_000 # number of documents in validation split are capped at this numbers
validation_min_split_docs = 10 # split must have at least this number of documents, otherwise it will be discarded
tokenizer_train_ratio = 0.1 # % of train data used for tokenizer training
# Vocab size should divisble by 8
# - Jan's recommendation: 250680
# - NVIDIA recommendation for multilingual models: 256000
tokenizer_vocab_size: int = 256000
tokenizer_model_type: Literal["bpe", "unigram", "word", "char"] = "bpe" # SP model types
seed: int = 0
extra_dataset_registries: Union[None, str, List[str]] = None
extra_dataset_classes: Union[None, List] = None
use_default_dataset_registry: bool = True
# Datasets are initialized with these kwargs
extra_dataset_kwargs: dict[str, dict] = {}
use_documents: bool = False
workers: int = 0
limit: int = 0
skip_items = 0
job_id = None
save_stats = True
verbose = False
log_file = None
override = False
def __init__(self, **entries):
self.__dict__.update(entries)
def init_logger(self, logger_name):
log_handlers = [logging.StreamHandler()]
if self.log_file:
log_handlers.append(logging.FileHandler(self.log_file))
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.DEBUG if self.verbose else logging.INFO,
handlers=log_handlers,
)
logger = logging.getLogger(logger_name)
return logger
def get_extra_dataset_kwargs(self, dataset_id) -> dict:
try:
return self.extra_dataset_kwargs[dataset_id]
except KeyError:
return {}
def get_selected_dataset_ids(self, mode: Literal["all", "exact", "fnmatch"] = "all"):
if mode == "exact":
# only ids for exact match
return [s for s in self.selected_dataset_ids if "*" not in s and "?" not in s]
elif mode == "fnmatch":
# only ids for fnmatch
return [s for s in self.selected_dataset_ids if "*" in s or "?" in s]
else:
# all
return self.selected_dataset_ids
def get_job_id(self) -> Union[None, str]:
"""Returns manually set job ID or from environment variable (SLURM_JOBID)"""
if self.job_id is None:
self.job_id = os.environ.get("SLURM_JOBID", "0")
return self.job_id
def get_key_value_pairs(self, keys: Iterable) -> Dict:
return {k: getattr(self, k) for k in keys}