You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
this is code from book:
def extract_hidden_states(batch):
# Place model inputs on the GPU
inputs = {k:v.to(device) for k,v in batch.items()
if k in tokenizer.model_input_names}
# Extract last hidden states
with torch.no_grad():
last_hidden_state = model(**inputs).last_hidden_state
# Return vector for [CLS] token
return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}
this might give you error about list feature only when your data is not correctly converted to tensors.
so you can use this:
def extract_hidden_state(batch):
inputs = {k: batch[k].to(device) for k in ['input_ids', 'attention_mask']}
with torch.no_grad():
outputs = distilbert(**inputs)
last_hidden_state = outputs.last_hidden_state[:, 0].cpu().numpy()
return {"hidden_state": last_hidden_state}
The text was updated successfully, but these errors were encountered:
this is code from book:
def extract_hidden_states(batch):
# Place model inputs on the GPU
inputs = {k:v.to(device) for k,v in batch.items()
if k in tokenizer.model_input_names}
# Extract last hidden states
with torch.no_grad():
last_hidden_state = model(**inputs).last_hidden_state
# Return vector for [CLS] token
return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}
this might give you error about list feature only when your data is not correctly converted to tensors.
so you can use this:
def extract_hidden_state(batch):
inputs = {k: batch[k].to(device) for k in ['input_ids', 'attention_mask']}
with torch.no_grad():
outputs = distilbert(**inputs)
last_hidden_state = outputs.last_hidden_state[:, 0].cpu().numpy()
return {"hidden_state": last_hidden_state}
The text was updated successfully, but these errors were encountered: