-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinference.py
executable file
·36 lines (30 loc) · 1.23 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import re
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from utils import utils
def inference(args, config):
trainer = pl.Trainer(gpus=1, max_epochs=config.train.max_epoch, log_every_n_steps=1, deterministic=True)
dataloader, model = utils.init_modules(config)
model = utils.load_pretrained(model, config)
output = trainer.predict(
model=model, datamodule=dataloader
) # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html
pred_answer, output_prob = zip(*output)
pred_answer = np.concatenate(pred_answer).tolist()
output_prob = np.concatenate(output_prob, axis=0).tolist()
pred_answer = utils.num_to_label(pred_answer)
output = pd.DataFrame(
{
"id": range(len(pred_answer)),
"pred_label": pred_answer,
"probs": output_prob,
}
)
if not os.path.isdir("prediction"):
os.mkdir("prediction")
path = args.saved_model if args.saved_model is not None else config.path.best_model_path
run_name = f'{config.model.name}-{path.split("/")[-1]}'
run_name = run_name.replace("/", "-")
output.to_csv(f"./prediction/submission_{run_name}.csv", index=False)