Skip to content

Commit

Permalink
Add torch-backed scikit-learn solver with array api dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
fcharras committed Oct 13, 2023
1 parent 434a7ba commit dea87bc
Show file tree
Hide file tree
Showing 12 changed files with 1,038 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ jobs:
- name: Check sanity of benchmark files
run: |
python ./benchmarks/kmeans/consolidate_result_csv.py ./benchmarks/kmeans/results.csv --check-csv
python ./benchmarks/pca/consolidate_result_csv.py ./benchmarks/pca/results.csv --check-csv
3 changes: 3 additions & 0 deletions .github/workflows/sync_benchmark_files_to_gsheet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ jobs:
GSPREAD_URL: ${{vars.GSPREAD_URL}}
run: |
python ./benchmarks/kmeans/consolidate_result_csv.py ./benchmarks/kmeans/results.csv --check-csv
python ./benchmarks/pca/consolidate_result_csv.py ./benchmarks/pca/results.csv --check-csv
echo "$GSPREAD_SERVICE_ACCOUNT_AUTH_KEY" > service_account.json
python ./benchmarks/kmeans/consolidate_result_csv.py ./benchmarks/kmeans/results.csv \
--sync-to-gspread --gspread-url $GSPREAD_URL --gspread-auth-key ./service_account.json
python ./benchmarks/pca/consolidate_result_csv.py ./benchmarks/pca/results.csv \
--sync-to-gspread --gspread-url $GSPREAD_URL --gspread-auth-key ./service_account.json
2 changes: 2 additions & 0 deletions .github/workflows/test_cpu_benchmarks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,5 @@ jobs:
run: |
cd benchmarks/kmeans
PYTHONPATH=$PYTHONPATH:$(realpath ../../kmeans_dpcpp/) benchopt run --no-plot -l -d Simulated_correlated_data[n_samples=1000,n_features=14]
cd ../pca
benchopt run --no-plot -l -d Simulated_correlated_data[n_samples=100,n_features=100]
2 changes: 1 addition & 1 deletion benchmarks/kmeans/solvers/sklearn_pytorch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def set_objective(
device = self.device
# Copy the data before running the benchmark to ensure that no unfortunate side
# effects can happen
self.X = torch.asarray(X, copy=True, device=self.device)
self.X = torch.asarray(X, copy=True, device=device)

if hasattr(sample_weight, "copy"):
sample_weight = torch.asarray(sample_weight, copy=True, device=device)
Expand Down
Loading

0 comments on commit dea87bc

Please sign in to comment.