Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,215 +1,141 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
import streamlit as st
|
3 |
-
from
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
"
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
try:
|
40 |
-
|
41 |
-
except
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
return np.array(arr)
|
71 |
-
|
72 |
-
def cosine_sim(a, b):
|
73 |
-
a = a / (np.linalg.norm(a, axis=-1, keepdims=True) + 1e-8)
|
74 |
-
b = b / (np.linalg.norm(b, axis=-1, keepdims=True) + 1e-8)
|
75 |
-
return a @ b.T
|
76 |
-
|
77 |
-
def summarize_long_text(text: str):
|
78 |
-
chunks = split_into_chunks(text)
|
79 |
-
mini_summaries = []
|
80 |
-
for c in chunks:
|
81 |
-
out = hf_infer_json(SUMMARIZER_MODEL, {"inputs": c}, router=False)
|
82 |
-
if isinstance(out, list) and len(out) and "summary_text" in out[0]:
|
83 |
-
mini_summaries.append(out[0]["summary_text"])
|
84 |
-
else:
|
85 |
-
mini_summaries.append(c[:800])
|
86 |
-
return " ".join(mini_summaries), chunks
|
87 |
-
|
88 |
-
def tts_wav_bytes(text: str) -> bytes:
|
89 |
-
for model in TTS_MODELS:
|
90 |
-
try:
|
91 |
-
res = hf_infer_json(model, {"inputs": text}, router=False, accept="audio/wav")
|
92 |
-
if isinstance(res, (bytes, bytearray)):
|
93 |
-
return res
|
94 |
-
if isinstance(res, dict) and "audio" in res:
|
95 |
-
return base64.b64decode(res["audio"])
|
96 |
-
except Exception:
|
97 |
-
continue
|
98 |
-
raise RuntimeError("All TTS models failed.")
|
99 |
-
|
100 |
-
def extract_text_from_pdf(file) -> str:
|
101 |
-
reader = PdfReader(file)
|
102 |
-
pages = []
|
103 |
-
for p in reader.pages:
|
104 |
-
try:
|
105 |
-
pages.append(p.extract_text() or "")
|
106 |
-
except:
|
107 |
-
pages.append("")
|
108 |
-
return "\n".join(pages)
|
109 |
-
|
110 |
-
def make_word_freq_chart(text: str, top_k=20):
|
111 |
-
text = text.lower()
|
112 |
-
stop = set(("the a an and of to in is are for with on by as at this that from be was were it its it’s into or if not your you we they their our can may such more most other also than which".split()))
|
113 |
-
tokens = re.findall(r"[a-zA-Z]{3,}", text)
|
114 |
-
freq = {}
|
115 |
-
for t in tokens:
|
116 |
-
if t in stop:
|
117 |
-
continue
|
118 |
-
freq[t] = freq.get(t, 0) + 1
|
119 |
-
items = sorted(freq.items(), key=lambda x: x[1], reverse=True)[:top_k]
|
120 |
-
if not items:
|
121 |
-
st.info("Not enough text to show a frequency chart.")
|
122 |
-
return
|
123 |
-
words, counts = zip(*items)
|
124 |
-
fig = plt.figure()
|
125 |
-
plt.bar(words, counts)
|
126 |
-
plt.xticks(rotation=60, ha="right")
|
127 |
-
plt.title("Top word frequencies")
|
128 |
-
plt.tight_layout()
|
129 |
-
st.pyplot(fig)
|
130 |
-
|
131 |
-
# -----------------------------
|
132 |
-
# UI
|
133 |
-
# -----------------------------
|
134 |
-
st.title("📄 PDF → Summary · 🔊 Audio · 📊 Chart · ❓ Q&A")
|
135 |
-
st.caption("Free models via Hugging Face Hosted Inference API.")
|
136 |
-
|
137 |
-
uploaded = st.file_uploader("Upload a PDF", type=["pdf"])
|
138 |
-
|
139 |
-
if "doc_text" not in st.session_state:
|
140 |
-
st.session_state.doc_text = ""
|
141 |
-
st.session_state.chunks = []
|
142 |
-
st.session_state.chunk_vecs = None
|
143 |
-
st.session_state.summary = ""
|
144 |
-
|
145 |
if uploaded:
|
146 |
-
with st.spinner("Extracting text..."):
|
147 |
-
text =
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
st.
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
with st.container():
|
163 |
-
if st.button("🔊 Generate Audio (summary)"):
|
164 |
-
target_text = st.session_state.summary or st.session_state.doc_text[:1200]
|
165 |
-
with st.spinner("Generating audio..."):
|
166 |
try:
|
167 |
-
|
168 |
-
st.audio(
|
169 |
-
|
170 |
except Exception as e:
|
171 |
st.error(f"TTS failed: {e}")
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
try:
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
q_vec = embed_texts([question])
|
193 |
-
sims = cosine_sim(q_vec, vecs).flatten()
|
194 |
-
top_idx = np.argsort(sims)[::-1][:3]
|
195 |
-
context = "\n".join([st.session_state.chunks[i] for i in top_idx])
|
196 |
-
|
197 |
-
qa_out = hf_infer_json(QA_MODEL, {"inputs": {"question": question, "context": context}}, router=False)
|
198 |
-
if isinstance(qa_out, dict):
|
199 |
-
ans = qa_out.get("answer", "")
|
200 |
-
score = qa_out.get("score", 0.0)
|
201 |
-
elif isinstance(qa_out, list) and len(qa_out) and isinstance(qa_out[0], dict):
|
202 |
-
ans = qa_out[0].get("answer", "")
|
203 |
-
score = qa_out[0].get("score", 0.0)
|
204 |
-
else:
|
205 |
-
ans, score = "", 0.0
|
206 |
-
|
207 |
-
st.write("**Answer:**", ans or "_(no confident answer)_")
|
208 |
-
st.caption(f"Confidence: {score:.3f}")
|
209 |
-
with st.expander("Context used"):
|
210 |
-
st.write(context)
|
211 |
except Exception as e:
|
212 |
-
st.error(f"
|
213 |
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import os
|
3 |
+
import io
|
4 |
+
import tempfile
|
5 |
import streamlit as st
|
6 |
+
from huggingface_hub import InferenceClient
|
7 |
+
import pdfplumber
|
8 |
+
from PIL import Image
|
9 |
+
import base64
|
10 |
+
|
11 |
+
# ---------- Configuration ----------
|
12 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") # required
|
13 |
+
GROQ_KEY = os.environ.get("GROQ_API_KEY") # optional: if you want to call Groq directly
|
14 |
+
USE_GROQ_PROVIDER = True # set False to route to default HF provider
|
15 |
+
|
16 |
+
# model IDs (change if you prefer other models)
|
17 |
+
LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF
|
18 |
+
TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # a HF-hosted TTS model example
|
19 |
+
SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL base model
|
20 |
+
|
21 |
+
# create Inference client (route via HF token by default)
|
22 |
+
if USE_GROQ_PROVIDER:
|
23 |
+
client = InferenceClient(provider="groq", api_key=HF_TOKEN)
|
24 |
+
else:
|
25 |
+
client = InferenceClient(api_key=HF_TOKEN)
|
26 |
+
|
27 |
+
# ---------- Helpers ----------
|
28 |
+
def pdf_to_text(uploaded_file) -> str:
|
29 |
+
text_chunks = []
|
30 |
+
with pdfplumber.open(uploaded_file) as pdf:
|
31 |
+
for page in pdf.pages:
|
32 |
+
ptext = page.extract_text()
|
33 |
+
if ptext:
|
34 |
+
text_chunks.append(ptext)
|
35 |
+
return "\n\n".join(text_chunks)
|
36 |
+
|
37 |
+
def llama_summarize(text, max_tokens=512):
|
38 |
+
prompt = [
|
39 |
+
{"role": "system", "content": "You are a concise summarizer. Produce a clear summary in bullet points."},
|
40 |
+
{"role": "user", "content": f"Summarize the following document in <= 8 bullet points. Keep it short:\n\n{text}"}
|
41 |
+
]
|
42 |
+
# Use chat completion endpoint style
|
43 |
+
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=prompt)
|
44 |
try:
|
45 |
+
summary = resp.choices[0].message["content"]
|
46 |
+
except Exception:
|
47 |
+
# fallback: try text generation field
|
48 |
+
summary = resp.choices[0].text if hasattr(resp.choices[0], "text") else str(resp)
|
49 |
+
return summary
|
50 |
+
|
51 |
+
def llama_chat(chat_history, user_question):
|
52 |
+
messages = chat_history + [{"role":"user","content":user_question}]
|
53 |
+
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
|
54 |
+
return resp.choices[0].message["content"]
|
55 |
+
|
56 |
+
def tts_synthesize(text) -> bytes:
|
57 |
+
# InferenceClient offers text->audio utilities. This returns raw audio bytes (wav).
|
58 |
+
audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
|
59 |
+
return audio_bytes
|
60 |
+
|
61 |
+
def generate_image(prompt_text) -> Image.Image:
|
62 |
+
img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
|
63 |
+
return Image.open(io.BytesIO(img_bytes))
|
64 |
+
|
65 |
+
def audio_download_button(wav_bytes, filename="summary.wav"):
|
66 |
+
b64 = base64.b64encode(wav_bytes).decode()
|
67 |
+
href = f'<a href="data:audio/wav;base64,{b64}" download="{filename}">Download audio (WAV)</a>'
|
68 |
+
st.markdown(href, unsafe_allow_html=True)
|
69 |
+
|
70 |
+
# ---------- Streamlit UI ----------
|
71 |
+
st.set_page_config(page_title="PDFGPT (Groq + HF)", layout="wide")
|
72 |
+
st.title("PDF → Summary + Speech + Chat + Diagram (Groq + HF)")
|
73 |
+
|
74 |
+
uploaded = st.file_uploader("Upload PDF", type=["pdf"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
if uploaded:
|
76 |
+
with st.spinner("Extracting text from PDF..."):
|
77 |
+
text = pdf_to_text(uploaded)
|
78 |
+
st.subheader("Extracted text (preview)")
|
79 |
+
st.text_area("Document text", value=text[:1000], height=200)
|
80 |
+
|
81 |
+
if st.button("Create summary (Groq Llama)"):
|
82 |
+
with st.spinner("Summarizing with Groq Llama..."):
|
83 |
+
summary = llama_summarize(text)
|
84 |
+
st.subheader("Summary")
|
85 |
+
st.write(summary)
|
86 |
+
st.session_state["summary"] = summary
|
87 |
+
|
88 |
+
if "summary" in st.session_state:
|
89 |
+
summary = st.session_state["summary"]
|
90 |
+
if st.button("Synthesize audio from summary (TTS)"):
|
91 |
+
with st.spinner("Creating audio..."):
|
|
|
|
|
|
|
|
|
92 |
try:
|
93 |
+
audio = tts_synthesize(summary)
|
94 |
+
st.audio(audio)
|
95 |
+
audio_download_button(audio)
|
96 |
except Exception as e:
|
97 |
st.error(f"TTS failed: {e}")
|
98 |
|
99 |
+
st.markdown("---")
|
100 |
+
st.subheader("Chat with your PDF (ask questions about document)")
|
101 |
+
if "chat_history" not in st.session_state:
|
102 |
+
# start with system + doc context (shortened)
|
103 |
+
doc_context = (text[:4000] + "...") if len(text) > 4000 else text
|
104 |
+
st.session_state["chat_history"] = [
|
105 |
+
{"role":"system","content":"You are a helpful assistant that answers questions based on the provided document."},
|
106 |
+
{"role":"user","content": f"Document context:\n{doc_context}"}
|
107 |
+
]
|
108 |
+
|
109 |
+
user_q = st.text_input("Ask a question about the PDF")
|
110 |
+
if st.button("Ask") and user_q:
|
111 |
+
with st.spinner("Getting answer from Groq Llama..."):
|
112 |
+
answer = llama_chat(st.session_state["chat_history"], user_q)
|
113 |
+
st.session_state.setdefault("convo", []).append(("You", user_q))
|
114 |
+
st.session_state.setdefault("convo", []).append(("Assistant", answer))
|
115 |
+
# append to history for next calls
|
116 |
+
st.session_state["chat_history"].append({"role":"user","content":user_q})
|
117 |
+
st.session_state["chat_history"].append({"role":"assistant","content":answer})
|
118 |
+
st.write(answer)
|
119 |
+
|
120 |
+
st.markdown("---")
|
121 |
+
st.subheader("Generate a diagram from your question (SDXL)")
|
122 |
+
diagram_prompt = st.text_input("Describe the diagram or scene to generate")
|
123 |
+
if st.button("Generate diagram") and diagram_prompt:
|
124 |
+
with st.spinner("Generating image (SDXL)..."):
|
125 |
try:
|
126 |
+
img = generate_image(diagram_prompt)
|
127 |
+
st.image(img, use_column_width=True)
|
128 |
+
# allow download
|
129 |
+
buf = io.BytesIO()
|
130 |
+
img.save(buf, format="PNG")
|
131 |
+
st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
except Exception as e:
|
133 |
+
st.error(f"Image generation failed: {e}")
|
134 |
|
135 |
+
st.sidebar.title("Settings")
|
136 |
+
st.sidebar.write("Models in use:")
|
137 |
+
st.sidebar.write(f"LLM: {LLAMA_MODEL}")
|
138 |
+
st.sidebar.write(f"TTS: {TTS_MODEL}")
|
139 |
+
st.sidebar.write(f"Image: {SDXL_MODEL}")
|
140 |
+
|
141 |
+
st.sidebar.markdown("**Notes**\n- Set HF_TOKEN in Space secrets or environment before starting.\n- To route directly to Groq with your Groq API key, set `GROQ_API_KEY` and change the client init accordingly.")
|