|
from unittest import result |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader, Subset |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from datasets import load_dataset |
|
|
|
import os |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from utils.inference_utils import InferenceResults, saveResults |
|
|
|
|
|
|
|
def load_session(dataset, model, split): |
|
dataset = load_dataset(dataset, split=split) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=256, drop_last=True |
|
) |
|
model = AutoModelForSequenceClassification.from_pretrained(model) |
|
tokenizer = AutoTokenizer.from_pretrained(model) |
|
return tokenizer, dataloader, model |
|
|
|
|
|
def get_input(name, model): |
|
hidden_layers = {} |
|
def hook(model, input, output): |
|
if name in hidden_layers: |
|
del hidden_layers[name] |
|
hidden_layers[name] = input[0].detach() |
|
return hook, hidden_layers |
|
|
|
def run_inference(dataset='yelp_polarity', model='textattack/albert-base-v2-yelp-polarity', split='test', output_path='./assets/data/inference_results'): |
|
tokenizer, dataloader, model = load_session(dataset,model,split) |
|
model.eval() |
|
model.to('cpu') |
|
hook, hidden_layers = model.classifier.register_forward_hook(get_input('last_layer', model)) |
|
|
|
hidden_list = [] |
|
loss_list = [] |
|
output_list = [] |
|
example = [] |
|
labels = [] |
|
criterion = nn.CrossEntropyLoss(reduction='none') |
|
softmax = nn.Softmax(dim=1) |
|
with torch.no_grad(): |
|
for batch_num, batch in tqdm(enumerate(dataloader), total=len(dataloader), position=0, leave=True): |
|
batch_ex = [ex[:512] for ex in batch['text']] |
|
inputs = tokenizer(batch_ex, padding=True, return_tensors='pt').to('cpu') |
|
targets = batch['label'] |
|
|
|
outputs = model(**inputs)['logits'] |
|
loss = criterion(outputs, targets) |
|
predictions = softmax(outputs) |
|
|
|
hidden_list.append(hidden_layers['last_layer'].cpu()) |
|
loss_list.append(loss.cpu()) |
|
|
|
output_list.append(np.argmax(predictions, axis=1)) |
|
labels.append(targets) |
|
example.append(inputs['input_ids']) |
|
embeddings = torch.vstack(hidden_list) |
|
|
|
losses = torch.hstack(loss_list) |
|
targets = torch.hstack(labels) |
|
|
|
results = save_results(embeddings,losses,targets) |
|
saveResults(os.path.join(output_path,dataset+'.pkl'),results) |
|
|
|
|
|
|
|
def save_results(embeddings, losses, labels): |
|
results = InferenceResults( |
|
embeddings = torch.clone(embeddings), |
|
losses = losses, |
|
labels = labels |
|
) |
|
return results |
|
|
|
|
|
|