From f60db54dfb48544b7b36c6ee495621668c364e29 Mon Sep 17 00:00:00 2001 From: Lucas Hervier Date: Fri, 28 Jun 2024 15:25:18 +0200 Subject: [PATCH] added weighting of language along domains --- training/collect_data_and_weights_ablation.py | 67 ++++++++++--------- training/domain_proportions.yml | 49 +++++++++----- 2 files changed, 67 insertions(+), 49 deletions(-) diff --git a/training/collect_data_and_weights_ablation.py b/training/collect_data_and_weights_ablation.py index a492041..bb47751 100644 --- a/training/collect_data_and_weights_ablation.py +++ b/training/collect_data_and_weights_ablation.py @@ -234,14 +234,13 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too "domain_target_proportions": domain_target_proportions, "additional_weights": additional_weights, } + if not os.path.exists(args.save_weights_path): + os.makedirs(args.save_weights_path) with open(f"{args.save_weights_path}/proportion_args.json", "w") as f: f.write(json.dumps(proportion_dict, indent=4)) stats_datasets = read_stats_datasets() - # import json - # print(json.dumps(stats_datasets, indent=4)) - not_tokenized_datasets = list(stats_datasets.keys()) prefixes = [] @@ -278,18 +277,23 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too if re.search(content, prefix): additional_weight *= additional_weights[content] - domains = d["category"].split("-") + # get the corresponding domain of the dataset + category = d["category"] + language = d["language"] + domain = language + "--" + category + + # get the number of tokens in the dataset count = d[args.count] count_weighted = additional_weight * count - for domain in domains: - num_tokens_per_domain_weighted[domain] = num_tokens_per_domain_weighted.get( - domain, 0 - ) + (count_weighted // len(domains)) - num_tokens_per_domain[domain] = num_tokens_per_domain.get(domain, 0) + ( - count // len(domains) - ) - if domain == "code": + # update the number of tokens and the number of tokens weighted for the domain + num_tokens_per_domain_weighted[domain] = num_tokens_per_domain_weighted.get( + domain, 0 + ) + count_weighted + num_tokens_per_domain[domain] = num_tokens_per_domain.get(domain, 0) + count + + # update the number of tokens for the programming language if it is a code dataset + if domain == "code--programming": prog_lang = format_programming_language(name) num_tokens_per_programming_language[prog_lang] = ( num_tokens_per_programming_language.get(prog_lang, 0) + count @@ -318,23 +322,16 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too ] ) - domain_target_proportion_rest = 1 - sum(domain_target_proportions.values()) - assert ( - domain_target_proportion_rest >= 0 and domain_target_proportion_rest < 1 - ), f"{domain_target_proportion_rest=}" + # normalize the domain_target_proportions + total_proportions = sum(domain_target_proportions.values()) + assert total_proportions > 0 + normalized_domain_target_proportions = {k: v/total_proportions for k,v in domain_target_proportions.items()} # Set the weights for domain (newspaper, book, code, ...) domain_weights = {} for domain, count_weighted in num_tokens_per_domain_weighted.items(): - if domain in domain_target_proportions: - target_proportion = domain_target_proportions[domain] - else: - target_proportion = ( - domain_target_proportion_rest - * count_weighted - / total_count_weighted_rest - ) - domain_target_proportions[domain] = target_proportion + assert domain in normalized_domain_target_proportions, f"{domain=} not found in {domain_target_proportions}" + target_proportion = normalized_domain_target_proportions[domain] weight = target_proportion / (count_weighted / total_count_weighted) domain_weights[domain] = weight @@ -351,7 +348,7 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too ), f"{language=} not found" target_proportion = ( programming_language_target_proportions[language] - * domain_target_proportions["code"] + * domain_target_proportions["code--programming"] ) weight = target_proportion / (count_weighted / total_count_weighted) programming_language_weights[language] = weight @@ -394,8 +391,6 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too ) print("```\n") - print("# Weights per sub-domain\n```") - for second_pass in [False, True]: if not second_pass: all_weights = {} @@ -414,12 +409,20 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too final_weight_dict = {} for prefix, d in data.items(): - domains = d["category"].split("-") + # get the corresponding domain of the dataset + language = d["language"] + category = d["category"] + domain = language + "--" + category + + # get the number of tokens in the dataset count = d[args.count] + # get the proportion of the dataset in the total number of tokens ratio = count / total_count - domain_weight = max(domain_weights[domain] for domain in domains) - if d["category"] == "code": + domain_weight = domain_weights[domain] + + # override the domain weight if the dataset is a code dataset by the programming language weight + if d["category"] == "programming": prog_language = format_programming_language(prefix) domain_weight = programming_language_weights[prog_language] @@ -459,5 +462,3 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too print("```") else: print() - -# DATASET="$(python ~/Lucie-Training/training/collect_data_and_weights_alt.py /local_data/lucie_tokens_65k_grouped)" diff --git a/training/domain_proportions.yml b/training/domain_proportions.yml index c979ead..c8f4324 100644 --- a/training/domain_proportions.yml +++ b/training/domain_proportions.yml @@ -1,16 +1,33 @@ -newspaper: 1 -book: 2.5 -technical: 2 -wiki: 3 -legal: 2 -misc: 0.1 -dialogue: 2 -parlementary: 0.05 -forum: 1 -math: 3 -code: 1 -aligned: 3 -legi_dialogue: 2 -legi_spoken: 2 -legi_written: 2 -programming: 1 +fr--newspaper: 13.27 +fr--book: 5.9 +fr--technical: 9.47 +fr--wiki: 1.31 +fr--legi_written: 0.36 +fr--dialogue: 0.01 +fr--legi_spoken: 0.09 +fr--legi_dialogue: 0.08 +en--technical: 16.22 +en--newspaper: 1.62 +en--legi_written: 1.88 +en--wiki: 2.69 +en--forum: 1.41 +en--book: 1.93 +en--math: 4.95 +en--dialogue: 0.26 +en--legi_dialogue: 0.02 +de--wiki: 1.0 +de--legi_written: 0.41 +de--book: 0.06 +de--legi_dialogue: 0.02 +es--wiki: 0.8 +es--legi_written: 0.3 +es--legi_dialogue: 0.01 +es--book: 0.04 +it--wiki: 0.77 +it--legi_written: 0.3 +it--book: 0.04 +es-en--aligned: 0.05 +it-en--aligned: 0.05 +de-fr--aligned: 0.05 +fr-en--aligned: 8.67 +code--programming: 25.96