Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import aiohttp | |
import asyncio | |
import discord | |
import pandas as pd | |
import requests | |
from teapotai import TeapotAI, TeapotAISettings | |
from pydantic import BaseModel, Field | |
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") | |
DISCORD_TOKEN = os.environ.get("discord_key") | |
# ======= API KEYS ======= | |
BRAVE_API_KEY = os.environ.get("brave_api_key") | |
WEATHER_API_KEY = os.environ.get("weather_api_key") | |
# ======== TOOLS =========== | |
import requests | |
from typing import Optional | |
from teapotai import TeapotTool | |
import re | |
import math | |
import pandas as pd | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, logging | |
### SEARCH TOOL | |
class BraveWebSearch(BaseModel): | |
search_query: str = Field(..., description="the search string to answer the question") | |
def brave_search_context(query, count=3): | |
url = "https://api.search.brave.com/res/v1/web/search" | |
headers = {"Accept": "application/json", "X-Subscription-Token": BRAVE_API_KEY} | |
params = {"q": query, "count": count} | |
response = requests.get(url, headers=headers, params=params) | |
if response.status_code == 200: | |
results = response.json().get("web", {}).get("results", []) | |
return "\n\n".join([res["title"]+"\n"+res["url"]+"\n"+res["description"].replace("<strong>","").replace("</strong>","") for res in results]) | |
else: | |
print(f"Error: {response.status_code}, {response.text}") | |
return "" | |
### CALCULATOR TOOL | |
import builtins | |
def evaluate_expression(expr) -> str: | |
""" | |
Evaluate a simple algebraic expression string safely. | |
Supports +, -, *, /, **, and parentheses. | |
Retries evaluation after stripping non-numeric/non-operator characters if needed. | |
""" | |
allowed_names = {k: getattr(builtins, k) for k in ("abs", "round")} | |
allowed_names.update({k: getattr(math, k) for k in ("sqrt", "pow")}) | |
def safe_eval(expression): | |
return eval(expression, {"__builtins__": None}, allowed_names) | |
try: | |
result = safe_eval(expr) | |
return f"{expr} = {result}" | |
except Exception as e: | |
print(f"Initial evaluation failed: {e}") | |
# Strip out any characters that are not numbers, parentheses, or valid operators | |
cleaned_expr = re.sub(r"[^0-9\.\+\-\*/\*\*\(\) ]", "", expr) | |
try: | |
result = safe_eval(cleaned_expr) | |
return f"{cleaned_expr} = {result}" | |
except Exception as e2: | |
print(f"Retry also failed: {e2}") | |
return "Sorry, I am unable to calculate that." | |
class Calculator(BaseModel): | |
expression: str = Field(..., description="mathematical expression") | |
### Weather Tool | |
def get_weather(weather): | |
city_name = weather.city_name | |
# OpenWeatherMap API endpoint | |
url = f'https://api.openweathermap.org/data/2.5/weather?appid={WEATHER_API_KEY}&units=imperial&q={city_name}' | |
# Send GET request to the OpenWeatherMap API | |
response = requests.get(url) | |
# Check if the request was successful | |
if response.status_code == 200: | |
data = response.json() | |
# Extract relevant weather information | |
city = data['name'] | |
temperature = round(data['main']['temp']) | |
weather_description = data['weather'][0]['description'] | |
# Print or return the results | |
return f"The weather in {city} is {weather_description} with a temperature of {temperature}Β°F." | |
else: | |
print(response.status_code) | |
return "City not found or there was an error with the request." | |
class Weather(BaseModel): | |
city_name: str = Field(..., description="The name of the city to pull the weather for") | |
### Stupid Question Tool | |
class CountNumberLetter(BaseModel): | |
word: str = Field(..., description="the word to count the number of letters in") | |
letter: str = Field(..., description="the letter to count the occurences of") | |
def count_number_letters(obj): | |
letter = obj.letter.lower() | |
expression = obj.word.lower() | |
if letter == "None": | |
return f"There are {len(obj.word)} letters in '{expression}'" | |
count = len([l for l in expression if l == letter]) | |
if count == 1: | |
return f"There is 1 '{letter}' in '{expression}'" | |
return f"There are {count} '{letter}'s in '{expression}'" | |
### Image Gen Tool | |
class ImageGen(BaseModel): | |
prompt: str = Field(..., description="The prompt to use to generate the image") | |
def generate_image(prompt): | |
if "teapot" in prompt.prompt.lower(): | |
return "I generated an image of a teapot for you: https://teapotai.com/assets/teapotsmile.png" | |
return "Ok I can't generate images, but you could easily hook up an image gen model to this tool call. Check out this image I did generate for you: https://teapotai.com/assets/teapotsmile.png" | |
### Tool Creation | |
DEFAULT_TOOLS = [ | |
TeapotTool( | |
name="websearch", | |
description="Execute web searches with pagination and filtering", | |
schema=BraveWebSearch, | |
fn=brave_search_context | |
), | |
TeapotTool( | |
name="letter_counter", | |
description="Can count how many times a letter occurs in a word.", | |
schema=CountNumberLetter, | |
fn=count_number_letters | |
), | |
TeapotTool( | |
name="calculator", | |
description="Can perform calculations on numbers using addition, subtraction, multiplication, and division.", | |
schema=Calculator, | |
fn=lambda expression: evaluate_expression(expression.expression), | |
), | |
TeapotTool( | |
name="generate_image", | |
description="Can generate an image for a user based on a prompt", | |
schema=ImageGen, | |
fn=generate_image, | |
directly_return_result=True | |
), | |
TeapotTool( | |
name="weather", | |
description="Can pull today's weather information for any city.", | |
schema=Weather, | |
fn=get_weather | |
) | |
] | |
# ========= CONFIG ========= | |
CONFIG = { | |
# "OneTrainer": TeapotAI( | |
# documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), | |
# settings=TeapotAISettings(rag_num_results=7) | |
# ), | |
"Teapot AI": TeapotAI( | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"teapotai/teapotllm", | |
revision="5aa6f84b5bd59da85552d55cc00efb702869cbf8", | |
), | |
documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), | |
settings=TeapotAISettings(rag_num_results=3, log_level="debug"), | |
tools=DEFAULT_TOOLS | |
), | |
} | |
# ========= DISCORD CLIENT ========= | |
intents = discord.Intents.default() | |
intents.messages = True | |
client = discord.Client(intents=intents) | |
async def handle_teapot_inference(server_name, user_input): | |
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"]) | |
print(f"Using Teapot instance for server: {server_name}") | |
# Running query in a separate thread to avoid blocking the event loop | |
# response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input)) | |
response = await asyncio.to_thread(teapot_instance.query, query=user_input, system_prompt="""You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. You can use tools such as a web search, a calculator and an image generator to assist users.""") | |
return response | |
async def debug_teapot_inference(server_name, user_input): | |
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"]) | |
print(f"Using Teapot instance for server: {server_name}") | |
# Running query in a separate thread to avoid blocking the event loop | |
search_result = brave_search_context(user_input) | |
rag_results = teapot_instance.rag(query=user_input) | |
return "\n\n".join(rag_results), search_result | |
async def on_ready(): | |
print(f'Logged in as {client.user}') | |
async def on_message(message): | |
if message.author == client.user: | |
return | |
# Check if the message mentions the bot | |
mentioned = f'<@{client.user.id}>' in message.content | |
# Check if the message is a reply to the bot | |
replied_to_bot = False | |
previous_message = "" | |
if message.reference: | |
replied_message = await message.channel.fetch_message(message.reference.message_id) | |
if replied_message.author == client.user: | |
replied_to_bot = True | |
previous_message = "agent: "+replied_message.content+"\n" | |
# If not mentioned and not replying to the bot, ignore | |
if not (mentioned or replied_to_bot): | |
return | |
server_name = message.guild.name if message.guild else "Teapot AI" | |
print(server_name, message.author, message.content) | |
async with message.channel.typing(): | |
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip() | |
full_context = previous_message + cleaned_message | |
response = await handle_teapot_inference(server_name, full_context) | |
await message.reply(response) | |
async def on_reaction_add(reaction, user): | |
if user == client.user: | |
return | |
if str(reaction.emoji) not in ["β", "β"]: | |
return | |
message = reaction.message | |
# Make sure it's a bot message that was a reply | |
if message.author != client.user or not message.reference: | |
return | |
# Fetch the original message that this bot message replied to | |
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip() | |
original_message = await message.channel.fetch_message(message.reference.message_id) | |
user_input = original_message.content.strip() | |
server_name = message.guild.name if message.guild else "Teapot AI" | |
# Create a thread or use existing one | |
thread = message.thread | |
if thread is None: | |
thread = await message.create_thread(name=f"Debug Thread: '{cleaned_message[0:30]}...'", auto_archive_duration=60) | |
rag_result, search_result = await debug_teapot_inference(server_name, user_input) | |
debug_response = "## RAG:\n```"+discord.utils.escape_markdown(rag_result)[-900:]+"```\n\n## Search:\n```"+discord.utils.escape_markdown(search_result)[-900:]+"```" | |
await thread.send(debug_response) | |
# ========= STREAMLIT ========= | |
def discord_loop(): | |
st.session_state["initialized"] = True | |
client.run(DISCORD_TOKEN) | |
st.write("418 I'm a teapot") | |
return | |
discord_loop() | |