Skip to content

Commit

Permalink
in case @jstjohn needs it for his work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 15, 2022
1 parent 96abc11 commit 9edf7f3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,29 @@ seq = ds[0] # (196608,)
pred = model(seq, head = 'human') # (896, 5313)
```

To return the random shift value, as well as whether reverse complement was activated (in the case you need to reverse the corresponding chip-seq target data), just set `return_augs = True` when initializing the `GenomicIntervalDataset`

```python
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset

filter_train = lambda df: df.filter(pl.col('column_4') == 'train')

ds = GenomeIntervalDataset(
bed_file = './sequences.bed', # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
fasta_file = './hg38.ml.fa', # path to fasta file
filter_df_fn = filter_train, # filter dataframe function
return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings
shift_augs = (-2, 2), # random shift augmentations from -2 to +2 basepairs
rc_aug = True, # use reverse complement augmentation with 50% probability
context_length = 196_608,
return_augs = True # return the augmentation meta data
)

seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)
```

## Appreciation

Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for providing the resources to retrain the model in an acceptable amount of time
Expand Down
24 changes: 19 additions & 5 deletions enformer_pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
self.shift_augs = shift_augs
self.rc_aug = rc_aug

def __call__(self, chr_name, start, end):
def __call__(self, chr_name, start, end, return_augs = False):
interval_length = end - start
chromosome = self.seqs[chr_name]
chromosome_length = len(chromosome)
Expand Down Expand Up @@ -152,10 +152,21 @@ def __call__(self, chr_name, start, end):

one_hot = str_to_one_hot(seq)

if self.rc_aug and coin_flip():
rc_aug = self.rc_aug and coin_flip()

if rc_aug:
one_hot = one_hot_reverse_complement(one_hot)

return one_hot
if not return_augs:
return one_hot

# returns the shift integer as well as the bool (for whether reverse complement was activated)
# for this particular genomic sequence

rand_shift_tensor = torch.tensor([rand_shift])
rand_aug_bool_tensor = torch.tensor([rc_aug])

return one_hot, rand_shift_tensor, rand_aug_bool_tensor


class GenomeIntervalDataset(Dataset):
Expand All @@ -168,7 +179,8 @@ def __init__(
context_length = None,
return_seq_indices = False,
shift_augs = None,
rc_aug = False
rc_aug = False,
return_augs = False
):
super().__init__()
bed_path = Path(bed_file)
Expand All @@ -190,11 +202,13 @@ def __init__(
rc_aug = rc_aug
)

self.return_augs = return_augs

def __len__(self):
return len(self.df)

def __getitem__(self, ind):
interval = self.df.row(ind)
chr_name, start, end = (interval[0], interval[1], interval[2])
chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name)
return self.fasta(chr_name, start, end)
return self.fasta(chr_name, start, end, return_augs = self.return_augs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.4.4',
version = '0.4.5',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 9edf7f3

Please sign in to comment.