Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
import torch
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
gr.Markdown(
|
8 |
+
"""
|
9 |
+
<style>
|
10 |
+
.center-btn button {
|
11 |
+
margin-left: auto;
|
12 |
+
margin-right: auto;
|
13 |
+
display: block;
|
14 |
+
}
|
15 |
+
</style>
|
16 |
+
"""
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
# Load model and tokenizer
|
21 |
+
model_name = "ale-dp/xlm-roberta-email-classifier"
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
23 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
24 |
+
|
25 |
+
# Label map
|
26 |
+
label_map = {
|
27 |
+
0: 'Billing and Payments',
|
28 |
+
1: 'Customer Service',
|
29 |
+
2: 'General Inquiry',
|
30 |
+
3: 'Human Resources',
|
31 |
+
4: 'IT Support',
|
32 |
+
5: 'Product Support',
|
33 |
+
6: 'Returns and Exchanges',
|
34 |
+
7: 'Sales and Pre-Sales',
|
35 |
+
8: 'Service Outages and Maintenance',
|
36 |
+
9: 'Technical Support'
|
37 |
+
}
|
38 |
+
|
39 |
+
# Prediction function
|
40 |
+
def classify_email_with_probs(text):
|
41 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
42 |
+
with torch.no_grad():
|
43 |
+
outputs = model(**inputs)
|
44 |
+
logits = outputs.logits[0]
|
45 |
+
probs = torch.nn.functional.softmax(logits, dim=0)
|
46 |
+
|
47 |
+
prob_dict = {label_map[i]: round(float(probs[i]) * 100, 2) for i in range(len(probs))}
|
48 |
+
sorted_probs = dict(sorted(prob_dict.items(), key=lambda item: item[1], reverse=True))
|
49 |
+
df = pd.DataFrame(sorted_probs.items(), columns=["Category", "Confidence (%)"])
|
50 |
+
top_label = df.iloc[0]["Category"]
|
51 |
+
return top_label, df
|
52 |
+
|
53 |
+
# Sample emails
|
54 |
+
examples = [
|
55 |
+
"Hello, I recently purchased a pair of headphones from your online store (Order #48392) and unfortunately, they arrived damaged. The left earcup is completely detached and the sound is distorted. I’d like to request a return or exchange. Please let me know the steps I need to follow and whether I need to ship the item back first. Thank you for your assistance.",
|
56 |
+
"Dear Customer Support Team,\n\nI hope this message reaches you well. I am reaching out to request detailed billing details and payment options for a QuickBooks Online subscription. Specifically, I am interested in understanding the available plans, their pricing structures, and any tailored options for institutional clients within the financial services industry.",
|
57 |
+
"Hello, I’m reaching out on behalf of a mid-sized retail company interested in your cloud-based inventory solution. We’re currently evaluating vendors and would appreciate a demo of your platform, along with pricing tiers for teams of 50+ users. Please let me know your availability this week for a call.",
|
58 |
+
"Currently facing sporadic connectivity difficulties with the cloud-native SaaS system. The suspected reason appears to be linked to orchestration resource distribution within Kubernetes-managed microservices. After restarting the affected services and examining deployment logs, the issue continues. Further investigation and escalation are required to resolve this matter swiftly."
|
59 |
+
]
|
60 |
+
|
61 |
+
# Gradio UI
|
62 |
+
with gr.Blocks() as demo:
|
63 |
+
gr.Markdown("## 📬 Email Ticket Classifier")
|
64 |
+
gr.Markdown("Classify emails into support categories using XLM-RoBERTa. See top prediction and full confidence breakdown.")
|
65 |
+
|
66 |
+
email_input = gr.Textbox(
|
67 |
+
lines=12,
|
68 |
+
label="Email Text",
|
69 |
+
placeholder="Paste your email here...",
|
70 |
+
elem_id="email_input"
|
71 |
+
)
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
submit_btn = gr.Button("Classify", variant="primary", elem_classes="center-btn")
|
75 |
+
|
76 |
+
gr.Markdown("<br><br>")
|
77 |
+
|
78 |
+
gr.Markdown("### Examples:")
|
79 |
+
|
80 |
+
with gr.Column():
|
81 |
+
for example in examples:
|
82 |
+
gr.Button(example).click(fn=lambda x=example: x, outputs=email_input)
|
83 |
+
|
84 |
+
|
85 |
+
top_label = gr.Label(label="Predicted Category")
|
86 |
+
prob_table = gr.Dataframe(
|
87 |
+
headers=["Category", "Confidence (%)"],
|
88 |
+
label="Confidence Breakdown",
|
89 |
+
datatype=["str", "number"],
|
90 |
+
row_count=10
|
91 |
+
)
|
92 |
+
|
93 |
+
submit_btn.click(fn=classify_email_with_probs, inputs=email_input, outputs=[top_label, prob_table])
|
94 |
+
|
95 |
+
demo.launch()
|