|
from huggingface_hub import InferenceClient |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
API_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
|
def format_prompt(session_state ,query, history, chat_client): |
|
if chat_client=="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" : |
|
model_input = f"""<|im_start|>system |
|
{session_state.system_instruction} |
|
""" |
|
for user_prompt, bot_response in history: |
|
model_input += f"""<|im_start|>user |
|
{user_prompt}<|im_end|> |
|
""" |
|
model_input += f"""<|im_start|>assistant |
|
{bot_response}<|im_end|> |
|
""" |
|
model_input += f"""<|im_start|>user |
|
{query}<|im_end|> |
|
<|im_start|>assistant""" |
|
|
|
return model_input |
|
|
|
|
|
else : |
|
model_input = "<s>" |
|
for user_prompt, bot_response in history: |
|
model_input += f"[INST] {user_prompt} [/INST]" |
|
model_input += f" {bot_response}</s> " |
|
model_input += f"[INST] {query} [/INST]" |
|
return model_input |
|
|
|
|
|
def chat(session_state, query, config): |
|
|
|
|
|
|
|
chat_bot_dict = config["CHAT_BOTS"] |
|
chat_client = chat_bot_dict[session_state.chat_bot] |
|
temperature = session_state.temp |
|
max_new_tokens = session_state.max_tokens |
|
repetion_penalty = session_state.repetion_penalty |
|
history = session_state.history |
|
|
|
|
|
client = InferenceClient(chat_client, token=API_TOKEN) |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(0.95) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetion_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
formatted_prompt = format_prompt(session_state, query, history, chat_client) |
|
|
|
stream = client.text_generation( |
|
formatted_prompt, |
|
**generate_kwargs, |
|
stream=True, |
|
details=True, |
|
return_full_text=False, |
|
truncate = 32000 |
|
) |
|
|
|
return stream |
|
|