Spaces:
Sleeping
Sleeping
""" | |
MCP Server for GAIA Agent Tools | |
This implements the Model Context Protocol for better tool organization | |
""" | |
import re | |
import os | |
import sys | |
import requests | |
import whisper | |
import pandas as pd | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader | |
try: | |
from mcp.server.fastmcp import FastMCP | |
mcp = FastMCP("gaia_agent_tools") | |
except ImportError: | |
print("Warning: MCP not available. Install with: pip install mcp", file=sys.stderr) | |
mcp = None | |
class GAIAToolServer: | |
"""GAIA Tool Server implementing MCP protocol""" | |
def __init__(self): | |
self.tools_registered = False | |
if mcp: | |
self.register_tools() | |
def register_tools(self): | |
"""Register all tools with the MCP server""" | |
def enhanced_web_search(query: str) -> dict: | |
"""Advanced web search with multiple result processing and filtering.""" | |
try: | |
search_tool = TavilySearchResults(max_results=5) | |
docs = search_tool.run(query) | |
results = [] | |
for d in docs: | |
content = d.get("content", "").strip() | |
url = d.get("url", "") | |
if content and len(content) > 20: | |
results.append(f"Source: {url}\nContent: {content}") | |
return {"web_results": "\n\n".join(results)} | |
except Exception as e: | |
return {"web_results": f"Search error: {str(e)}"} | |
def enhanced_wiki_search(query: str) -> dict: | |
"""Enhanced Wikipedia search with better content extraction.""" | |
try: | |
queries = [query, query.replace("_", " "), query.replace("-", " ")] | |
for q in queries: | |
try: | |
pages = WikipediaLoader(query=q, load_max_docs=3).load() | |
if pages: | |
content = "\n\n".join([ | |
f"Page: {p.metadata.get('title', 'Unknown')}\n{p.page_content[:2000]}" | |
for p in pages | |
]) | |
return {"wiki_results": content} | |
except: | |
continue | |
return {"wiki_results": "No Wikipedia results found"} | |
except Exception as e: | |
return {"wiki_results": f"Wikipedia error: {str(e)}"} | |
def youtube_transcript_tool(url: str) -> dict: | |
"""Extract transcript from YouTube videos with enhanced error handling.""" | |
try: | |
print(f"DEBUG: Processing YouTube URL: {url}", file=sys.stderr) | |
video_id_patterns = [ | |
r"(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/)([a-zA-Z0-9_-]{11})", | |
r"(?:v=|\/)([0-9A-Za-z_-]{11})" | |
] | |
video_id = None | |
for pattern in video_id_patterns: | |
match = re.search(pattern, url) | |
if match: | |
video_id = match.group(1) | |
break | |
if not video_id: | |
return {"transcript": "Error: Could not extract video ID from URL"} | |
print(f"DEBUG: Extracted video ID: {video_id}", file=sys.stderr) | |
try: | |
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) | |
# Try English first, then any available | |
try: | |
transcript = transcript_list.find_transcript(['en']) | |
except: | |
available = list(transcript_list._manually_created_transcripts.keys()) | |
if available: | |
transcript = transcript_list.find_transcript([available[0]]) | |
else: | |
return {"transcript": "No transcripts available"} | |
transcript_data = transcript.fetch() | |
# Format with timestamps | |
formatted_transcript = [] | |
for entry in transcript_data: | |
time_str = f"[{entry['start']:.1f}s]" | |
formatted_transcript.append(f"{time_str} {entry['text']}") | |
full_transcript = "\n".join(formatted_transcript) | |
return {"transcript": full_transcript} | |
except Exception as e: | |
return {"transcript": f"Error fetching transcript: {str(e)}"} | |
except Exception as e: | |
return {"transcript": f"YouTube processing error: {str(e)}"} | |
def enhanced_audio_transcribe(path: str) -> dict: | |
"""Enhanced audio transcription with better file handling.""" | |
try: | |
if not os.path.isabs(path): | |
abs_path = os.path.abspath(path) | |
else: | |
abs_path = path | |
print(f"DEBUG: Transcribing audio file: {abs_path}", file=sys.stderr) | |
if not os.path.isfile(abs_path): | |
current_dir_path = os.path.join(os.getcwd(), os.path.basename(path)) | |
if os.path.isfile(current_dir_path): | |
abs_path = current_dir_path | |
else: | |
return {"transcript": f"Error: Audio file not found at {abs_path}"} | |
# Check ffmpeg | |
try: | |
import subprocess | |
subprocess.run(["ffmpeg", "-version"], check=True, | |
stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
except (FileNotFoundError, subprocess.CalledProcessError): | |
return {"transcript": "Error: ffmpeg not found. Please install ffmpeg."} | |
model = whisper.load_model("base") | |
result = model.transcribe(abs_path) | |
transcript = result["text"].strip() | |
return {"transcript": transcript} | |
except Exception as e: | |
return {"transcript": f"Transcription error: {str(e)}"} | |
def enhanced_excel_analysis(path: str, query: str = "", sheet_name: str = None) -> dict: | |
"""Enhanced Excel analysis with query-specific processing.""" | |
try: | |
if not os.path.isabs(path): | |
abs_path = os.path.abspath(path) | |
else: | |
abs_path = path | |
if not os.path.isfile(abs_path): | |
current_dir_path = os.path.join(os.getcwd(), os.path.basename(path)) | |
if os.path.isfile(current_dir_path): | |
abs_path = current_dir_path | |
else: | |
return {"excel_analysis": f"Error: Excel file not found at {abs_path}"} | |
df = pd.read_excel(abs_path, sheet_name=sheet_name or 0) | |
analysis = { | |
"columns": list(df.columns), | |
"row_count": len(df), | |
"sheet_info": f"Analyzing sheet: {sheet_name or 'default'}" | |
} | |
query_lower = query.lower() if query else "" | |
if "total" in query_lower or "sum" in query_lower: | |
numeric_cols = df.select_dtypes(include=['number']).columns | |
totals = {} | |
for col in numeric_cols: | |
totals[col] = df[col].sum() | |
analysis["totals"] = totals | |
if "food" in query_lower or "category" in query_lower: | |
for col in df.columns: | |
if df[col].dtype == 'object': | |
categories = df[col].value_counts().to_dict() | |
analysis[f"{col}_categories"] = categories | |
analysis["sample_data"] = df.head(5).to_dict('records') | |
numeric_cols = df.select_dtypes(include=['number']).columns | |
if len(numeric_cols) > 0: | |
analysis["numeric_summary"] = df[numeric_cols].describe().to_dict() | |
return {"excel_analysis": analysis} | |
except Exception as e: | |
return {"excel_analysis": f"Excel analysis error: {str(e)}"} | |
def web_file_downloader(url: str) -> dict: | |
"""Download and analyze files from web URLs.""" | |
try: | |
response = requests.get(url, timeout=30) | |
response.raise_for_status() | |
content_type = response.headers.get('content-type', '').lower() | |
if 'audio' in content_type or url.endswith(('.mp3', '.wav', '.m4a')): | |
temp_path = f"temp_audio_{hash(url) % 10000}.wav" | |
with open(temp_path, 'wb') as f: | |
f.write(response.content) | |
result = enhanced_audio_transcribe(temp_path) | |
try: | |
os.remove(temp_path) | |
except: | |
pass | |
return result | |
elif 'text' in content_type or 'html' in content_type: | |
return {"content": response.text[:5000]} | |
else: | |
return {"content": f"Downloaded {len(response.content)} bytes of {content_type}"} | |
except Exception as e: | |
return {"content": f"Download error: {str(e)}"} | |
def test_tool(message: str) -> dict: | |
"""A simple test tool that always works.""" | |
print(f"DEBUG: Test tool called with: {message}", file=sys.stderr) | |
return {"result": f"Test successful: {message}"} | |
self.tools_registered = True | |
print("DEBUG: All MCP tools registered successfully", file=sys.stderr) | |
# Standalone functions for direct use (when MCP is not available) | |
class DirectTools: | |
"""Direct tool implementations for use without MCP""" | |
def enhanced_web_search(query: str) -> dict: | |
"""Direct web search implementation""" | |
try: | |
search_tool = TavilySearchResults(max_results=5) | |
docs = search_tool.run(query) | |
results = [] | |
for d in docs: | |
content = d.get("content", "").strip() | |
url = d.get("url", "") | |
if content and len(content) > 20: | |
results.append(f"Source: {url}\nContent: {content}") | |
return {"web_results": "\n\n".join(results)} | |
except Exception as e: | |
return {"web_results": f"Search error: {str(e)}"} | |
def youtube_transcript_tool(url: str) -> dict: | |
"""Direct YouTube transcript implementation""" | |
try: | |
video_id_patterns = [ | |
r"(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/)([a-zA-Z0-9_-]{11})", | |
r"(?:v=|\/)([0-9A-Za-z_-]{11})" | |
] | |
video_id = None | |
for pattern in video_id_patterns: | |
match = re.search(pattern, url) | |
if match: | |
video_id = match.group(1) | |
break | |
if not video_id: | |
return {"transcript": "Error: Could not extract video ID from URL"} | |
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) | |
try: | |
transcript = transcript_list.find_transcript(['en']) | |
except: | |
available = list(transcript_list._manually_created_transcripts.keys()) | |
if available: | |
transcript = transcript_list.find_transcript([available[0]]) | |
else: | |
return {"transcript": "No transcripts available"} | |
transcript_data = transcript.fetch() | |
formatted_transcript = [] | |
for entry in transcript_data: | |
time_str = f"[{entry['start']:.1f}s]" | |
formatted_transcript.append(f"{time_str} {entry['text']}") | |
full_transcript = "\n".join(formatted_transcript) | |
return {"transcript": full_transcript} | |
except Exception as e: | |
return {"transcript": f"YouTube processing error: {str(e)}"} | |
# Initialize the server | |
tool_server = GAIAToolServer() | |
if __name__ == "__main__": | |
if mcp and tool_server.tools_registered: | |
print("DEBUG: Starting MCP server", file=sys.stderr) | |
mcp.run(transport="stdio") | |
else: | |
print("MCP not available. Tools can be used directly via DirectTools class.") | |
# Test the tools | |
print("\nTesting DirectTools:") | |
# Test YouTube tool | |
test_url = "https://www.youtube.com/watch?v=1htKBjuUWec" | |
result = DirectTools.youtube_transcript_tool(test_url) | |
print(f"YouTube test result: {result}") | |