Skip to content

Commit

Permalink
updated for config by dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
joesharratt1229 committed Mar 3, 2025
1 parent 49db4ed commit aaf19e8
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions eval/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
--size SIZE Default dataset size (default: 100)
--seed SEED Default dataset seed (default: 42)
--include-params Include all configuration parameters (default: False)
--category CATEGORY Only include datasets from this category (default: None)
"""

import argparse
Expand All @@ -35,14 +36,27 @@ def extract_category(module_name):
return "other"


def generate_config(model, provider, size, seed, include_params):
"""Generate configuration with all registered datasets."""
def generate_config(model, provider, size, seed, include_params, category=None):
"""Generate configuration with all registered datasets.
Args:
model: Model name
provider: Provider name
size: Default dataset size
seed: Default dataset seed
include_params: Whether to include all configuration parameters
category: If specified, only include datasets from this category
"""
# Group datasets by category
categories = defaultdict(list)

for dataset_name, (dataset_cls, config_cls) in DATASETS.items():
# Extract category from module name
category = extract_category(dataset_cls.__module__)
dataset_category = extract_category(dataset_cls.__module__)

# Skip if a specific category was requested and this doesn't match
if category and dataset_category != category:
continue

# Create dataset entry
dataset_entry = {"dataset": dataset_name}
Expand All @@ -62,7 +76,7 @@ def generate_config(model, provider, size, seed, include_params):
dataset_entry["params"] = params

# Add to appropriate category
categories[category].append(dataset_entry)
categories[dataset_category].append(dataset_entry)

# Create configuration structure
config = {
Expand Down Expand Up @@ -90,12 +104,18 @@ def main():
parser.add_argument("--size", type=int, default=100, help="Default dataset size")
parser.add_argument("--seed", type=int, default=42, help="Default dataset seed")
parser.add_argument("--include-params", action="store_true", help="Include all configuration parameters")
parser.add_argument("--category", help="Only include datasets from this category")

args = parser.parse_args()

# Generate configuration
config = generate_config(
model=args.model, provider=args.provider, size=args.size, seed=args.seed, include_params=args.include_params
model=args.model,
provider=args.provider,
size=args.size,
seed=args.seed,
include_params=args.include_params,
category=args.category,
)

# Write to file
Expand Down

0 comments on commit aaf19e8

Please sign in to comment.