seal / seal /run_inference.py
nazneen's picture
utils
d394488
from unittest import result
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import os
import numpy as np
from tqdm import tqdm
from utils.inference_utils import InferenceResults, saveResults
# Load validation set
def load_session(dataset, model, split):
dataset = load_dataset(dataset, split=split)
dataloader = DataLoader(
dataset,
batch_size=256, drop_last=True
)
model = AutoModelForSequenceClassification.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(model)
return tokenizer, dataloader, model
# Add hook to capture hidden layer
def get_input(name, model):
hidden_layers = {}
def hook(model, input, output):
if name in hidden_layers:
del hidden_layers[name]
hidden_layers[name] = input[0].detach()
return hook, hidden_layers
def run_inference(dataset='yelp_polarity', model='textattack/albert-base-v2-yelp-polarity', split='test', output_path='./assets/data/inference_results'):
tokenizer, dataloader, model = load_session(dataset,model,split)
model.eval()
model.to('cpu')
hook, hidden_layers = model.classifier.register_forward_hook(get_input('last_layer', model))
# Run inference on entire dataset
hidden_list = []
loss_list = []
output_list = []
example = []
labels = []
criterion = nn.CrossEntropyLoss(reduction='none')
softmax = nn.Softmax(dim=1)
with torch.no_grad():
for batch_num, batch in tqdm(enumerate(dataloader), total=len(dataloader), position=0, leave=True):
batch_ex = [ex[:512] for ex in batch['text']]
inputs = tokenizer(batch_ex, padding=True, return_tensors='pt').to('cpu')
targets = batch['label']
outputs = model(**inputs)['logits']
loss = criterion(outputs, targets)
predictions = softmax(outputs)
hidden_list.append(hidden_layers['last_layer'].cpu())
loss_list.append(loss.cpu())
#output_list.append(predictions[:, 1].cpu())
output_list.append(np.argmax(predictions, axis=1))
labels.append(targets)
example.append(inputs['input_ids'])
embeddings = torch.vstack(hidden_list)
#outputs = torch.hstack(output_list)
losses = torch.hstack(loss_list)
targets = torch.hstack(labels)
#inputs = torch.hstack(example)
results = save_results(embeddings,losses,targets)
saveResults(os.path.join(output_path,dataset+'.pkl'),results)
def save_results(embeddings, losses, labels):
results = InferenceResults(
embeddings = torch.clone(embeddings),
losses = losses,
labels = labels
)
return results