Skip to content

Commit

Permalink
Update s2d CrossEncoder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chungimungi authored Nov 25, 2023
1 parent e09e46c commit ab45da2
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions scripts/s2d CrossEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.model_selection import train_test_split
import numpy as np

# Set the device to use two GPUs
# Set the device to use two GPUs if available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load and preprocess the data
Expand All @@ -30,6 +30,7 @@

# Create a custom dataloader
class CustomDataset(Dataset):
"""Custom PyTorch dataset for handling input sequences and labels."""
def __init__(self, X, y, max_seq_length):
self.X = X
self.y = y
Expand All @@ -48,8 +49,9 @@ def __getitem__(self, index):

return torch.LongTensor(padded_input_seq), torch.LongTensor([label])

# model architecture
# Define the model architecture
class CustomCrossEncoder(nn.Module):
"""Custom PyTorch model for a cross-encoder with LSTM and attention."""
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, dropout_prob=0.5):
super(CustomCrossEncoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
Expand Down Expand Up @@ -86,7 +88,7 @@ def forward(self, input):

# Define loss function and optimizer
model = CustomCrossEncoder(vocab_size, embed_dim, hidden_dim, num_classes).to(device)
model = nn.DataParallel(model) #comment this line of code if multiple GPUs are not available
model = nn.DataParallel(model) # Comment this line of code if multiple GPUs are not available
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Expand Down Expand Up @@ -119,7 +121,7 @@ def forward(self, input):

with torch.no_grad():
for inputs, labels in test_dataloader:
inputs = inputs.to(device)
inputs = inputs.to(device)
labels = labels.to(device)

outputs = model(inputs)
Expand All @@ -129,8 +131,9 @@ def forward(self, input):

print(f"Test Accuracy: {100 * correct / total}%")

#Inference function
# Inference function
def predict_disease_from_input():
"""Predict disease from user input."""
input_text = input("Enter symptoms : ")
input_text = input_text.split()
input_ids = [symptom2id[word] for word in input_text]
Expand All @@ -147,4 +150,5 @@ def predict_disease_from_input():

import torchinfo

# Display model summary
torchinfo.summary(model.cuda())

0 comments on commit ab45da2

Please sign in to comment.