diff --git a/dataset/dataset_test.py b/dataset/dataset_test.py index da41eb4..6ae8765 100644 --- a/dataset/dataset_test.py +++ b/dataset/dataset_test.py @@ -94,18 +94,17 @@ def read_smiles(data_path, target, task): with open(data_path) as csv_file: csv_reader = csv.DictReader(csv_file, delimiter=',') for i, row in enumerate(csv_reader): - if i != 0: - smiles = row['smiles'] - label = row[target] - mol = Chem.MolFromSmiles(smiles) - if mol != None and label != '': - smiles_data.append(smiles) - if task == 'classification': - labels.append(int(label)) - elif task == 'regression': - labels.append(float(label)) - else: - ValueError('task must be either regression or classification') + smiles = row['smiles'] + label = row[target] + mol = Chem.MolFromSmiles(smiles) + if mol != None and label != '': + smiles_data.append(smiles) + if task == 'classification': + labels.append(int(label)) + elif task == 'regression': + labels.append(float(label)) + else: + ValueError('task must be either regression or classification') print(len(smiles_data)) return smiles_data, labels