From 76290ee6d6015a10143d04bb7b2af5339f0f8e65 Mon Sep 17 00:00:00 2001 From: Ridwan Amure <40931539+instabaines@users.noreply.github.com> Date: Tue, 31 Dec 2024 00:20:13 -0600 Subject: [PATCH 1/2] Update huggingface.py Modify input device to match with the model device --- swarm_models/huggingface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swarm_models/huggingface.py b/swarm_models/huggingface.py index bef8057..c10b1b5 100644 --- a/swarm_models/huggingface.py +++ b/swarm_models/huggingface.py @@ -228,6 +228,7 @@ def run(self, task: str, *args, **kwargs): """ try: inputs = self.tokenizer.encode(task, return_tensors="pt") + inputs = inputs.to(self.model.device) if self.decoding: with torch.no_grad(): From 364da5d4ce3894d982cd6e5139542d97e07fc888 Mon Sep 17 00:00:00 2001 From: Ridwan Amure <40931539+instabaines@users.noreply.github.com> Date: Tue, 31 Dec 2024 18:59:49 -0600 Subject: [PATCH 2/2] Modify inputs device to match with the model device Fix device mismatch by placing inputs on the same device as the model to avoid errors --- swarm_models/huggingface.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/swarm_models/huggingface.py b/swarm_models/huggingface.py index c10b1b5..88a7300 100644 --- a/swarm_models/huggingface.py +++ b/swarm_models/huggingface.py @@ -227,8 +227,7 @@ def run(self, task: str, *args, **kwargs): - Generated text (str). """ try: - inputs = self.tokenizer.encode(task, return_tensors="pt") - inputs = inputs.to(self.model.device) + inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.model.device) if self.decoding: with torch.no_grad():