Skip to content

Commit

Permalink
added weighting of language along domains
Browse files Browse the repository at this point in the history
  • Loading branch information
lucashervier authored and thib-s committed Jun 28, 2024
1 parent fb5b136 commit f60db54
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 49 deletions.
67 changes: 34 additions & 33 deletions training/collect_data_and_weights_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,13 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too
"domain_target_proportions": domain_target_proportions,
"additional_weights": additional_weights,
}
if not os.path.exists(args.save_weights_path):
os.makedirs(args.save_weights_path)
with open(f"{args.save_weights_path}/proportion_args.json", "w") as f:
f.write(json.dumps(proportion_dict, indent=4))

stats_datasets = read_stats_datasets()

# import json
# print(json.dumps(stats_datasets, indent=4))

not_tokenized_datasets = list(stats_datasets.keys())

prefixes = []
Expand Down Expand Up @@ -278,18 +277,23 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too
if re.search(content, prefix):
additional_weight *= additional_weights[content]

domains = d["category"].split("-")
# get the corresponding domain of the dataset
category = d["category"]
language = d["language"]
domain = language + "--" + category

# get the number of tokens in the dataset
count = d[args.count]
count_weighted = additional_weight * count
for domain in domains:
num_tokens_per_domain_weighted[domain] = num_tokens_per_domain_weighted.get(
domain, 0
) + (count_weighted // len(domains))
num_tokens_per_domain[domain] = num_tokens_per_domain.get(domain, 0) + (
count // len(domains)
)

if domain == "code":
# update the number of tokens and the number of tokens weighted for the domain
num_tokens_per_domain_weighted[domain] = num_tokens_per_domain_weighted.get(
domain, 0
) + count_weighted
num_tokens_per_domain[domain] = num_tokens_per_domain.get(domain, 0) + count

# update the number of tokens for the programming language if it is a code dataset
if domain == "code--programming":
prog_lang = format_programming_language(name)
num_tokens_per_programming_language[prog_lang] = (
num_tokens_per_programming_language.get(prog_lang, 0) + count
Expand Down Expand Up @@ -318,23 +322,16 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too
]
)

domain_target_proportion_rest = 1 - sum(domain_target_proportions.values())
assert (
domain_target_proportion_rest >= 0 and domain_target_proportion_rest < 1
), f"{domain_target_proportion_rest=}"
# normalize the domain_target_proportions
total_proportions = sum(domain_target_proportions.values())
assert total_proportions > 0
normalized_domain_target_proportions = {k: v/total_proportions for k,v in domain_target_proportions.items()}

# Set the weights for domain (newspaper, book, code, ...)
domain_weights = {}
for domain, count_weighted in num_tokens_per_domain_weighted.items():
if domain in domain_target_proportions:
target_proportion = domain_target_proportions[domain]
else:
target_proportion = (
domain_target_proportion_rest
* count_weighted
/ total_count_weighted_rest
)
domain_target_proportions[domain] = target_proportion
assert domain in normalized_domain_target_proportions, f"{domain=} not found in {domain_target_proportions}"
target_proportion = normalized_domain_target_proportions[domain]
weight = target_proportion / (count_weighted / total_count_weighted)
domain_weights[domain] = weight

Expand All @@ -351,7 +348,7 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too
), f"{language=} not found"
target_proportion = (
programming_language_target_proportions[language]
* domain_target_proportions["code"]
* domain_target_proportions["code--programming"]
)
weight = target_proportion / (count_weighted / total_count_weighted)
programming_language_weights[language] = weight
Expand Down Expand Up @@ -394,8 +391,6 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too
)
print("```\n")

print("# Weights per sub-domain\n```")

for second_pass in [False, True]:
if not second_pass:
all_weights = {}
Expand All @@ -414,12 +409,20 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too
final_weight_dict = {}

for prefix, d in data.items():
domains = d["category"].split("-")
# get the corresponding domain of the dataset
language = d["language"]
category = d["category"]
domain = language + "--" + category

# get the number of tokens in the dataset
count = d[args.count]
# get the proportion of the dataset in the total number of tokens
ratio = count / total_count

domain_weight = max(domain_weights[domain] for domain in domains)
if d["category"] == "code":
domain_weight = domain_weights[domain]

# override the domain weight if the dataset is a code dataset by the programming language weight
if d["category"] == "programming":
prog_language = format_programming_language(prefix)
domain_weight = programming_language_weights[prog_language]

Expand Down Expand Up @@ -459,5 +462,3 @@ def prefix_to_canonical_name(name, possible_names): # noqa # C901 `...` is too
print("```")
else:
print()

# DATASET="$(python ~/Lucie-Training/training/collect_data_and_weights_alt.py /local_data/lucie_tokens_65k_grouped)"
49 changes: 33 additions & 16 deletions training/domain_proportions.yml
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
newspaper: 1
book: 2.5
technical: 2
wiki: 3
legal: 2
misc: 0.1
dialogue: 2
parlementary: 0.05
forum: 1
math: 3
code: 1
aligned: 3
legi_dialogue: 2
legi_spoken: 2
legi_written: 2
programming: 1
fr--newspaper: 13.27
fr--book: 5.9
fr--technical: 9.47
fr--wiki: 1.31
fr--legi_written: 0.36
fr--dialogue: 0.01
fr--legi_spoken: 0.09
fr--legi_dialogue: 0.08
en--technical: 16.22
en--newspaper: 1.62
en--legi_written: 1.88
en--wiki: 2.69
en--forum: 1.41
en--book: 1.93
en--math: 4.95
en--dialogue: 0.26
en--legi_dialogue: 0.02
de--wiki: 1.0
de--legi_written: 0.41
de--book: 0.06
de--legi_dialogue: 0.02
es--wiki: 0.8
es--legi_written: 0.3
es--legi_dialogue: 0.01
es--book: 0.04
it--wiki: 0.77
it--legi_written: 0.3
it--book: 0.04
es-en--aligned: 0.05
it-en--aligned: 0.05
de-fr--aligned: 0.05
fr-en--aligned: 8.67
code--programming: 25.96

0 comments on commit f60db54

Please sign in to comment.