diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index 34f4ecd0ce..12e7c448d4 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -15,7 +15,7 @@ import copy import json import os -import pwd +import platform import re import sys import time @@ -32,6 +32,10 @@ from trl.trainer.utils import get_quantization_config +if platform.system() != "Windows": + import pwd + + init_zero_verbose() HELP_STRING = """\ @@ -138,7 +142,10 @@ def print_help(self): def get_username(): - return pwd.getpwuid(os.getuid())[0] + if platform.system() == "Windows": + return os.getlogin() + else: + return pwd.getpwuid(os.getuid()).pw_name def create_default_filename(model_name):