ChatBotsTA commited on
Commit
a257837
·
verified ·
1 Parent(s): 12a2cb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -205
app.py CHANGED
@@ -1,215 +1,141 @@
1
- import os, io, re, json, base64, requests, numpy as np
 
 
 
2
  import streamlit as st
3
- from pypdf import PdfReader
4
- import matplotlib.pyplot as plt
5
-
6
- # -----------------------------
7
- # Config
8
- # -----------------------------
9
- st.set_page_config(page_title="PDF Summarizer + Audio + QA", page_icon="📄", layout="wide")
10
-
11
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
12
- HEADERS_JSON = {
13
- "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "",
14
- "Content-Type": "application/json",
15
- "Accept": "application/json",
16
- }
17
-
18
- SUMMARIZER_MODEL = "pszemraj/long-t5-tglobal-base-16384-book-summary"
19
- TTS_MODELS = [
20
- "espnet/kan-bayashi_ljspeech_vits",
21
- "facebook/fastspeech2-en-ljspeech"
22
- ]
23
- EMB_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
24
- QA_MODEL = "deepset/roberta-base-squad2"
25
-
26
- # -----------------------------
27
- # API helpers
28
- # -----------------------------
29
- def hf_infer_json(model_id: str, payload: dict, router=False, accept=None):
30
- if router:
31
- url = f"https://router.huggingface.co/hf-inference/models/{model_id}"
32
- else:
33
- url = f"https://api-inference.huggingface.co/models/{model_id}"
34
- headers = HEADERS_JSON.copy()
35
- if accept:
36
- headers["Accept"] = accept
37
- r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=120)
38
- r.raise_for_status()
 
 
39
  try:
40
- return r.json()
41
- except requests.exceptions.JSONDecodeError:
42
- return r.content
43
-
44
- def split_into_chunks(text: str, max_chars: int = 1500, overlap: int = 200):
45
- text = re.sub(r"\s+", " ", text).strip()
46
- chunks = []
47
- i = 0
48
- while i < len(text):
49
- chunk = text[i:i+max_chars]
50
- last_dot = chunk.rfind(". ")
51
- if last_dot > 400:
52
- chunk = chunk[:last_dot+1]
53
- i += last_dot + 1 - overlap
54
- else:
55
- i += max_chars - overlap
56
- chunks.append(chunk.strip())
57
- return [c for c in chunks if c]
58
-
59
- def embed_texts(texts):
60
- url = f"https://router.huggingface.co/hf-inference/models/{EMB_MODEL}/pipeline/feature-extraction"
61
- headers = HEADERS_JSON
62
- r = requests.post(url, headers=headers, data=json.dumps({"inputs": texts}), timeout=120)
63
- r.raise_for_status()
64
- arr = np.array(r.json(), dtype=np.float32)
65
- if arr.ndim == 2:
66
- return arr.mean(axis=0, keepdims=True)
67
- if arr.ndim == 3:
68
- pooled = [a.mean(axis=0) for a in arr]
69
- return np.vstack(pooled)
70
- return np.array(arr)
71
-
72
- def cosine_sim(a, b):
73
- a = a / (np.linalg.norm(a, axis=-1, keepdims=True) + 1e-8)
74
- b = b / (np.linalg.norm(b, axis=-1, keepdims=True) + 1e-8)
75
- return a @ b.T
76
-
77
- def summarize_long_text(text: str):
78
- chunks = split_into_chunks(text)
79
- mini_summaries = []
80
- for c in chunks:
81
- out = hf_infer_json(SUMMARIZER_MODEL, {"inputs": c}, router=False)
82
- if isinstance(out, list) and len(out) and "summary_text" in out[0]:
83
- mini_summaries.append(out[0]["summary_text"])
84
- else:
85
- mini_summaries.append(c[:800])
86
- return " ".join(mini_summaries), chunks
87
-
88
- def tts_wav_bytes(text: str) -> bytes:
89
- for model in TTS_MODELS:
90
- try:
91
- res = hf_infer_json(model, {"inputs": text}, router=False, accept="audio/wav")
92
- if isinstance(res, (bytes, bytearray)):
93
- return res
94
- if isinstance(res, dict) and "audio" in res:
95
- return base64.b64decode(res["audio"])
96
- except Exception:
97
- continue
98
- raise RuntimeError("All TTS models failed.")
99
-
100
- def extract_text_from_pdf(file) -> str:
101
- reader = PdfReader(file)
102
- pages = []
103
- for p in reader.pages:
104
- try:
105
- pages.append(p.extract_text() or "")
106
- except:
107
- pages.append("")
108
- return "\n".join(pages)
109
-
110
- def make_word_freq_chart(text: str, top_k=20):
111
- text = text.lower()
112
- stop = set(("the a an and of to in is are for with on by as at this that from be was were it its it’s into or if not your you we they their our can may such more most other also than which".split()))
113
- tokens = re.findall(r"[a-zA-Z]{3,}", text)
114
- freq = {}
115
- for t in tokens:
116
- if t in stop:
117
- continue
118
- freq[t] = freq.get(t, 0) + 1
119
- items = sorted(freq.items(), key=lambda x: x[1], reverse=True)[:top_k]
120
- if not items:
121
- st.info("Not enough text to show a frequency chart.")
122
- return
123
- words, counts = zip(*items)
124
- fig = plt.figure()
125
- plt.bar(words, counts)
126
- plt.xticks(rotation=60, ha="right")
127
- plt.title("Top word frequencies")
128
- plt.tight_layout()
129
- st.pyplot(fig)
130
-
131
- # -----------------------------
132
- # UI
133
- # -----------------------------
134
- st.title("📄 PDF → Summary · 🔊 Audio · 📊 Chart · ❓ Q&A")
135
- st.caption("Free models via Hugging Face Hosted Inference API.")
136
-
137
- uploaded = st.file_uploader("Upload a PDF", type=["pdf"])
138
-
139
- if "doc_text" not in st.session_state:
140
- st.session_state.doc_text = ""
141
- st.session_state.chunks = []
142
- st.session_state.chunk_vecs = None
143
- st.session_state.summary = ""
144
-
145
  if uploaded:
146
- with st.spinner("Extracting text..."):
147
- text = extract_text_from_pdf(uploaded)
148
- st.session_state.doc_text = text
149
- st.success(f"Loaded {len(text)} characters.")
150
-
151
- st.write("### Actions")
152
- with st.container():
153
- if st.button("📝 Summarize"):
154
- with st.spinner("Summarizing..."):
155
- summary, chunks = summarize_long_text(st.session_state.doc_text)
156
- st.session_state.summary = summary
157
- st.session_state.chunks = chunks
158
- st.success("Summary ready.")
159
- st.write("#### Summary")
160
- st.write(st.session_state.summary)
161
-
162
- with st.container():
163
- if st.button("🔊 Generate Audio (summary)"):
164
- target_text = st.session_state.summary or st.session_state.doc_text[:1200]
165
- with st.spinner("Generating audio..."):
166
  try:
167
- wav = tts_wav_bytes(target_text)
168
- st.audio(wav, format="audio/wav")
169
- st.success("Audio ready.")
170
  except Exception as e:
171
  st.error(f"TTS failed: {e}")
172
 
173
- with st.container():
174
- if st.button("📊 Show Word-Frequency Chart"):
175
- with st.spinner("Building chart..."):
176
- make_word_freq_chart(st.session_state.doc_text)
177
-
178
- st.write("---")
179
- st.subheader("Ask questions about the PDF")
180
- question = st.text_input("Your question")
181
- if st.button("Answer"):
182
- if not st.session_state.chunks:
183
- st.session_state.chunks = split_into_chunks(st.session_state.doc_text)
184
- with st.spinner("Thinking..."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  try:
186
- if st.session_state.chunk_vecs is None:
187
- vecs = embed_texts(st.session_state.chunks)
188
- st.session_state.chunk_vecs = vecs
189
- else:
190
- vecs = st.session_state.chunk_vecs
191
-
192
- q_vec = embed_texts([question])
193
- sims = cosine_sim(q_vec, vecs).flatten()
194
- top_idx = np.argsort(sims)[::-1][:3]
195
- context = "\n".join([st.session_state.chunks[i] for i in top_idx])
196
-
197
- qa_out = hf_infer_json(QA_MODEL, {"inputs": {"question": question, "context": context}}, router=False)
198
- if isinstance(qa_out, dict):
199
- ans = qa_out.get("answer", "")
200
- score = qa_out.get("score", 0.0)
201
- elif isinstance(qa_out, list) and len(qa_out) and isinstance(qa_out[0], dict):
202
- ans = qa_out[0].get("answer", "")
203
- score = qa_out[0].get("score", 0.0)
204
- else:
205
- ans, score = "", 0.0
206
-
207
- st.write("**Answer:**", ans or "_(no confident answer)_")
208
- st.caption(f"Confidence: {score:.3f}")
209
- with st.expander("Context used"):
210
- st.write(context)
211
  except Exception as e:
212
- st.error(f"QA failed: {e}")
213
 
214
- else:
215
- st.info("Upload a PDF to get started.")
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import tempfile
5
  import streamlit as st
6
+ from huggingface_hub import InferenceClient
7
+ import pdfplumber
8
+ from PIL import Image
9
+ import base64
10
+
11
+ # ---------- Configuration ----------
12
+ HF_TOKEN = os.environ.get("HF_TOKEN") # required
13
+ GROQ_KEY = os.environ.get("GROQ_API_KEY") # optional: if you want to call Groq directly
14
+ USE_GROQ_PROVIDER = True # set False to route to default HF provider
15
+
16
+ # model IDs (change if you prefer other models)
17
+ LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF
18
+ TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # a HF-hosted TTS model example
19
+ SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL base model
20
+
21
+ # create Inference client (route via HF token by default)
22
+ if USE_GROQ_PROVIDER:
23
+ client = InferenceClient(provider="groq", api_key=HF_TOKEN)
24
+ else:
25
+ client = InferenceClient(api_key=HF_TOKEN)
26
+
27
+ # ---------- Helpers ----------
28
+ def pdf_to_text(uploaded_file) -> str:
29
+ text_chunks = []
30
+ with pdfplumber.open(uploaded_file) as pdf:
31
+ for page in pdf.pages:
32
+ ptext = page.extract_text()
33
+ if ptext:
34
+ text_chunks.append(ptext)
35
+ return "\n\n".join(text_chunks)
36
+
37
+ def llama_summarize(text, max_tokens=512):
38
+ prompt = [
39
+ {"role": "system", "content": "You are a concise summarizer. Produce a clear summary in bullet points."},
40
+ {"role": "user", "content": f"Summarize the following document in <= 8 bullet points. Keep it short:\n\n{text}"}
41
+ ]
42
+ # Use chat completion endpoint style
43
+ resp = client.chat.completions.create(model=LLAMA_MODEL, messages=prompt)
44
  try:
45
+ summary = resp.choices[0].message["content"]
46
+ except Exception:
47
+ # fallback: try text generation field
48
+ summary = resp.choices[0].text if hasattr(resp.choices[0], "text") else str(resp)
49
+ return summary
50
+
51
+ def llama_chat(chat_history, user_question):
52
+ messages = chat_history + [{"role":"user","content":user_question}]
53
+ resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
54
+ return resp.choices[0].message["content"]
55
+
56
+ def tts_synthesize(text) -> bytes:
57
+ # InferenceClient offers text->audio utilities. This returns raw audio bytes (wav).
58
+ audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
59
+ return audio_bytes
60
+
61
+ def generate_image(prompt_text) -> Image.Image:
62
+ img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
63
+ return Image.open(io.BytesIO(img_bytes))
64
+
65
+ def audio_download_button(wav_bytes, filename="summary.wav"):
66
+ b64 = base64.b64encode(wav_bytes).decode()
67
+ href = f'<a href="data:audio/wav;base64,{b64}" download="{filename}">Download audio (WAV)</a>'
68
+ st.markdown(href, unsafe_allow_html=True)
69
+
70
+ # ---------- Streamlit UI ----------
71
+ st.set_page_config(page_title="PDFGPT (Groq + HF)", layout="wide")
72
+ st.title("PDF Summary + Speech + Chat + Diagram (Groq + HF)")
73
+
74
+ uploaded = st.file_uploader("Upload PDF", type=["pdf"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if uploaded:
76
+ with st.spinner("Extracting text from PDF..."):
77
+ text = pdf_to_text(uploaded)
78
+ st.subheader("Extracted text (preview)")
79
+ st.text_area("Document text", value=text[:1000], height=200)
80
+
81
+ if st.button("Create summary (Groq Llama)"):
82
+ with st.spinner("Summarizing with Groq Llama..."):
83
+ summary = llama_summarize(text)
84
+ st.subheader("Summary")
85
+ st.write(summary)
86
+ st.session_state["summary"] = summary
87
+
88
+ if "summary" in st.session_state:
89
+ summary = st.session_state["summary"]
90
+ if st.button("Synthesize audio from summary (TTS)"):
91
+ with st.spinner("Creating audio..."):
 
 
 
 
92
  try:
93
+ audio = tts_synthesize(summary)
94
+ st.audio(audio)
95
+ audio_download_button(audio)
96
  except Exception as e:
97
  st.error(f"TTS failed: {e}")
98
 
99
+ st.markdown("---")
100
+ st.subheader("Chat with your PDF (ask questions about document)")
101
+ if "chat_history" not in st.session_state:
102
+ # start with system + doc context (shortened)
103
+ doc_context = (text[:4000] + "...") if len(text) > 4000 else text
104
+ st.session_state["chat_history"] = [
105
+ {"role":"system","content":"You are a helpful assistant that answers questions based on the provided document."},
106
+ {"role":"user","content": f"Document context:\n{doc_context}"}
107
+ ]
108
+
109
+ user_q = st.text_input("Ask a question about the PDF")
110
+ if st.button("Ask") and user_q:
111
+ with st.spinner("Getting answer from Groq Llama..."):
112
+ answer = llama_chat(st.session_state["chat_history"], user_q)
113
+ st.session_state.setdefault("convo", []).append(("You", user_q))
114
+ st.session_state.setdefault("convo", []).append(("Assistant", answer))
115
+ # append to history for next calls
116
+ st.session_state["chat_history"].append({"role":"user","content":user_q})
117
+ st.session_state["chat_history"].append({"role":"assistant","content":answer})
118
+ st.write(answer)
119
+
120
+ st.markdown("---")
121
+ st.subheader("Generate a diagram from your question (SDXL)")
122
+ diagram_prompt = st.text_input("Describe the diagram or scene to generate")
123
+ if st.button("Generate diagram") and diagram_prompt:
124
+ with st.spinner("Generating image (SDXL)..."):
125
  try:
126
+ img = generate_image(diagram_prompt)
127
+ st.image(img, use_column_width=True)
128
+ # allow download
129
+ buf = io.BytesIO()
130
+ img.save(buf, format="PNG")
131
+ st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  except Exception as e:
133
+ st.error(f"Image generation failed: {e}")
134
 
135
+ st.sidebar.title("Settings")
136
+ st.sidebar.write("Models in use:")
137
+ st.sidebar.write(f"LLM: {LLAMA_MODEL}")
138
+ st.sidebar.write(f"TTS: {TTS_MODEL}")
139
+ st.sidebar.write(f"Image: {SDXL_MODEL}")
140
+
141
+ st.sidebar.markdown("**Notes**\n- Set HF_TOKEN in Space secrets or environment before starting.\n- To route directly to Groq with your Groq API key, set `GROQ_API_KEY` and change the client init accordingly.")