-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathhuggingface_auth.py
171 lines (133 loc) · 6.01 KB
/
huggingface_auth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import base64
import os
import time
from datetime import datetime, timedelta
from getpass import getpass
import requests
from huggingface_hub import HfApi
from hivemind.proto.auth_pb2 import AccessToken
from hivemind.utils.auth import TokenAuthorizerBase
from hivemind.utils.crypto import RSAPublicKey
from hivemind.utils.logging import get_logger
logger = get_logger(__name__)
class NonRetriableError(Exception):
pass
def call_with_retries(func, n_retries=10, initial_delay=1.0):
for i in range(n_retries):
try:
return func()
except NonRetriableError:
raise
except Exception as e:
if i == n_retries - 1:
raise
delay = initial_delay * (2 ** i)
logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
time.sleep(delay)
class InvalidCredentialsError(NonRetriableError):
pass
class NotInAllowlistError(NonRetriableError):
pass
class HuggingFaceAuthorizer(TokenAuthorizerBase):
_AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
def __init__(self, experiment_id: int, username: str, password: str):
super().__init__()
self.experiment_id = experiment_id
self.username = username
self.password = password
self._authority_public_key = None
self.coordinator_ip = None
self.coordinator_port = None
self._hf_api = HfApi()
async def get_token(self) -> AccessToken:
"""
Hivemind calls this method to refresh the token when necessary.
"""
self.join_experiment()
return self._local_access_token
def join_experiment(self) -> None:
call_with_retries(self._join_experiment)
def _join_experiment(self) -> None:
try:
token = self._hf_api.login(self.username, self.password)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401: # Unauthorized
raise InvalidCredentialsError()
raise
try:
url = f'{self._AUTH_SERVER_URL}/api/experiments/join/{self.experiment_id}/'
headers = {'Authorization': f'Bearer {token}'}
response = requests.put(url, headers=headers, json={
'experiment_join_input': {
'peer_public_key': self.local_public_key.to_bytes().decode(),
},
})
response.raise_for_status()
response = response.json()
self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
self.coordinator_ip = response['coordinator_ip']
self.coordinator_port = response['coordinator_port']
token_dict = response['hivemind_access']
access_token = AccessToken()
access_token.username = token_dict['username']
access_token.public_key = token_dict['peer_public_key'].encode()
access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
access_token.signature = token_dict['signature'].encode()
self._local_access_token = access_token
logger.info(f'Access for user {access_token.username} '
f'has been granted until {access_token.expiration_time} UTC')
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401: # Unauthorized
raise NotInAllowlistError()
raise
finally:
self._hf_api.logout(token)
def is_token_valid(self, access_token: AccessToken) -> bool:
data = self._token_to_bytes(access_token)
if not self._authority_public_key.verify(data, access_token.signature):
logger.exception('Access token has invalid signature')
return False
try:
expiration_time = datetime.fromisoformat(access_token.expiration_time)
except ValueError:
logger.exception(
f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
return False
if expiration_time.tzinfo is not None:
logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
return False
if expiration_time < datetime.utcnow():
logger.exception('Access token has expired')
return False
return True
_MAX_LATENCY = timedelta(minutes=1)
def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
expiration_time = datetime.fromisoformat(access_token.expiration_time)
return expiration_time < datetime.utcnow() + self._MAX_LATENCY
@staticmethod
def _token_to_bytes(access_token: AccessToken) -> bytes:
return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
def authorize_with_huggingface() -> HuggingFaceAuthorizer:
while True:
experiment_id = os.getenv('HF_EXPERIMENT_ID')
if experiment_id is None:
experiment_id = input('HuggingFace experiment ID: ')
username = os.getenv('HF_USERNAME')
if username is None:
while True:
username = input('HuggingFace username: ')
if '@' not in username:
break
print('Please enter your Huggingface _username_ instead of the email address!')
password = os.getenv('HF_PASSWORD')
if password is None:
password = getpass('HuggingFace password: ')
authorizer = HuggingFaceAuthorizer(experiment_id, username, password)
try:
authorizer.join_experiment()
return authorizer
except InvalidCredentialsError:
print('Invalid username or password, please try again')
except NotInAllowlistError:
print('This account is not specified in the allowlist. '
'Please ask a moderator to add you to the allowlist and try again')