Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
0c0a08a
1
Parent(s):
99b54b3
Code reorganisation to better use config files. Adapted code to use Gemma 3 as local model. Minor package updates
Browse files- .dockerignore +8 -1
- .gitignore +4 -1
- Dockerfile +7 -1
- README.md +1 -1
- app.py +29 -38
- requirements.txt +9 -7
- requirements_aws.txt +5 -4
- requirements_gpu.txt +7 -4
- tools/auth.py +38 -13
- tools/aws_functions.py +6 -64
- tools/chatfuncs.py +0 -240
- tools/config.py +327 -0
- tools/dedup_summaries.py +602 -0
- tools/helper_functions.py +132 -97
- tools/llm_api_call.py +54 -1005
- tools/llm_funcs.py +579 -0
- tools/verify_titles.py +12 -8
- windows_install_llama-cpp-python.txt +118 -0
.dockerignore
CHANGED
@@ -14,4 +14,11 @@ dist/*
|
|
14 |
logs/*
|
15 |
usage/*
|
16 |
feedback/*
|
17 |
-
test_code/*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
logs/*
|
15 |
usage/*
|
16 |
feedback/*
|
17 |
+
test_code/*
|
18 |
+
input/
|
19 |
+
output/
|
20 |
+
logs/
|
21 |
+
usage/
|
22 |
+
feedback/
|
23 |
+
config/
|
24 |
+
tmp/
|
.gitignore
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
*.xls
|
7 |
*.xlsx
|
8 |
*.csv
|
|
|
9 |
examples/*
|
10 |
output/*
|
11 |
tools/__pycache__/*
|
@@ -14,4 +15,6 @@ dist/*
|
|
14 |
logs/*
|
15 |
usage/*
|
16 |
feedback/*
|
17 |
-
test_code/*
|
|
|
|
|
|
6 |
*.xls
|
7 |
*.xlsx
|
8 |
*.csv
|
9 |
+
*.pyc
|
10 |
examples/*
|
11 |
output/*
|
12 |
tools/__pycache__/*
|
|
|
15 |
logs/*
|
16 |
usage/*
|
17 |
feedback/*
|
18 |
+
test_code/*
|
19 |
+
config/*
|
20 |
+
tmp/*
|
Dockerfile
CHANGED
@@ -7,6 +7,8 @@ RUN apt-get update && apt-get install -y \
|
|
7 |
gcc \
|
8 |
g++ \
|
9 |
cmake \
|
|
|
|
|
10 |
python3-dev \
|
11 |
libffi-dev \
|
12 |
&& apt-get clean \
|
@@ -14,11 +16,15 @@ RUN apt-get update && apt-get install -y \
|
|
14 |
|
15 |
WORKDIR /src
|
16 |
|
|
|
|
|
|
|
|
|
17 |
COPY requirements_aws.txt .
|
18 |
|
19 |
RUN pip uninstall -y typing_extensions \
|
20 |
&& pip install --no-cache-dir --target=/install typing_extensions==4.12.2 \
|
21 |
-
&& pip install --no-cache-dir --target=/install torch==2.
|
22 |
&& pip install --no-cache-dir --target=/install -r requirements_aws.txt
|
23 |
|
24 |
RUN rm requirements_aws.txt
|
|
|
7 |
gcc \
|
8 |
g++ \
|
9 |
cmake \
|
10 |
+
libopenblas-dev \
|
11 |
+
pkg-config \
|
12 |
python3-dev \
|
13 |
libffi-dev \
|
14 |
&& apt-get clean \
|
|
|
16 |
|
17 |
WORKDIR /src
|
18 |
|
19 |
+
# Optional: Set environment variables for OpenBLAS
|
20 |
+
ENV OPENBLAS_VERBOSE=1
|
21 |
+
ENV CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS"
|
22 |
+
|
23 |
COPY requirements_aws.txt .
|
24 |
|
25 |
RUN pip uninstall -y typing_extensions \
|
26 |
&& pip install --no-cache-dir --target=/install typing_extensions==4.12.2 \
|
27 |
+
&& pip install --no-cache-dir --target=/install torch==2.7.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu \
|
28 |
&& pip install --no-cache-dir --target=/install -r requirements_aws.txt
|
29 |
|
30 |
RUN rm requirements_aws.txt
|
README.md
CHANGED
@@ -6,7 +6,7 @@ colorTo: yellow
|
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: true
|
9 |
-
license:
|
10 |
---
|
11 |
|
12 |
# Large language model topic modelling
|
|
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: true
|
9 |
+
license: agpl-3.0
|
10 |
---
|
11 |
|
12 |
# Large language model topic modelling
|
app.py
CHANGED
@@ -1,38 +1,33 @@
|
|
1 |
import os
|
2 |
-
import socket
|
3 |
-
import spaces
|
4 |
-
from tools.helper_functions import ensure_output_folder_exists, add_folder_to_path, put_columns_in_df, get_connection_params, output_folder, get_or_create_env_var, reveal_feedback_buttons, wipe_logs, model_full_names, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, RUN_LOCAL_MODEL, load_in_previous_reference_file, join_cols_onto_reference_df, GEMINI_API_KEY
|
5 |
-
from tools.aws_functions import upload_file_to_s3, RUN_AWS_FUNCTIONS
|
6 |
-
from tools.llm_api_call import extract_topics, load_in_data_file, load_in_previous_data_files, sample_reference_table_summaries, summarise_output_topics, batch_size_default, deduplicate_topics, modify_existing_output_tables
|
7 |
-
from tools.auth import authenticate_user
|
8 |
-
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
9 |
-
from tools.verify_titles import verify_titles
|
10 |
-
#from tools.aws_functions import load_data_from_aws
|
11 |
import gradio as gr
|
12 |
import pandas as pd
|
13 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
today_rev = datetime.now().strftime("%Y%m%d")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
access_logs_data_folder = 'logs/' + today_rev + '/' + host_name + '/'
|
23 |
-
feedback_data_folder = 'feedback/' + today_rev + '/' + host_name + '/'
|
24 |
-
usage_data_folder = 'usage/' + today_rev + '/' + host_name + '/'
|
25 |
-
file_input_height = 200
|
26 |
|
27 |
if RUN_LOCAL_MODEL == "1":
|
28 |
-
default_model_choice =
|
29 |
elif RUN_AWS_FUNCTIONS == "1":
|
30 |
default_model_choice = "anthropic.claude-3-haiku-20240307-v1:0"
|
31 |
else:
|
32 |
default_model_choice = "gemini-2.0-flash-001"
|
33 |
|
34 |
# Create the gradio interface
|
35 |
-
app = gr.Blocks(theme = gr.themes.
|
36 |
|
37 |
with app:
|
38 |
|
@@ -53,11 +48,11 @@ with app:
|
|
53 |
master_reference_df_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_reference_df_state", visible=False, type="pandas")
|
54 |
|
55 |
master_modify_unique_topics_df_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_modify_unique_topics_df_state", visible=False, type="pandas")
|
56 |
-
master_modify_reference_df_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_modify_reference_df_state", visible=False, type="pandas")
|
57 |
-
|
58 |
|
59 |
session_hash_state = gr.State()
|
60 |
-
|
|
|
61 |
|
62 |
# Logging state
|
63 |
log_file_name = 'log.csv'
|
@@ -222,7 +217,7 @@ with app:
|
|
222 |
""")
|
223 |
with gr.Accordion("Settings for LLM generation", open = True):
|
224 |
temperature_slide = gr.Slider(minimum=0.1, maximum=1.0, value=0.1, label="Choose LLM temperature setting")
|
225 |
-
batch_size_number = gr.Number(label = "Number of responses to submit in a single LLM query", value =
|
226 |
random_seed = gr.Number(value=42, label="Random seed for LLM generation", visible=False)
|
227 |
|
228 |
with gr.Accordion("Prompt settings", open = False):
|
@@ -279,13 +274,13 @@ with app:
|
|
279 |
success(load_in_data_file,
|
280 |
inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, reference_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
|
281 |
success(fn=extract_topics,
|
282 |
-
inputs=[in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, reference_data_file_name_textbox, total_number_of_batches, in_api_key, temperature_slide, in_colnames, model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, first_loop_state, conversation_metadata_textbox, initial_table_prompt_textbox, prompt_2_textbox, prompt_3_textbox, system_prompt_textbox, add_to_existing_topics_system_prompt_textbox, add_to_existing_topics_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, in_excel_sheets, force_single_topic_radio],
|
283 |
outputs=[display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, modification_input_files, in_join_files], api_name="extract_topics")
|
284 |
|
285 |
|
286 |
# If the output file count text box changes, keep going with redacting each data file until done. Then reveal the feedback buttons.
|
287 |
# latest_batch_completed.change(fn=extract_topics,
|
288 |
-
# inputs=[in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, reference_data_file_name_textbox, total_number_of_batches, in_api_key, temperature_slide, in_colnames, model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, second_loop_state, conversation_metadata_textbox, initial_table_prompt_textbox, prompt_2_textbox, prompt_3_textbox, system_prompt_textbox, add_to_existing_topics_system_prompt_textbox, add_to_existing_topics_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, in_excel_sheets],
|
289 |
# outputs=[display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, modification_input_files, in_join_files]).\
|
290 |
# success(fn = reveal_feedback_buttons,
|
291 |
# outputs=[data_feedback_radio, data_further_details_text, data_submit_feedback_btn, data_feedback_title], scroll_to_output=True)
|
@@ -293,24 +288,20 @@ with app:
|
|
293 |
# If you upload data into the deduplication input box, the modifiable topic dataframe box is updated
|
294 |
modification_input_files.change(fn=load_in_previous_data_files, inputs=[modification_input_files, modified_unique_table_change_bool], outputs=[modifiable_unique_topics_df_state, master_modify_reference_df_state, master_modify_unique_topics_df_state, reference_data_file_name_textbox, unique_topics_table_file_name_textbox, text_output_modify_file_list_state])
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
# Modify output table with custom topic names
|
301 |
-
save_modified_files_button.click(fn=modify_existing_output_tables, inputs=[master_modify_unique_topics_df_state, modifiable_unique_topics_df_state, master_modify_reference_df_state, text_output_modify_file_list_state], outputs=[master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, deduplication_input_files, summarisation_input_files, reference_data_file_name_textbox, unique_topics_table_file_name_textbox, summarised_output_markdown])
|
302 |
|
303 |
# When button pressed, deduplicate data
|
304 |
deduplicate_previous_data_btn.click(load_in_previous_data_files, inputs=[deduplication_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, reference_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
305 |
-
success(deduplicate_topics, inputs=[master_reference_df_state, master_unique_topics_df_state, reference_data_file_name_textbox, unique_topics_table_file_name_textbox, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, deduplicate_score_threshold, in_data_files, in_colnames], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown], scroll_to_output=True)
|
306 |
|
307 |
# When button pressed, summarise previous data
|
308 |
summarise_previous_data_btn.click(empty_output_vars_summarise, inputs=None, outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox]).\
|
309 |
success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, reference_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
310 |
-
success(sample_reference_table_summaries, inputs=[master_reference_df_state, master_unique_topics_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown, master_reference_df_state, master_unique_topics_df_state]).\
|
311 |
-
success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, in_api_key, summarised_references_markdown, temperature_slide, reference_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output])
|
312 |
|
313 |
-
latest_summary_completed_num.change(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, in_api_key, summarised_references_markdown, temperature_slide, reference_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output], scroll_to_output=True)
|
314 |
|
315 |
# CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
|
316 |
|
@@ -327,8 +318,8 @@ with app:
|
|
327 |
verify_titles_btn.click(fn=empty_output_vars_extract_topics, inputs=None, outputs=[master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, file_data_state, reference_data_file_name_textbox, display_topic_table_markdown]).\
|
328 |
success(load_in_data_file,
|
329 |
inputs = [verify_in_data_files, verify_in_colnames, batch_size_number, verify_in_excel_sheets], outputs = [file_data_state, reference_data_file_name_textbox, total_number_of_batches], api_name="verify_load_data").\
|
330 |
-
success(fn=verify_titles,
|
331 |
-
inputs=[verify_in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, reference_data_file_name_textbox, total_number_of_batches, verify_in_api_key, temperature_slide, verify_in_colnames, verify_model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, first_loop_state, conversation_metadata_textbox, verify_titles_prompt_textbox, prompt_2_textbox, prompt_3_textbox, verify_titles_system_prompt_textbox, verify_titles_system_prompt_textbox, verify_titles_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, in_excel_sheets],
|
332 |
outputs=[verify_display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, verify_titles_file_output, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, verify_modification_input_files_placeholder], api_name="verify_descriptions")
|
333 |
|
334 |
###
|
@@ -346,7 +337,7 @@ with app:
|
|
346 |
###
|
347 |
# LOGGING AND ON APP LOAD FUNCTIONS
|
348 |
###
|
349 |
-
app.load(get_connection_params, inputs=None, outputs=[session_hash_state,
|
350 |
|
351 |
# Log usernames and times of access to file (to know who is using the app when running on AWS)
|
352 |
access_callback = gr.CSVLogger(dataset_file_name=log_file_name)
|
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
from datetime import datetime
|
5 |
+
from tools.helper_functions import put_columns_in_df, get_connection_params, get_or_create_env_var, reveal_feedback_buttons, wipe_logs, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, load_in_previous_reference_file, join_cols_onto_reference_df
|
6 |
+
from tools.aws_functions import upload_file_to_s3
|
7 |
+
from tools.llm_api_call import extract_topics, load_in_data_file, load_in_previous_data_files, modify_existing_output_tables
|
8 |
+
from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics
|
9 |
+
from tools.auth import authenticate_user
|
10 |
+
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
11 |
+
from tools.verify_titles import verify_titles
|
12 |
+
from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, AWS_USER_POOL_ID, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE
|
13 |
|
14 |
today_rev = datetime.now().strftime("%Y%m%d")
|
15 |
|
16 |
+
host_name = HOST_NAME
|
17 |
+
access_logs_data_folder = ACCESS_LOGS_FOLDER
|
18 |
+
feedback_data_folder = FEEDBACK_LOGS_FOLDER
|
19 |
+
usage_data_folder = USAGE_LOGS_FOLDER
|
20 |
+
file_input_height = FILE_INPUT_HEIGHT
|
|
|
|
|
|
|
|
|
21 |
|
22 |
if RUN_LOCAL_MODEL == "1":
|
23 |
+
default_model_choice = CHOSEN_LOCAL_MODEL_TYPE
|
24 |
elif RUN_AWS_FUNCTIONS == "1":
|
25 |
default_model_choice = "anthropic.claude-3-haiku-20240307-v1:0"
|
26 |
else:
|
27 |
default_model_choice = "gemini-2.0-flash-001"
|
28 |
|
29 |
# Create the gradio interface
|
30 |
+
app = gr.Blocks(theme = gr.themes.Default(primary_hue="blue"), fill_width=True)
|
31 |
|
32 |
with app:
|
33 |
|
|
|
48 |
master_reference_df_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_reference_df_state", visible=False, type="pandas")
|
49 |
|
50 |
master_modify_unique_topics_df_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_modify_unique_topics_df_state", visible=False, type="pandas")
|
51 |
+
master_modify_reference_df_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_modify_reference_df_state", visible=False, type="pandas")
|
|
|
52 |
|
53 |
session_hash_state = gr.State()
|
54 |
+
output_folder_state = gr.State()
|
55 |
+
input_folder_state = gr.State()
|
56 |
|
57 |
# Logging state
|
58 |
log_file_name = 'log.csv'
|
|
|
217 |
""")
|
218 |
with gr.Accordion("Settings for LLM generation", open = True):
|
219 |
temperature_slide = gr.Slider(minimum=0.1, maximum=1.0, value=0.1, label="Choose LLM temperature setting")
|
220 |
+
batch_size_number = gr.Number(label = "Number of responses to submit in a single LLM query", value = BATCH_SIZE_DEFAULT, precision=0, minimum=1, maximum=100)
|
221 |
random_seed = gr.Number(value=42, label="Random seed for LLM generation", visible=False)
|
222 |
|
223 |
with gr.Accordion("Prompt settings", open = False):
|
|
|
274 |
success(load_in_data_file,
|
275 |
inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, reference_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
|
276 |
success(fn=extract_topics,
|
277 |
+
inputs=[in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, reference_data_file_name_textbox, total_number_of_batches, in_api_key, temperature_slide, in_colnames, model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, first_loop_state, conversation_metadata_textbox, initial_table_prompt_textbox, prompt_2_textbox, prompt_3_textbox, system_prompt_textbox, add_to_existing_topics_system_prompt_textbox, add_to_existing_topics_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, in_excel_sheets, force_single_topic_radio, output_folder_state],
|
278 |
outputs=[display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, modification_input_files, in_join_files], api_name="extract_topics")
|
279 |
|
280 |
|
281 |
# If the output file count text box changes, keep going with redacting each data file until done. Then reveal the feedback buttons.
|
282 |
# latest_batch_completed.change(fn=extract_topics,
|
283 |
+
# inputs=[in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, reference_data_file_name_textbox, total_number_of_batches, in_api_key, temperature_slide, in_colnames, model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, second_loop_state, conversation_metadata_textbox, initial_table_prompt_textbox, prompt_2_textbox, prompt_3_textbox, system_prompt_textbox, add_to_existing_topics_system_prompt_textbox, add_to_existing_topics_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, in_excel_sheets, force_single_topic_radio, output_folder_state],
|
284 |
# outputs=[display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, modification_input_files, in_join_files]).\
|
285 |
# success(fn = reveal_feedback_buttons,
|
286 |
# outputs=[data_feedback_radio, data_further_details_text, data_submit_feedback_btn, data_feedback_title], scroll_to_output=True)
|
|
|
288 |
# If you upload data into the deduplication input box, the modifiable topic dataframe box is updated
|
289 |
modification_input_files.change(fn=load_in_previous_data_files, inputs=[modification_input_files, modified_unique_table_change_bool], outputs=[modifiable_unique_topics_df_state, master_modify_reference_df_state, master_modify_unique_topics_df_state, reference_data_file_name_textbox, unique_topics_table_file_name_textbox, text_output_modify_file_list_state])
|
290 |
|
|
|
|
|
|
|
|
|
291 |
# Modify output table with custom topic names
|
292 |
+
save_modified_files_button.click(fn=modify_existing_output_tables, inputs=[master_modify_unique_topics_df_state, modifiable_unique_topics_df_state, master_modify_reference_df_state, text_output_modify_file_list_state, output_folder_state], outputs=[master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, deduplication_input_files, summarisation_input_files, reference_data_file_name_textbox, unique_topics_table_file_name_textbox, summarised_output_markdown])
|
293 |
|
294 |
# When button pressed, deduplicate data
|
295 |
deduplicate_previous_data_btn.click(load_in_previous_data_files, inputs=[deduplication_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, reference_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
296 |
+
success(deduplicate_topics, inputs=[master_reference_df_state, master_unique_topics_df_state, reference_data_file_name_textbox, unique_topics_table_file_name_textbox, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, deduplicate_score_threshold, in_data_files, in_colnames, output_folder_state], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown], scroll_to_output=True, api_name="deduplicate_topics")
|
297 |
|
298 |
# When button pressed, summarise previous data
|
299 |
summarise_previous_data_btn.click(empty_output_vars_summarise, inputs=None, outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox]).\
|
300 |
success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, reference_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
301 |
+
success(sample_reference_table_summaries, inputs=[master_reference_df_state, master_unique_topics_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown, master_reference_df_state, master_unique_topics_df_state], api_name="sample_summaries").\
|
302 |
+
success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, in_api_key, summarised_references_markdown, temperature_slide, reference_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output], api_name="summarise_topics")
|
303 |
|
304 |
+
latest_summary_completed_num.change(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, in_api_key, summarised_references_markdown, temperature_slide, reference_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output], scroll_to_output=True)
|
305 |
|
306 |
# CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
|
307 |
|
|
|
318 |
verify_titles_btn.click(fn=empty_output_vars_extract_topics, inputs=None, outputs=[master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, topic_extraction_output_files, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, file_data_state, reference_data_file_name_textbox, display_topic_table_markdown]).\
|
319 |
success(load_in_data_file,
|
320 |
inputs = [verify_in_data_files, verify_in_colnames, batch_size_number, verify_in_excel_sheets], outputs = [file_data_state, reference_data_file_name_textbox, total_number_of_batches], api_name="verify_load_data").\
|
321 |
+
success(fn=verify_titles,
|
322 |
+
inputs=[verify_in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, reference_data_file_name_textbox, total_number_of_batches, verify_in_api_key, temperature_slide, verify_in_colnames, verify_model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, first_loop_state, conversation_metadata_textbox, verify_titles_prompt_textbox, prompt_2_textbox, prompt_3_textbox, verify_titles_system_prompt_textbox, verify_titles_system_prompt_textbox, verify_titles_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, in_excel_sheets, output_folder_state],
|
323 |
outputs=[verify_display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, verify_titles_file_output, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, verify_modification_input_files_placeholder], api_name="verify_descriptions")
|
324 |
|
325 |
###
|
|
|
337 |
###
|
338 |
# LOGGING AND ON APP LOAD FUNCTIONS
|
339 |
###
|
340 |
+
app.load(get_connection_params, inputs=None, outputs=[session_hash_state, output_folder_state, session_hash_textbox, input_folder_state])
|
341 |
|
342 |
# Log usernames and times of access to file (to know who is using the app when running on AWS)
|
343 |
access_callback = gr.CSVLogger(dataset_file_name=log_file_name)
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
pandas==2.2.3
|
2 |
-
gradio==5.
|
3 |
spaces==0.34.1
|
4 |
-
boto3==1.38.
|
5 |
pyarrow==19.0.1
|
6 |
openpyxl==3.1.3
|
7 |
markdown==3.7
|
@@ -11,11 +11,13 @@ google-generativeai==0.8.4
|
|
11 |
html5lib==1.1
|
12 |
beautifulsoup4==4.12.3
|
13 |
rapidfuzz==3.10.1
|
14 |
-
torch==2.
|
15 |
-
|
16 |
-
#
|
17 |
-
llama-cpp-python==0.3.
|
18 |
-
#llama-cpp-python
|
|
|
19 |
transformers==4.51.1
|
|
|
20 |
numpy==1.26.4
|
21 |
typing_extensions==4.12.2
|
|
|
1 |
pandas==2.2.3
|
2 |
+
gradio==5.34.2
|
3 |
spaces==0.34.1
|
4 |
+
boto3==1.38.38
|
5 |
pyarrow==19.0.1
|
6 |
openpyxl==3.1.3
|
7 |
markdown==3.7
|
|
|
11 |
html5lib==1.1
|
12 |
beautifulsoup4==4.12.3
|
13 |
rapidfuzz==3.10.1
|
14 |
+
torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
15 |
+
llama-cpp-python==0.3.9 -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"# Linux compatibility - for recent models like Gemma 3
|
16 |
+
# For Windows try the following
|
17 |
+
# llama-cpp-python==0.3.9 -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/<root-path-to-openblas>/OpenBLAS/include;-DBLAS_LIBRARIES=C:/<root-path-to-openblas>/OpenBLAS/lib/libopenblas.lib
|
18 |
+
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.2/llama_cpp_python-0.3.2-cp311-cp311-win_amd64.whl # Use this for Windows if abov doesn't work, enough for Gemma 2b
|
19 |
+
#llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu # Use this for guaranteed Linux compatibility - enough for Gemma 2b only
|
20 |
transformers==4.51.1
|
21 |
+
python-dotenv==1.1.0
|
22 |
numpy==1.26.4
|
23 |
typing_extensions==4.12.2
|
requirements_aws.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
pandas==2.2.3
|
2 |
-
gradio==5.
|
3 |
spaces==0.34.1
|
4 |
-
boto3==1.38.
|
5 |
pyarrow==19.0.1
|
6 |
openpyxl==3.1.3
|
7 |
markdown==3.7
|
@@ -11,8 +11,9 @@ google-generativeai==0.8.4
|
|
11 |
html5lib==1.1
|
12 |
beautifulsoup4==4.12.3
|
13 |
rapidfuzz==3.10.1
|
14 |
-
|
15 |
-
llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
|
16 |
transformers==4.51.1
|
|
|
17 |
numpy==1.26.4
|
18 |
typing_extensions==4.12.2
|
|
|
1 |
pandas==2.2.3
|
2 |
+
gradio==5.34.2
|
3 |
spaces==0.34.1
|
4 |
+
boto3==1.38.38
|
5 |
pyarrow==19.0.1
|
6 |
openpyxl==3.1.3
|
7 |
markdown==3.7
|
|
|
11 |
html5lib==1.1
|
12 |
beautifulsoup4==4.12.3
|
13 |
rapidfuzz==3.10.1
|
14 |
+
llama-cpp-python==0.3.9
|
15 |
+
#llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
|
16 |
transformers==4.51.1
|
17 |
+
python-dotenv==1.1.0
|
18 |
numpy==1.26.4
|
19 |
typing_extensions==4.12.2
|
requirements_gpu.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
pandas==2.2.3
|
2 |
-
gradio==5.
|
3 |
spaces==0.34.1
|
4 |
-
boto3==1.38.
|
5 |
pyarrow==19.0.1
|
6 |
openpyxl==3.1.3
|
7 |
markdown==3.7
|
@@ -11,11 +11,14 @@ google-generativeai==0.8.4
|
|
11 |
html5lib==1.1
|
12 |
beautifulsoup4==4.12.3
|
13 |
rapidfuzz==3.10.1
|
14 |
-
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/
|
|
|
|
|
15 |
#llama-cpp-python==0.3.4 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
|
16 |
# Specify exact llama_cpp wheel for huggingface compatibility
|
17 |
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu121/llama_cpp_python-0.3.4-cp311-cp311-linux_x86_64.whl
|
18 |
-
https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu121/llama_cpp_python-0.3.4-cp311-cp311-win_amd64.whl # Windows
|
19 |
transformers==4.51.1
|
|
|
20 |
numpy==1.26.4
|
21 |
typing_extensions==4.12.2
|
|
|
1 |
pandas==2.2.3
|
2 |
+
gradio==5.34.2
|
3 |
spaces==0.34.1
|
4 |
+
boto3==1.38.38
|
5 |
pyarrow==19.0.1
|
6 |
openpyxl==3.1.3
|
7 |
markdown==3.7
|
|
|
11 |
html5lib==1.1
|
12 |
beautifulsoup4==4.12.3
|
13 |
rapidfuzz==3.10.1
|
14 |
+
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124
|
15 |
+
llama-cpp-python==0.3.9 -C cmake.args="-DGGML_CUDA=on"
|
16 |
+
# If the above doesn't work, try one of the following
|
17 |
#llama-cpp-python==0.3.4 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
|
18 |
# Specify exact llama_cpp wheel for huggingface compatibility
|
19 |
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu121/llama_cpp_python-0.3.4-cp311-cp311-linux_x86_64.whl
|
20 |
+
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu121/llama_cpp_python-0.3.4-cp311-cp311-win_amd64.whl # Windows
|
21 |
transformers==4.51.1
|
22 |
+
python-dotenv==1.1.0
|
23 |
numpy==1.26.4
|
24 |
typing_extensions==4.12.2
|
tools/auth.py
CHANGED
@@ -1,14 +1,20 @@
|
|
1 |
-
|
2 |
import boto3
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def authenticate_user(username, password, user_pool_id=
|
12 |
"""Authenticates a user against an AWS Cognito user pool.
|
13 |
|
14 |
Args:
|
@@ -16,22 +22,39 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
|
|
16 |
client_id (str): The ID of the Cognito user pool client.
|
17 |
username (str): The username of the user.
|
18 |
password (str): The password of the user.
|
|
|
19 |
|
20 |
Returns:
|
21 |
bool: True if the user is authenticated, False otherwise.
|
22 |
"""
|
23 |
|
24 |
-
client = boto3.client('cognito-idp') # Cognito Identity Provider client
|
|
|
|
|
|
|
25 |
|
26 |
try:
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
AuthFlow='USER_PASSWORD_AUTH',
|
29 |
AuthParameters={
|
30 |
'USERNAME': username,
|
31 |
'PASSWORD': password,
|
|
|
32 |
},
|
33 |
ClientId=client_id
|
34 |
-
|
35 |
|
36 |
# If successful, you'll receive an AuthenticationResult in the response
|
37 |
if response.get('AuthenticationResult'):
|
@@ -44,5 +67,7 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
|
|
44 |
except client.exceptions.UserNotFoundException:
|
45 |
return False
|
46 |
except Exception as e:
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
1 |
import boto3
|
2 |
+
import hmac
|
3 |
+
import hashlib
|
4 |
+
import base64
|
5 |
+
from tools.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_USER_POOL_ID, AWS_REGION
|
6 |
|
7 |
+
def calculate_secret_hash(client_id:str, client_secret:str, username:str):
|
8 |
+
message = username + client_id
|
9 |
+
dig = hmac.new(
|
10 |
+
str(client_secret).encode('utf-8'),
|
11 |
+
msg=str(message).encode('utf-8'),
|
12 |
+
digestmod=hashlib.sha256
|
13 |
+
).digest()
|
14 |
+
secret_hash = base64.b64encode(dig).decode()
|
15 |
+
return secret_hash
|
16 |
|
17 |
+
def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL_ID, client_id:str=AWS_CLIENT_ID, client_secret:str=AWS_CLIENT_SECRET):
|
18 |
"""Authenticates a user against an AWS Cognito user pool.
|
19 |
|
20 |
Args:
|
|
|
22 |
client_id (str): The ID of the Cognito user pool client.
|
23 |
username (str): The username of the user.
|
24 |
password (str): The password of the user.
|
25 |
+
client_secret (str): The client secret of the app client
|
26 |
|
27 |
Returns:
|
28 |
bool: True if the user is authenticated, False otherwise.
|
29 |
"""
|
30 |
|
31 |
+
client = boto3.client('cognito-idp', region_name=AWS_REGION) # Cognito Identity Provider client
|
32 |
+
|
33 |
+
# Compute the secret hash
|
34 |
+
secret_hash = calculate_secret_hash(client_id, client_secret, username)
|
35 |
|
36 |
try:
|
37 |
+
|
38 |
+
if client_secret == '':
|
39 |
+
response = client.initiate_auth(
|
40 |
+
AuthFlow='USER_PASSWORD_AUTH',
|
41 |
+
AuthParameters={
|
42 |
+
'USERNAME': username,
|
43 |
+
'PASSWORD': password,
|
44 |
+
},
|
45 |
+
ClientId=client_id
|
46 |
+
)
|
47 |
+
|
48 |
+
else:
|
49 |
+
response = client.initiate_auth(
|
50 |
AuthFlow='USER_PASSWORD_AUTH',
|
51 |
AuthParameters={
|
52 |
'USERNAME': username,
|
53 |
'PASSWORD': password,
|
54 |
+
'SECRET_HASH': secret_hash
|
55 |
},
|
56 |
ClientId=client_id
|
57 |
+
)
|
58 |
|
59 |
# If successful, you'll receive an AuthenticationResult in the response
|
60 |
if response.get('AuthenticationResult'):
|
|
|
67 |
except client.exceptions.UserNotFoundException:
|
68 |
return False
|
69 |
except Exception as e:
|
70 |
+
out_message = f"An error occurred: {e}"
|
71 |
+
print(out_message)
|
72 |
+
raise Exception(out_message)
|
73 |
+
return False
|
tools/aws_functions.py
CHANGED
@@ -3,21 +3,13 @@ import pandas as pd
|
|
3 |
import boto3
|
4 |
import tempfile
|
5 |
import os
|
6 |
-
from tools.
|
7 |
|
8 |
PandasDataFrame = Type[pd.DataFrame]
|
9 |
|
10 |
# Get AWS credentials if required
|
11 |
bucket_name=""
|
12 |
|
13 |
-
AWS_REGION = get_or_create_env_var('AWS_REGION', 'eu-west-2')
|
14 |
-
print(f'The value of AWS_REGION is {AWS_REGION}')
|
15 |
-
|
16 |
-
CONSULTATION_SUMMARY_BUCKET = get_or_create_env_var('CONSULTATION_SUMMARY_BUCKET', '')
|
17 |
-
print(f'The value of AWS_REGION is {CONSULTATION_SUMMARY_BUCKET}')
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
if RUN_AWS_FUNCTIONS == "1":
|
22 |
try:
|
23 |
bucket_name = CONSULTATION_SUMMARY_BUCKET
|
@@ -41,11 +33,12 @@ if RUN_AWS_FUNCTIONS == "1":
|
|
41 |
try:
|
42 |
assumed_role_arn, assumed_role_name = get_assumed_role_info()
|
43 |
|
44 |
-
print("Assumed Role ARN:", assumed_role_arn)
|
45 |
-
print("Assumed Role Name:", assumed_role_name)
|
46 |
|
47 |
-
|
48 |
-
|
|
|
49 |
print(e)
|
50 |
|
51 |
# Download direct from S3 - requires login credentials
|
@@ -113,57 +106,6 @@ def download_files_from_s3(bucket_name, s3_folder, local_folder, filenames):
|
|
113 |
except Exception as e:
|
114 |
print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
|
115 |
|
116 |
-
def load_data_from_aws(in_aws_keyword_file, aws_password="", bucket_name=bucket_name):
|
117 |
-
|
118 |
-
temp_dir = tempfile.mkdtemp()
|
119 |
-
local_address_stub = temp_dir + '/doc-redaction/'
|
120 |
-
files = []
|
121 |
-
|
122 |
-
if not 'LAMBETH_BOROUGH_PLAN_PASSWORD' in os.environ:
|
123 |
-
out_message = "Can't verify password for dataset access. Do you have a valid AWS connection? Data not loaded."
|
124 |
-
return files, out_message
|
125 |
-
|
126 |
-
if aws_password:
|
127 |
-
if "Lambeth borough plan" in in_aws_keyword_file and aws_password == os.environ['LAMBETH_BOROUGH_PLAN_PASSWORD']:
|
128 |
-
|
129 |
-
s3_folder_stub = 'example-data/lambeth-borough-plan/latest/'
|
130 |
-
|
131 |
-
local_folder_path = local_address_stub
|
132 |
-
|
133 |
-
# Check if folder exists
|
134 |
-
if not os.path.exists(local_folder_path):
|
135 |
-
print(f"Folder {local_folder_path} does not exist! Making folder.")
|
136 |
-
|
137 |
-
os.mkdir(local_folder_path)
|
138 |
-
|
139 |
-
# Check if folder is empty
|
140 |
-
if len(os.listdir(local_folder_path)) == 0:
|
141 |
-
print(f"Folder {local_folder_path} is empty")
|
142 |
-
# Download data
|
143 |
-
download_files_from_s3(bucket_name, s3_folder_stub, local_folder_path, filenames='*')
|
144 |
-
|
145 |
-
print("AWS data downloaded")
|
146 |
-
|
147 |
-
else:
|
148 |
-
print(f"Folder {local_folder_path} is not empty")
|
149 |
-
|
150 |
-
#files = os.listdir(local_folder_stub)
|
151 |
-
#print(files)
|
152 |
-
|
153 |
-
files = [os.path.join(local_folder_path, f) for f in os.listdir(local_folder_path) if os.path.isfile(os.path.join(local_folder_path, f))]
|
154 |
-
|
155 |
-
out_message = "Data successfully loaded from AWS"
|
156 |
-
print(out_message)
|
157 |
-
|
158 |
-
else:
|
159 |
-
out_message = "Data not loaded from AWS"
|
160 |
-
print(out_message)
|
161 |
-
else:
|
162 |
-
out_message = "No password provided. Please ask the data team for access if you need this."
|
163 |
-
print(out_message)
|
164 |
-
|
165 |
-
return files, out_message
|
166 |
-
|
167 |
def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=bucket_name, RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS):
|
168 |
"""
|
169 |
Uploads a file from local machine to Amazon S3.
|
|
|
3 |
import boto3
|
4 |
import tempfile
|
5 |
import os
|
6 |
+
from tools.config import RUN_AWS_FUNCTIONS, AWS_REGION, CONSULTATION_SUMMARY_BUCKET
|
7 |
|
8 |
PandasDataFrame = Type[pd.DataFrame]
|
9 |
|
10 |
# Get AWS credentials if required
|
11 |
bucket_name=""
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
if RUN_AWS_FUNCTIONS == "1":
|
14 |
try:
|
15 |
bucket_name = CONSULTATION_SUMMARY_BUCKET
|
|
|
33 |
try:
|
34 |
assumed_role_arn, assumed_role_name = get_assumed_role_info()
|
35 |
|
36 |
+
#print("Assumed Role ARN:", assumed_role_arn)
|
37 |
+
#print("Assumed Role Name:", assumed_role_name)
|
38 |
|
39 |
+
print("Successfully assumed role with AWS STS")
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
print(e)
|
43 |
|
44 |
# Download direct from S3 - requires login credentials
|
|
|
106 |
except Exception as e:
|
107 |
print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=bucket_name, RUN_AWS_FUNCTIONS=RUN_AWS_FUNCTIONS):
|
110 |
"""
|
111 |
Uploads a file from local machine to Amazon S3.
|
tools/chatfuncs.py
DELETED
@@ -1,240 +0,0 @@
|
|
1 |
-
from typing import TypeVar
|
2 |
-
import torch.cuda
|
3 |
-
import os
|
4 |
-
import time
|
5 |
-
from llama_cpp import Llama
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
-
from tools.helper_functions import RUN_LOCAL_MODEL
|
8 |
-
|
9 |
-
torch.cuda.empty_cache()
|
10 |
-
|
11 |
-
PandasDataFrame = TypeVar('pd.core.frame.DataFrame')
|
12 |
-
|
13 |
-
model_type = None # global variable setup
|
14 |
-
|
15 |
-
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
16 |
-
|
17 |
-
model = [] # Define empty list for model functions to run
|
18 |
-
tokenizer = [] #[] # Define empty list for model functions to run
|
19 |
-
|
20 |
-
local_model_type = "Gemma 2b"
|
21 |
-
|
22 |
-
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
|
23 |
-
|
24 |
-
# Check for torch cuda
|
25 |
-
print("Is CUDA enabled? ", torch.cuda.is_available())
|
26 |
-
print("Is a CUDA device available on this computer?", torch.backends.cudnn.enabled)
|
27 |
-
if torch.cuda.is_available():
|
28 |
-
torch_device = "cuda"
|
29 |
-
gpu_layers = -1
|
30 |
-
os.system("nvidia-smi")
|
31 |
-
else:
|
32 |
-
torch_device = "cpu"
|
33 |
-
gpu_layers = 0
|
34 |
-
|
35 |
-
print("Device used is: ", torch_device)
|
36 |
-
|
37 |
-
|
38 |
-
print("Running on device:", torch_device)
|
39 |
-
threads = torch.get_num_threads() # 8
|
40 |
-
print("CPU threads:", threads)
|
41 |
-
|
42 |
-
temperature: float = 0.1
|
43 |
-
top_k: int = 3
|
44 |
-
top_p: float = 1
|
45 |
-
repetition_penalty: float = 1.2 # Mild repetition penalty to prevent repeating table rows
|
46 |
-
last_n_tokens: int = 512
|
47 |
-
max_new_tokens: int = 4096 # 200
|
48 |
-
seed: int = 42
|
49 |
-
reset: bool = True
|
50 |
-
stream: bool = False
|
51 |
-
threads: int = threads
|
52 |
-
batch_size:int = 256
|
53 |
-
context_length:int = 16384
|
54 |
-
sample = True
|
55 |
-
|
56 |
-
|
57 |
-
class llama_cpp_init_config_gpu:
|
58 |
-
def __init__(self,
|
59 |
-
last_n_tokens=last_n_tokens,
|
60 |
-
seed=seed,
|
61 |
-
n_threads=threads,
|
62 |
-
n_batch=batch_size,
|
63 |
-
n_ctx=context_length,
|
64 |
-
n_gpu_layers=gpu_layers):
|
65 |
-
|
66 |
-
self.last_n_tokens = last_n_tokens
|
67 |
-
self.seed = seed
|
68 |
-
self.n_threads = n_threads
|
69 |
-
self.n_batch = n_batch
|
70 |
-
self.n_ctx = n_ctx
|
71 |
-
self.n_gpu_layers = n_gpu_layers
|
72 |
-
# self.stop: list[str] = field(default_factory=lambda: [stop_string])
|
73 |
-
|
74 |
-
def update_gpu(self, new_value):
|
75 |
-
self.n_gpu_layers = new_value
|
76 |
-
|
77 |
-
def update_context(self, new_value):
|
78 |
-
self.n_ctx = new_value
|
79 |
-
|
80 |
-
class llama_cpp_init_config_cpu(llama_cpp_init_config_gpu):
|
81 |
-
def __init__(self):
|
82 |
-
super().__init__()
|
83 |
-
self.n_gpu_layers = gpu_layers
|
84 |
-
self.n_ctx=context_length
|
85 |
-
|
86 |
-
gpu_config = llama_cpp_init_config_gpu()
|
87 |
-
cpu_config = llama_cpp_init_config_cpu()
|
88 |
-
|
89 |
-
|
90 |
-
class LlamaCPPGenerationConfig:
|
91 |
-
def __init__(self, temperature=temperature,
|
92 |
-
top_k=top_k,
|
93 |
-
top_p=top_p,
|
94 |
-
repeat_penalty=repetition_penalty,
|
95 |
-
seed=seed,
|
96 |
-
stream=stream,
|
97 |
-
max_tokens=max_new_tokens
|
98 |
-
):
|
99 |
-
self.temperature = temperature
|
100 |
-
self.top_k = top_k
|
101 |
-
self.top_p = top_p
|
102 |
-
self.repeat_penalty = repeat_penalty
|
103 |
-
self.seed = seed
|
104 |
-
self.max_tokens=max_tokens
|
105 |
-
self.stream = stream
|
106 |
-
|
107 |
-
def update_temp(self, new_value):
|
108 |
-
self.temperature = new_value
|
109 |
-
|
110 |
-
###
|
111 |
-
# Load local model
|
112 |
-
###
|
113 |
-
def get_model_path():
|
114 |
-
repo_id = os.environ.get("REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
|
115 |
-
filename = os.environ.get("MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
|
116 |
-
model_dir = "model/gemma" #"model/phi" # Assuming this is your intended directory
|
117 |
-
|
118 |
-
# Construct the expected local path
|
119 |
-
local_path = os.path.join(model_dir, filename)
|
120 |
-
|
121 |
-
if os.path.exists(local_path):
|
122 |
-
print(f"Model already exists at: {local_path}")
|
123 |
-
return local_path
|
124 |
-
else:
|
125 |
-
print(f"Checking default Hugging Face folder. Downloading model from Hugging Face Hub if not found")
|
126 |
-
return hf_hub_download(repo_id=repo_id, filename=filename)
|
127 |
-
|
128 |
-
def load_model(local_model_type:str=local_model_type, gpu_layers:int=gpu_layers, max_context_length:int=context_length, gpu_config:llama_cpp_init_config_gpu=gpu_config, cpu_config:llama_cpp_init_config_cpu=cpu_config, torch_device:str=torch_device):
|
129 |
-
'''
|
130 |
-
Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
|
131 |
-
'''
|
132 |
-
print("Loading model ", local_model_type)
|
133 |
-
|
134 |
-
if local_model_type == "Gemma 2b":
|
135 |
-
if torch_device == "cuda":
|
136 |
-
gpu_config.update_gpu(gpu_layers)
|
137 |
-
gpu_config.update_context(max_context_length)
|
138 |
-
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU. And a maximum context length of ", gpu_config.n_ctx)
|
139 |
-
else:
|
140 |
-
gpu_config.update_gpu(gpu_layers)
|
141 |
-
cpu_config.update_gpu(gpu_layers)
|
142 |
-
|
143 |
-
# Update context length according to slider
|
144 |
-
gpu_config.update_context(max_context_length)
|
145 |
-
cpu_config.update_context(max_context_length)
|
146 |
-
|
147 |
-
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU. And a maximum context length of ", gpu_config.n_ctx)
|
148 |
-
|
149 |
-
#print(vars(gpu_config))
|
150 |
-
#print(vars(cpu_config))
|
151 |
-
|
152 |
-
model_path = get_model_path()
|
153 |
-
|
154 |
-
try:
|
155 |
-
print("GPU load variables:" , vars(gpu_config))
|
156 |
-
llama_model = Llama(model_path=model_path, **vars(gpu_config)) # type_k=8, type_v = 8, flash_attn=True,
|
157 |
-
|
158 |
-
except Exception as e:
|
159 |
-
print("GPU load failed")
|
160 |
-
print(e)
|
161 |
-
llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config)) # type_v = 8, flash_attn=True,
|
162 |
-
|
163 |
-
tokenizer = []
|
164 |
-
|
165 |
-
model = llama_model
|
166 |
-
tokenizer = tokenizer
|
167 |
-
local_model_type = local_model_type
|
168 |
-
|
169 |
-
load_confirmation = "Finished loading model: " + local_model_type
|
170 |
-
|
171 |
-
print(load_confirmation)
|
172 |
-
return model, tokenizer
|
173 |
-
|
174 |
-
|
175 |
-
def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
|
176 |
-
"""
|
177 |
-
Calls your generation model with parameters from the LlamaCPPGenerationConfig object.
|
178 |
-
|
179 |
-
Args:
|
180 |
-
formatted_string (str): The formatted input text for the model.
|
181 |
-
gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
|
182 |
-
"""
|
183 |
-
# Extracting parameters from the gen_config object
|
184 |
-
temperature = gen_config.temperature
|
185 |
-
top_k = gen_config.top_k
|
186 |
-
top_p = gen_config.top_p
|
187 |
-
repeat_penalty = gen_config.repeat_penalty
|
188 |
-
seed = gen_config.seed
|
189 |
-
max_tokens = gen_config.max_tokens
|
190 |
-
stream = gen_config.stream
|
191 |
-
|
192 |
-
# Now you can call your model directly, passing the parameters:
|
193 |
-
output = model(
|
194 |
-
formatted_string,
|
195 |
-
temperature=temperature,
|
196 |
-
top_k=top_k,
|
197 |
-
top_p=top_p,
|
198 |
-
repeat_penalty=repeat_penalty,
|
199 |
-
seed=seed,
|
200 |
-
max_tokens=max_tokens,
|
201 |
-
stream=stream#,
|
202 |
-
#stop=["<|eot_id|>", "\n\n"]
|
203 |
-
)
|
204 |
-
|
205 |
-
return output
|
206 |
-
|
207 |
-
|
208 |
-
# This function is not used in this app
|
209 |
-
def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
210 |
-
|
211 |
-
gen_config = LlamaCPPGenerationConfig()
|
212 |
-
gen_config.update_temp(temperature)
|
213 |
-
|
214 |
-
print(vars(gen_config))
|
215 |
-
|
216 |
-
# Pull the generated text from the streamer, and update the model output.
|
217 |
-
start = time.time()
|
218 |
-
NUM_TOKENS=0
|
219 |
-
print('-'*4+'Start Generation'+'-'*4)
|
220 |
-
|
221 |
-
output = model(
|
222 |
-
full_prompt, **vars(gen_config))
|
223 |
-
|
224 |
-
history[-1][1] = ""
|
225 |
-
for out in output:
|
226 |
-
|
227 |
-
if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
|
228 |
-
history[-1][1] += out["choices"][0]["text"]
|
229 |
-
NUM_TOKENS+=1
|
230 |
-
yield history
|
231 |
-
else:
|
232 |
-
print(f"Unexpected output structure: {out}")
|
233 |
-
|
234 |
-
time_generate = time.time() - start
|
235 |
-
print('\n')
|
236 |
-
print('-'*4+'End Generation'+'-'*4)
|
237 |
-
print(f'Num of generated tokens: {NUM_TOKENS}')
|
238 |
-
print(f'Time for complete generation: {time_generate}s')
|
239 |
-
print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
|
240 |
-
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/config.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import socket
|
4 |
+
import logging
|
5 |
+
from datetime import datetime
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
today_rev = datetime.now().strftime("%Y%m%d")
|
9 |
+
HOST_NAME = socket.gethostname()
|
10 |
+
|
11 |
+
# Set or retrieve configuration variables for the redaction app
|
12 |
+
|
13 |
+
def get_or_create_env_var(var_name:str, default_value:str, print_val:bool=False):
|
14 |
+
'''
|
15 |
+
Get an environmental variable, and set it to a default value if it doesn't exist
|
16 |
+
'''
|
17 |
+
# Get the environment variable if it exists
|
18 |
+
value = os.environ.get(var_name)
|
19 |
+
|
20 |
+
# If it doesn't exist, set the environment variable to the default value
|
21 |
+
if value is None:
|
22 |
+
os.environ[var_name] = default_value
|
23 |
+
value = default_value
|
24 |
+
|
25 |
+
if print_val == True:
|
26 |
+
print(f'The value of {var_name} is {value}')
|
27 |
+
|
28 |
+
return value
|
29 |
+
|
30 |
+
def ensure_folder_exists(output_folder:str):
|
31 |
+
"""Checks if the specified folder exists, creates it if not."""
|
32 |
+
|
33 |
+
if not os.path.exists(output_folder):
|
34 |
+
# Create the folder if it doesn't exist
|
35 |
+
os.makedirs(output_folder, exist_ok=True)
|
36 |
+
print(f"Created the {output_folder} folder.")
|
37 |
+
else:
|
38 |
+
print(f"The {output_folder} folder already exists.")
|
39 |
+
|
40 |
+
def add_folder_to_path(folder_path: str):
|
41 |
+
'''
|
42 |
+
Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
|
43 |
+
'''
|
44 |
+
|
45 |
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
46 |
+
print(folder_path, "folder exists.")
|
47 |
+
|
48 |
+
# Resolve relative path to absolute path
|
49 |
+
absolute_path = os.path.abspath(folder_path)
|
50 |
+
|
51 |
+
current_path = os.environ['PATH']
|
52 |
+
if absolute_path not in current_path.split(os.pathsep):
|
53 |
+
full_path_extension = absolute_path + os.pathsep + current_path
|
54 |
+
os.environ['PATH'] = full_path_extension
|
55 |
+
#print(f"Updated PATH with: ", full_path_extension)
|
56 |
+
else:
|
57 |
+
print(f"Directory {folder_path} already exists in PATH.")
|
58 |
+
else:
|
59 |
+
print(f"Folder not found at {folder_path} - not added to PATH")
|
60 |
+
|
61 |
+
|
62 |
+
###
|
63 |
+
# LOAD CONFIG FROM ENV FILE
|
64 |
+
###
|
65 |
+
|
66 |
+
CONFIG_FOLDER = get_or_create_env_var('CONFIG_FOLDER', 'config/')
|
67 |
+
|
68 |
+
ensure_folder_exists(CONFIG_FOLDER)
|
69 |
+
|
70 |
+
# If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
|
71 |
+
APP_CONFIG_PATH = get_or_create_env_var('APP_CONFIG_PATH', CONFIG_FOLDER + 'app_config.env') # e.g. config/app_config.env
|
72 |
+
|
73 |
+
if APP_CONFIG_PATH:
|
74 |
+
if os.path.exists(APP_CONFIG_PATH):
|
75 |
+
print(f"Loading app variables from config file {APP_CONFIG_PATH}")
|
76 |
+
load_dotenv(APP_CONFIG_PATH)
|
77 |
+
else: print("App config file not found at location:", APP_CONFIG_PATH)
|
78 |
+
|
79 |
+
###
|
80 |
+
# AWS OPTIONS
|
81 |
+
###
|
82 |
+
|
83 |
+
# If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
|
84 |
+
AWS_CONFIG_PATH = get_or_create_env_var('AWS_CONFIG_PATH', '') # e.g. config/aws_config.env
|
85 |
+
|
86 |
+
if AWS_CONFIG_PATH:
|
87 |
+
if os.path.exists(AWS_CONFIG_PATH):
|
88 |
+
print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
|
89 |
+
load_dotenv(AWS_CONFIG_PATH)
|
90 |
+
else: print("AWS config file not found at location:", AWS_CONFIG_PATH)
|
91 |
+
|
92 |
+
RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "1")
|
93 |
+
|
94 |
+
AWS_REGION = get_or_create_env_var('AWS_REGION', '')
|
95 |
+
|
96 |
+
AWS_CLIENT_ID = get_or_create_env_var('AWS_CLIENT_ID', '')
|
97 |
+
|
98 |
+
AWS_CLIENT_SECRET = get_or_create_env_var('AWS_CLIENT_SECRET', '')
|
99 |
+
|
100 |
+
AWS_USER_POOL_ID = get_or_create_env_var('AWS_USER_POOL_ID', '')
|
101 |
+
|
102 |
+
AWS_ACCESS_KEY = get_or_create_env_var('AWS_ACCESS_KEY', '')
|
103 |
+
if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
|
104 |
+
|
105 |
+
AWS_SECRET_KEY = get_or_create_env_var('AWS_SECRET_KEY', '')
|
106 |
+
if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
|
107 |
+
|
108 |
+
CONSULTATION_SUMMARY_BUCKET = get_or_create_env_var('CONSULTATION_SUMMARY_BUCKET', '')
|
109 |
+
|
110 |
+
# Custom headers e.g. if routing traffic through Cloudfront
|
111 |
+
# Retrieving or setting CUSTOM_HEADER
|
112 |
+
CUSTOM_HEADER = get_or_create_env_var('CUSTOM_HEADER', '')
|
113 |
+
|
114 |
+
# Retrieving or setting CUSTOM_HEADER_VALUE
|
115 |
+
CUSTOM_HEADER_VALUE = get_or_create_env_var('CUSTOM_HEADER_VALUE', '')
|
116 |
+
|
117 |
+
###
|
118 |
+
# File I/O
|
119 |
+
###
|
120 |
+
SESSION_OUTPUT_FOLDER = get_or_create_env_var('SESSION_OUTPUT_FOLDER', 'False') # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
|
121 |
+
|
122 |
+
OUTPUT_FOLDER = get_or_create_env_var('GRADIO_OUTPUT_FOLDER', 'output/') # 'output/'
|
123 |
+
INPUT_FOLDER = get_or_create_env_var('GRADIO_INPUT_FOLDER', 'input/') # 'input/'
|
124 |
+
|
125 |
+
ensure_folder_exists(OUTPUT_FOLDER)
|
126 |
+
ensure_folder_exists(INPUT_FOLDER)
|
127 |
+
|
128 |
+
# Allow for files to be saved in a temporary folder for increased security in some instances
|
129 |
+
if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
|
130 |
+
# Create a temporary directory
|
131 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
132 |
+
print(f'Temporary directory created at: {temp_dir}')
|
133 |
+
|
134 |
+
if OUTPUT_FOLDER == "TEMP": OUTPUT_FOLDER = temp_dir + "/"
|
135 |
+
if INPUT_FOLDER == "TEMP": INPUT_FOLDER = temp_dir + "/"
|
136 |
+
|
137 |
+
|
138 |
+
GRADIO_TEMP_DIR = get_or_create_env_var('GRADIO_TEMP_DIR', 'tmp/gradio_tmp/') # Default Gradio temp folder
|
139 |
+
MPLCONFIGDIR = get_or_create_env_var('MPLCONFIGDIR', 'tmp/matplotlib_cache/') # Matplotlib cache folder
|
140 |
+
|
141 |
+
ensure_folder_exists(GRADIO_TEMP_DIR)
|
142 |
+
ensure_folder_exists(MPLCONFIGDIR)
|
143 |
+
|
144 |
+
# TLDEXTRACT_CACHE = get_or_create_env_var('TLDEXTRACT_CACHE', 'tmp/tld/')
|
145 |
+
# try:
|
146 |
+
# extract = TLDExtract(cache_dir=TLDEXTRACT_CACHE)
|
147 |
+
# except:
|
148 |
+
# extract = TLDExtract(cache_dir=None)
|
149 |
+
|
150 |
+
###
|
151 |
+
# LOGGING OPTIONS
|
152 |
+
###
|
153 |
+
|
154 |
+
# By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
|
155 |
+
# Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
|
156 |
+
|
157 |
+
SAVE_LOGS_TO_CSV = get_or_create_env_var('SAVE_LOGS_TO_CSV', 'True')
|
158 |
+
|
159 |
+
USE_LOG_SUBFOLDERS = get_or_create_env_var('USE_LOG_SUBFOLDERS', 'True')
|
160 |
+
|
161 |
+
if USE_LOG_SUBFOLDERS == "True":
|
162 |
+
day_log_subfolder = today_rev + '/'
|
163 |
+
host_name_subfolder = HOST_NAME + '/'
|
164 |
+
full_log_subfolder = day_log_subfolder + host_name_subfolder
|
165 |
+
else:
|
166 |
+
full_log_subfolder = ""
|
167 |
+
|
168 |
+
FEEDBACK_LOGS_FOLDER = get_or_create_env_var('FEEDBACK_LOGS_FOLDER', 'feedback/' + full_log_subfolder)
|
169 |
+
ACCESS_LOGS_FOLDER = get_or_create_env_var('ACCESS_LOGS_FOLDER', 'logs/' + full_log_subfolder)
|
170 |
+
USAGE_LOGS_FOLDER = get_or_create_env_var('USAGE_LOGS_FOLDER', 'usage/' + full_log_subfolder)
|
171 |
+
|
172 |
+
ensure_folder_exists(FEEDBACK_LOGS_FOLDER)
|
173 |
+
ensure_folder_exists(ACCESS_LOGS_FOLDER)
|
174 |
+
ensure_folder_exists(USAGE_LOGS_FOLDER)
|
175 |
+
|
176 |
+
# Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
|
177 |
+
DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var('DISPLAY_FILE_NAMES_IN_LOGS', 'False')
|
178 |
+
|
179 |
+
# Further customisation options for CSV logs
|
180 |
+
|
181 |
+
CSV_ACCESS_LOG_HEADERS = get_or_create_env_var('CSV_ACCESS_LOG_HEADERS', '') # If blank, uses component labels
|
182 |
+
CSV_FEEDBACK_LOG_HEADERS = get_or_create_env_var('CSV_FEEDBACK_LOG_HEADERS', '') # If blank, uses component labels
|
183 |
+
CSV_USAGE_LOG_HEADERS = get_or_create_env_var('CSV_USAGE_LOG_HEADERS', '["session_hash_textbox", "doc_full_file_name_textbox", "data_full_file_name_textbox", "actual_time_taken_number", "total_page_count", "textract_query_number", "pii_detection_method", "comprehend_query_number", "cost_code", "textract_handwriting_signature", "host_name_textbox", "text_extraction_method", "is_this_a_textract_api_call"]') # If blank, uses component labels
|
184 |
+
|
185 |
+
### DYNAMODB logs. Whether to save to DynamoDB, and the headers of the table
|
186 |
+
|
187 |
+
SAVE_LOGS_TO_DYNAMODB = get_or_create_env_var('SAVE_LOGS_TO_DYNAMODB', 'False')
|
188 |
+
|
189 |
+
ACCESS_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('ACCESS_LOG_DYNAMODB_TABLE_NAME', 'redaction_access_log')
|
190 |
+
DYNAMODB_ACCESS_LOG_HEADERS = get_or_create_env_var('DYNAMODB_ACCESS_LOG_HEADERS', '')
|
191 |
+
|
192 |
+
FEEDBACK_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('FEEDBACK_LOG_DYNAMODB_TABLE_NAME', 'redaction_feedback')
|
193 |
+
DYNAMODB_FEEDBACK_LOG_HEADERS = get_or_create_env_var('DYNAMODB_FEEDBACK_LOG_HEADERS', '')
|
194 |
+
|
195 |
+
USAGE_LOG_DYNAMODB_TABLE_NAME = get_or_create_env_var('USAGE_LOG_DYNAMODB_TABLE_NAME', 'redaction_usage')
|
196 |
+
DYNAMODB_USAGE_LOG_HEADERS = get_or_create_env_var('DYNAMODB_USAGE_LOG_HEADERS', '')
|
197 |
+
|
198 |
+
# Report logging to console?
|
199 |
+
LOGGING = get_or_create_env_var('LOGGING', 'False')
|
200 |
+
|
201 |
+
if LOGGING == 'True':
|
202 |
+
# Configure logging
|
203 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
204 |
+
|
205 |
+
LOG_FILE_NAME = get_or_create_env_var('LOG_FILE_NAME', 'log.csv')
|
206 |
+
|
207 |
+
###
|
208 |
+
# LLM variables
|
209 |
+
###
|
210 |
+
|
211 |
+
MAX_TOKENS = int(get_or_create_env_var('MAX_TOKENS', '4096')) # Maximum number of output tokens
|
212 |
+
TIMEOUT_WAIT = int(get_or_create_env_var('TIMEOUT_WAIT', '30')) # AWS now seems to have a 60 second minimum wait between API calls
|
213 |
+
NUMBER_OF_RETRY_ATTEMPTS = int(get_or_create_env_var('NUMBER_OF_RETRY_ATTEMPTS', '5'))
|
214 |
+
# Try up to 3 times to get a valid markdown table response with LLM calls, otherwise retry with temperature changed
|
215 |
+
MAX_OUTPUT_VALIDATION_ATTEMPTS = int(get_or_create_env_var('MAX_OUTPUT_VALIDATION_ATTEMPTS', '3'))
|
216 |
+
MAX_TIME_FOR_LOOP = int(get_or_create_env_var('MAX_TIME_FOR_LOOP', '99999'))
|
217 |
+
BATCH_SIZE_DEFAULT = int(get_or_create_env_var('BATCH_SIZE_DEFAULT', '5'))
|
218 |
+
DEDUPLICATION_THRESHOLD = int(get_or_create_env_var('DEDUPLICATION_THRESHOLD', '90'))
|
219 |
+
MAX_COMMENT_CHARS = int(get_or_create_env_var('MAX_COMMENT_CHARS', '14000'))
|
220 |
+
|
221 |
+
RUN_LOCAL_MODEL = get_or_create_env_var("RUN_LOCAL_MODEL", "1")
|
222 |
+
RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
|
223 |
+
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
224 |
+
|
225 |
+
# Build up options for models
|
226 |
+
|
227 |
+
model_full_names = []
|
228 |
+
model_short_names = []
|
229 |
+
|
230 |
+
CHOSEN_LOCAL_MODEL_TYPE = get_or_create_env_var("CHOSEN_LOCAL_MODEL_TYPE", "Gemma 3 4B") # Gemma 3 1B # "Gemma 2b"
|
231 |
+
|
232 |
+
if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
|
233 |
+
model_full_names.append(CHOSEN_LOCAL_MODEL_TYPE)
|
234 |
+
model_short_names.append(CHOSEN_LOCAL_MODEL_TYPE)
|
235 |
+
|
236 |
+
if RUN_AWS_FUNCTIONS == "1":
|
237 |
+
model_full_names.extend(["anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"])
|
238 |
+
model_short_names.extend(["haiku", "sonnet"])
|
239 |
+
|
240 |
+
if RUN_GEMINI_MODELS == "1":
|
241 |
+
model_full_names.extend(["gemini-2.0-flash-001", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-exp-05-06" ]) # , # Gemini pro No longer available on free tier
|
242 |
+
model_short_names.extend(["gemini_flash_2", "gemini_flash_2.5", "gemini_pro"])
|
243 |
+
|
244 |
+
print("model_short_names:", model_short_names)
|
245 |
+
print("model_full_names:", model_full_names)
|
246 |
+
|
247 |
+
model_name_map = {short: full for short, full in zip(model_full_names, model_short_names)}
|
248 |
+
|
249 |
+
# HF token may or may not be needed for downloading models from Hugging Face
|
250 |
+
HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
|
251 |
+
|
252 |
+
GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
|
253 |
+
GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
|
254 |
+
GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma") #"model/phi" # Assuming this is your intended directory
|
255 |
+
|
256 |
+
GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "ggml-org/gemma-3-1b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
|
257 |
+
GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-1b-it-Q8_0.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
|
258 |
+
GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
|
259 |
+
|
260 |
+
GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "ggml-org/gemma-3-4b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
|
261 |
+
GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-Q4_K_M.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
|
262 |
+
GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
|
263 |
+
|
264 |
+
|
265 |
+
if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 2b":
|
266 |
+
LOCAL_REPO_ID = GEMMA2_REPO_ID
|
267 |
+
LOCAL_MODEL_FILE = GEMMA2_MODEL_FILE
|
268 |
+
LOCAL_MODEL_FOLDER = GEMMA2_MODEL_FOLDER
|
269 |
+
|
270 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 1B":
|
271 |
+
LOCAL_REPO_ID = GEMMA3_REPO_ID
|
272 |
+
LOCAL_MODEL_FILE = GEMMA3_MODEL_FILE
|
273 |
+
LOCAL_MODEL_FOLDER = GEMMA3_MODEL_FOLDER
|
274 |
+
|
275 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 4B":
|
276 |
+
LOCAL_REPO_ID = GEMMA3_4B_REPO_ID
|
277 |
+
LOCAL_MODEL_FILE = GEMMA3_4B_MODEL_FILE
|
278 |
+
LOCAL_MODEL_FOLDER = GEMMA3_4B_MODEL_FOLDER
|
279 |
+
|
280 |
+
print("CHOSEN_LOCAL_MODEL_TYPE:", CHOSEN_LOCAL_MODEL_TYPE)
|
281 |
+
print("LOCAL_REPO_ID:", LOCAL_REPO_ID)
|
282 |
+
print("LOCAL_MODEL_FILE:", LOCAL_MODEL_FILE)
|
283 |
+
print("LOCAL_MODEL_FOLDER:", LOCAL_MODEL_FOLDER)
|
284 |
+
|
285 |
+
LLM_TEMPERATURE = float(get_or_create_env_var('LLM_TEMPERATURE', '0.1'))
|
286 |
+
LLM_TOP_K = int(get_or_create_env_var('LLM_TOP_K','3'))
|
287 |
+
LLM_TOP_P = float(get_or_create_env_var('LLM_TOP_P', '1'))
|
288 |
+
LLM_REPETITION_PENALTY = float(get_or_create_env_var('LLM_REPETITION_PENALTY', '1.2')) # Mild repetition penalty to prevent repeating table rows
|
289 |
+
LLM_LAST_N_TOKENS = int(get_or_create_env_var('LLM_LAST_N_TOKENS', '512'))
|
290 |
+
LLM_MAX_NEW_TOKENS = int(get_or_create_env_var('LLM_MAX_NEW_TOKENS', '4096'))
|
291 |
+
LLM_SEED = int(get_or_create_env_var('LLM_SEED', '42'))
|
292 |
+
LLM_RESET = get_or_create_env_var('LLM_RESET', 'True')
|
293 |
+
LLM_STREAM = get_or_create_env_var('LLM_STREAM', 'False')
|
294 |
+
LLM_THREADS = int(get_or_create_env_var('LLM_THREADS', '4'))
|
295 |
+
LLM_BATCH_SIZE = int(get_or_create_env_var('LLM_BATCH_SIZE', '256'))
|
296 |
+
LLM_CONTEXT_LENGTH = int(get_or_create_env_var('LLM_CONTEXT_LENGTH', '16384'))
|
297 |
+
LLM_SAMPLE = get_or_create_env_var('LLM_SAMPLE', 'True')
|
298 |
+
|
299 |
+
###
|
300 |
+
# Gradio app variables
|
301 |
+
###
|
302 |
+
|
303 |
+
# Get some environment variables and Launch the Gradio app
|
304 |
+
COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
|
305 |
+
|
306 |
+
RUN_DIRECT_MODE = get_or_create_env_var('RUN_DIRECT_MODE', '0')
|
307 |
+
|
308 |
+
MAX_QUEUE_SIZE = int(get_or_create_env_var('MAX_QUEUE_SIZE', '5'))
|
309 |
+
|
310 |
+
MAX_FILE_SIZE = get_or_create_env_var('MAX_FILE_SIZE', '250mb')
|
311 |
+
|
312 |
+
GRADIO_SERVER_PORT = int(get_or_create_env_var('GRADIO_SERVER_PORT', '7860'))
|
313 |
+
|
314 |
+
ROOT_PATH = get_or_create_env_var('ROOT_PATH', '')
|
315 |
+
|
316 |
+
DEFAULT_CONCURRENCY_LIMIT = get_or_create_env_var('DEFAULT_CONCURRENCY_LIMIT', '3')
|
317 |
+
|
318 |
+
GET_DEFAULT_ALLOW_LIST = get_or_create_env_var('GET_DEFAULT_ALLOW_LIST', '')
|
319 |
+
|
320 |
+
ALLOW_LIST_PATH = get_or_create_env_var('ALLOW_LIST_PATH', '') # config/default_allow_list.csv
|
321 |
+
|
322 |
+
S3_ALLOW_LIST_PATH = get_or_create_env_var('S3_ALLOW_LIST_PATH', '') # default_allow_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
|
323 |
+
|
324 |
+
if ALLOW_LIST_PATH: OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
|
325 |
+
else: OUTPUT_ALLOW_LIST_PATH = 'config/default_allow_list.csv'
|
326 |
+
|
327 |
+
FILE_INPUT_HEIGHT = get_or_create_env_var('FILE_INPUT_HEIGHT', '200')
|
tools/dedup_summaries.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from rapidfuzz import process, fuzz
|
3 |
+
from typing import List, Tuple
|
4 |
+
import re
|
5 |
+
import spaces
|
6 |
+
import gradio as gr
|
7 |
+
from time import time
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt
|
11 |
+
from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model
|
12 |
+
from tools.helper_functions import create_unique_table_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text
|
13 |
+
from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map
|
14 |
+
|
15 |
+
max_tokens = MAX_TOKENS
|
16 |
+
timeout_wait = TIMEOUT_WAIT
|
17 |
+
number_of_api_retry_attempts = NUMBER_OF_RETRY_ATTEMPTS
|
18 |
+
max_time_for_loop = MAX_TIME_FOR_LOOP
|
19 |
+
batch_size_default = BATCH_SIZE_DEFAULT
|
20 |
+
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
21 |
+
max_comment_character_length = MAX_COMMENT_CHARS
|
22 |
+
|
23 |
+
# DEDUPLICATION/SUMMARISATION FUNCTIONS
|
24 |
+
def deduplicate_categories(category_series: pd.Series, join_series: pd.Series, reference_df: pd.DataFrame, general_topic_series: pd.Series = None, merge_general_topics = "No", merge_sentiment:str="No", threshold: float = 90) -> pd.DataFrame:
|
25 |
+
"""
|
26 |
+
Deduplicates similar category names in a pandas Series based on a fuzzy matching threshold,
|
27 |
+
merging smaller topics into larger topics.
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
category_series (pd.Series): Series containing category names to deduplicate.
|
31 |
+
join_series (pd.Series): Additional series used for joining back to original results.
|
32 |
+
reference_df (pd.DataFrame): DataFrame containing the reference data to count occurrences.
|
33 |
+
threshold (float): Similarity threshold for considering two strings as duplicates.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
pd.DataFrame: DataFrame with columns ['old_category', 'deduplicated_category'].
|
37 |
+
"""
|
38 |
+
# Count occurrences of each category in the reference_df
|
39 |
+
category_counts = reference_df['Subtopic'].value_counts().to_dict()
|
40 |
+
|
41 |
+
# Initialize dictionaries for both category mapping and scores
|
42 |
+
deduplication_map = {}
|
43 |
+
match_scores = {} # New dictionary to store match scores
|
44 |
+
|
45 |
+
# First pass: Handle exact matches
|
46 |
+
for category in category_series.unique():
|
47 |
+
if category in deduplication_map:
|
48 |
+
continue
|
49 |
+
|
50 |
+
# Find all exact matches
|
51 |
+
exact_matches = category_series[category_series.str.lower() == category.lower()].index.tolist()
|
52 |
+
if len(exact_matches) > 1:
|
53 |
+
# Find the variant with the highest count
|
54 |
+
match_counts = {match: category_counts.get(category_series[match], 0) for match in exact_matches}
|
55 |
+
most_common = max(match_counts.items(), key=lambda x: x[1])[0]
|
56 |
+
most_common_category = category_series[most_common]
|
57 |
+
|
58 |
+
# Map all exact matches to the most common variant and store score
|
59 |
+
for match in exact_matches:
|
60 |
+
deduplication_map[category_series[match]] = most_common_category
|
61 |
+
match_scores[category_series[match]] = 100 # Exact matches get score of 100
|
62 |
+
|
63 |
+
# Second pass: Handle fuzzy matches for remaining categories
|
64 |
+
# Create a DataFrame to maintain the relationship between categories and general topics
|
65 |
+
categories_df = pd.DataFrame({
|
66 |
+
'category': category_series,
|
67 |
+
'general_topic': general_topic_series
|
68 |
+
}).drop_duplicates()
|
69 |
+
|
70 |
+
for _, row in categories_df.iterrows():
|
71 |
+
category = row['category']
|
72 |
+
if category in deduplication_map:
|
73 |
+
continue
|
74 |
+
|
75 |
+
current_general_topic = row['general_topic']
|
76 |
+
|
77 |
+
# Filter potential matches to only those within the same General Topic if relevant
|
78 |
+
if merge_general_topics == "No":
|
79 |
+
potential_matches = categories_df[
|
80 |
+
(categories_df['category'] != category) &
|
81 |
+
(categories_df['general_topic'] == current_general_topic)
|
82 |
+
]['category'].tolist()
|
83 |
+
else:
|
84 |
+
potential_matches = categories_df[
|
85 |
+
(categories_df['category'] != category)
|
86 |
+
]['category'].tolist()
|
87 |
+
|
88 |
+
matches = process.extract(category,
|
89 |
+
potential_matches,
|
90 |
+
scorer=fuzz.WRatio,
|
91 |
+
score_cutoff=threshold)
|
92 |
+
|
93 |
+
if matches:
|
94 |
+
best_match = max(matches, key=lambda x: x[1])
|
95 |
+
match, score, _ = best_match
|
96 |
+
|
97 |
+
if category_counts.get(category, 0) < category_counts.get(match, 0):
|
98 |
+
deduplication_map[category] = match
|
99 |
+
match_scores[category] = score
|
100 |
+
else:
|
101 |
+
deduplication_map[match] = category
|
102 |
+
match_scores[match] = score
|
103 |
+
else:
|
104 |
+
deduplication_map[category] = category
|
105 |
+
match_scores[category] = 100
|
106 |
+
|
107 |
+
# Create the result DataFrame with scores
|
108 |
+
result_df = pd.DataFrame({
|
109 |
+
'old_category': category_series + " | " + join_series,
|
110 |
+
'deduplicated_category': category_series.map(lambda x: deduplication_map.get(x, x)),
|
111 |
+
'match_score': category_series.map(lambda x: match_scores.get(x, 100)) # Add scores column
|
112 |
+
})
|
113 |
+
|
114 |
+
#print(result_df)
|
115 |
+
|
116 |
+
return result_df
|
117 |
+
|
118 |
+
def deduplicate_topics(reference_df:pd.DataFrame,
|
119 |
+
unique_topics_df:pd.DataFrame,
|
120 |
+
reference_table_file_name:str,
|
121 |
+
unique_topics_table_file_name:str,
|
122 |
+
in_excel_sheets:str="",
|
123 |
+
merge_sentiment:str= "No",
|
124 |
+
merge_general_topics:str="No",
|
125 |
+
score_threshold:int=90,
|
126 |
+
in_data_files:List[str]=[],
|
127 |
+
chosen_cols:List[str]="",
|
128 |
+
deduplicate_topics:str="Yes",
|
129 |
+
output_folder:str=OUTPUT_FOLDER
|
130 |
+
):
|
131 |
+
'''
|
132 |
+
Deduplicate topics based on a reference and unique topics table
|
133 |
+
'''
|
134 |
+
output_files = []
|
135 |
+
log_output_files = []
|
136 |
+
file_data = pd.DataFrame()
|
137 |
+
|
138 |
+
reference_table_file_name_no_ext = reference_table_file_name #get_file_name_no_ext(reference_table_file_name)
|
139 |
+
unique_topics_table_file_name_no_ext = unique_topics_table_file_name #get_file_name_no_ext(unique_topics_table_file_name)
|
140 |
+
|
141 |
+
# For checking that data is not lost during the process
|
142 |
+
initial_unique_references = len(reference_df["Response References"].unique())
|
143 |
+
|
144 |
+
if unique_topics_df.empty:
|
145 |
+
unique_topics_df = create_unique_table_df_from_reference_table(reference_df)
|
146 |
+
|
147 |
+
# Then merge the topic numbers back to the original dataframe
|
148 |
+
reference_df = reference_df.merge(
|
149 |
+
unique_topics_df[['General Topic', 'Subtopic', 'Sentiment', 'Topic_number']],
|
150 |
+
on=['General Topic', 'Subtopic', 'Sentiment'],
|
151 |
+
how='left'
|
152 |
+
)
|
153 |
+
|
154 |
+
if in_data_files and chosen_cols:
|
155 |
+
file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file(in_data_files, chosen_cols, 1, in_excel_sheets)
|
156 |
+
else:
|
157 |
+
out_message = "No file data found, pivot table output will not be created."
|
158 |
+
print(out_message)
|
159 |
+
#raise Exception(out_message)
|
160 |
+
|
161 |
+
# Run through this x times to try to get all duplicate topics
|
162 |
+
if deduplicate_topics == "Yes":
|
163 |
+
for i in range(0, 8):
|
164 |
+
if merge_sentiment == "No":
|
165 |
+
if merge_general_topics == "No":
|
166 |
+
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
167 |
+
reference_df_unique = reference_df.drop_duplicates("old_category")
|
168 |
+
|
169 |
+
deduplicated_topic_map_df = reference_df_unique.groupby(["General Topic", "Sentiment"]).apply(
|
170 |
+
lambda group: deduplicate_categories(
|
171 |
+
group["Subtopic"],
|
172 |
+
group["Sentiment"],
|
173 |
+
reference_df,
|
174 |
+
general_topic_series=group["General Topic"],
|
175 |
+
merge_general_topics="No",
|
176 |
+
threshold=score_threshold
|
177 |
+
)
|
178 |
+
).reset_index(drop=True)
|
179 |
+
else:
|
180 |
+
# This case should allow cross-topic matching but is still grouping by Sentiment
|
181 |
+
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
182 |
+
reference_df_unique = reference_df.drop_duplicates("old_category")
|
183 |
+
|
184 |
+
deduplicated_topic_map_df = reference_df_unique.groupby("Sentiment").apply(
|
185 |
+
lambda group: deduplicate_categories(
|
186 |
+
group["Subtopic"],
|
187 |
+
group["Sentiment"],
|
188 |
+
reference_df,
|
189 |
+
general_topic_series=None, # Set to None to allow cross-topic matching
|
190 |
+
merge_general_topics="Yes",
|
191 |
+
threshold=score_threshold
|
192 |
+
)
|
193 |
+
).reset_index(drop=True)
|
194 |
+
else:
|
195 |
+
if merge_general_topics == "No":
|
196 |
+
# Update this case to maintain general topic boundaries
|
197 |
+
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
198 |
+
reference_df_unique = reference_df.drop_duplicates("old_category")
|
199 |
+
|
200 |
+
deduplicated_topic_map_df = reference_df_unique.groupby("General Topic").apply(
|
201 |
+
lambda group: deduplicate_categories(
|
202 |
+
group["Subtopic"],
|
203 |
+
group["Sentiment"],
|
204 |
+
reference_df,
|
205 |
+
general_topic_series=group["General Topic"],
|
206 |
+
merge_general_topics="No",
|
207 |
+
merge_sentiment=merge_sentiment,
|
208 |
+
threshold=score_threshold
|
209 |
+
)
|
210 |
+
).reset_index(drop=True)
|
211 |
+
else:
|
212 |
+
# For complete merging across all categories
|
213 |
+
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
214 |
+
reference_df_unique = reference_df.drop_duplicates("old_category")
|
215 |
+
|
216 |
+
deduplicated_topic_map_df = deduplicate_categories(
|
217 |
+
reference_df_unique["Subtopic"],
|
218 |
+
reference_df_unique["Sentiment"],
|
219 |
+
reference_df,
|
220 |
+
general_topic_series=None, # Set to None to allow cross-topic matching
|
221 |
+
merge_general_topics="Yes",
|
222 |
+
merge_sentiment=merge_sentiment,
|
223 |
+
threshold=score_threshold
|
224 |
+
).reset_index(drop=True)
|
225 |
+
|
226 |
+
if deduplicated_topic_map_df['deduplicated_category'].isnull().all():
|
227 |
+
# Check if 'deduplicated_category' contains any values
|
228 |
+
print("No deduplicated categories found, skipping the following code.")
|
229 |
+
|
230 |
+
else:
|
231 |
+
# Remove rows where 'deduplicated_category' is blank or NaN
|
232 |
+
deduplicated_topic_map_df = deduplicated_topic_map_df.loc[(deduplicated_topic_map_df['deduplicated_category'].str.strip() != '') & ~(deduplicated_topic_map_df['deduplicated_category'].isnull()), ['old_category','deduplicated_category', 'match_score']]
|
233 |
+
|
234 |
+
#deduplicated_topic_map_df.to_csv(output_folder + "deduplicated_topic_map_df_" + str(i) + ".csv", index=None)
|
235 |
+
|
236 |
+
reference_df = reference_df.merge(deduplicated_topic_map_df, on="old_category", how="left")
|
237 |
+
|
238 |
+
reference_df.rename(columns={"Subtopic": "Subtopic_old", "Sentiment": "Sentiment_old"}, inplace=True)
|
239 |
+
# Extract subtopic and sentiment from deduplicated_category
|
240 |
+
reference_df["Subtopic"] = reference_df["deduplicated_category"].str.extract(r'^(.*?) \|')[0] # Extract subtopic
|
241 |
+
reference_df["Sentiment"] = reference_df["deduplicated_category"].str.extract(r'\| (.*)$')[0] # Extract sentiment
|
242 |
+
|
243 |
+
# Combine with old values to ensure no data is lost
|
244 |
+
reference_df["Subtopic"] = reference_df["deduplicated_category"].combine_first(reference_df["Subtopic_old"])
|
245 |
+
reference_df["Sentiment"] = reference_df["Sentiment"].combine_first(reference_df["Sentiment_old"])
|
246 |
+
|
247 |
+
|
248 |
+
reference_df.drop(['old_category', 'deduplicated_category', "Subtopic_old", "Sentiment_old"], axis=1, inplace=True, errors="ignore")
|
249 |
+
|
250 |
+
reference_df = reference_df[["Response References", "General Topic", "Subtopic", "Sentiment", "Summary", "Start row of group"]]
|
251 |
+
|
252 |
+
#reference_df["General Topic"] = reference_df["General Topic"].str.lower().str.capitalize()
|
253 |
+
#reference_df["Subtopic"] = reference_df["Subtopic"].str.lower().str.capitalize()
|
254 |
+
#reference_df["Sentiment"] = reference_df["Sentiment"].str.lower().str.capitalize()
|
255 |
+
|
256 |
+
if merge_general_topics == "Yes":
|
257 |
+
# Replace General topic names for each Subtopic with that for the Subtopic with the most responses
|
258 |
+
# Step 1: Count the number of occurrences for each General Topic and Subtopic combination
|
259 |
+
count_df = reference_df.groupby(['Subtopic', 'General Topic']).size().reset_index(name='Count')
|
260 |
+
|
261 |
+
# Step 2: Find the General Topic with the maximum count for each Subtopic
|
262 |
+
max_general_topic = count_df.loc[count_df.groupby('Subtopic')['Count'].idxmax()]
|
263 |
+
|
264 |
+
# Step 3: Map the General Topic back to the original DataFrame
|
265 |
+
reference_df = reference_df.merge(max_general_topic[['Subtopic', 'General Topic']], on='Subtopic', suffixes=('', '_max'), how='left')
|
266 |
+
|
267 |
+
reference_df['General Topic'] = reference_df["General Topic_max"].combine_first(reference_df["General Topic"])
|
268 |
+
|
269 |
+
if merge_sentiment == "Yes":
|
270 |
+
# Step 1: Count the number of occurrences for each General Topic and Subtopic combination
|
271 |
+
count_df = reference_df.groupby(['Subtopic', 'Sentiment']).size().reset_index(name='Count')
|
272 |
+
|
273 |
+
# Step 2: Determine the number of unique Sentiment values for each Subtopic
|
274 |
+
unique_sentiments = count_df.groupby('Subtopic')['Sentiment'].nunique().reset_index(name='UniqueCount')
|
275 |
+
|
276 |
+
# Step 3: Update Sentiment to 'Mixed' where there is more than one unique sentiment
|
277 |
+
reference_df = reference_df.merge(unique_sentiments, on='Subtopic', how='left')
|
278 |
+
reference_df['Sentiment'] = reference_df.apply(
|
279 |
+
lambda row: 'Mixed' if row['UniqueCount'] > 1 else row['Sentiment'],
|
280 |
+
axis=1
|
281 |
+
)
|
282 |
+
|
283 |
+
# Clean up the DataFrame by dropping the UniqueCount column
|
284 |
+
reference_df.drop(columns=['UniqueCount'], inplace=True)
|
285 |
+
|
286 |
+
reference_df = reference_df[["Response References", "General Topic", "Subtopic", "Sentiment", "Summary", "Start row of group"]]
|
287 |
+
|
288 |
+
# Update reference summary column with all summaries
|
289 |
+
reference_df["Summary"] = reference_df.groupby(
|
290 |
+
["Response References", "General Topic", "Subtopic", "Sentiment"]
|
291 |
+
)["Summary"].transform(' <br> '.join)
|
292 |
+
|
293 |
+
# Check that we have not inadvertantly removed some data during the above process
|
294 |
+
end_unique_references = len(reference_df["Response References"].unique())
|
295 |
+
|
296 |
+
if initial_unique_references != end_unique_references:
|
297 |
+
raise Exception(f"Number of unique references changed during processing: Initial={initial_unique_references}, Final={end_unique_references}")
|
298 |
+
|
299 |
+
# Drop duplicates in the reference table - each comment should only have the same topic referred to once
|
300 |
+
reference_df.drop_duplicates(['Response References', 'General Topic', 'Subtopic', 'Sentiment'], inplace=True)
|
301 |
+
|
302 |
+
|
303 |
+
# Remake unique_topics_df based on new reference_df
|
304 |
+
unique_topics_df = create_unique_table_df_from_reference_table(reference_df)
|
305 |
+
|
306 |
+
# Then merge the topic numbers back to the original dataframe
|
307 |
+
reference_df = reference_df.merge(
|
308 |
+
unique_topics_df[['General Topic', 'Subtopic', 'Sentiment', 'Topic_number']],
|
309 |
+
on=['General Topic', 'Subtopic', 'Sentiment'],
|
310 |
+
how='left'
|
311 |
+
)
|
312 |
+
|
313 |
+
if not file_data.empty:
|
314 |
+
basic_response_data = get_basic_response_data(file_data, chosen_cols)
|
315 |
+
reference_df_pivot = convert_reference_table_to_pivot_table(reference_df, basic_response_data)
|
316 |
+
|
317 |
+
reference_pivot_file_path = output_folder + reference_table_file_name_no_ext + "_pivot_dedup.csv"
|
318 |
+
reference_df_pivot.to_csv(reference_pivot_file_path, index=None, encoding='utf-8')
|
319 |
+
log_output_files.append(reference_pivot_file_path)
|
320 |
+
|
321 |
+
#reference_table_file_name_no_ext = get_file_name_no_ext(reference_table_file_name)
|
322 |
+
#unique_topics_table_file_name_no_ext = get_file_name_no_ext(unique_topics_table_file_name)
|
323 |
+
|
324 |
+
reference_file_path = output_folder + reference_table_file_name_no_ext + "_dedup.csv"
|
325 |
+
unique_topics_file_path = output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv"
|
326 |
+
reference_df.to_csv(reference_file_path, index = None, encoding='utf-8')
|
327 |
+
unique_topics_df.to_csv(unique_topics_file_path, index=None, encoding='utf-8')
|
328 |
+
|
329 |
+
output_files.append(reference_file_path)
|
330 |
+
output_files.append(unique_topics_file_path)
|
331 |
+
|
332 |
+
# Outputs for markdown table output
|
333 |
+
unique_table_df_revised_display = unique_topics_df.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
|
334 |
+
|
335 |
+
deduplicated_unique_table_markdown = unique_table_df_revised_display.to_markdown(index=False)
|
336 |
+
|
337 |
+
return reference_df, unique_topics_df, output_files, log_output_files, deduplicated_unique_table_markdown
|
338 |
+
|
339 |
+
def sample_reference_table_summaries(reference_df:pd.DataFrame,
|
340 |
+
unique_topics_df:pd.DataFrame,
|
341 |
+
random_seed:int,
|
342 |
+
no_of_sampled_summaries:int=150):
|
343 |
+
|
344 |
+
'''
|
345 |
+
Sample x number of summaries from which to produce summaries, so that the input token length is not too long.
|
346 |
+
'''
|
347 |
+
|
348 |
+
all_summaries = pd.DataFrame()
|
349 |
+
output_files = []
|
350 |
+
|
351 |
+
reference_df_grouped = reference_df.groupby(["General Topic", "Subtopic", "Sentiment"])
|
352 |
+
|
353 |
+
if 'Revised summary' in reference_df.columns:
|
354 |
+
out_message = "Summary has already been created for this file"
|
355 |
+
print(out_message)
|
356 |
+
raise Exception(out_message)
|
357 |
+
|
358 |
+
for group_keys, reference_df_group in reference_df_grouped:
|
359 |
+
#print(f"Group: {group_keys}")
|
360 |
+
#print(f"Data: {reference_df_group}")
|
361 |
+
|
362 |
+
if len(reference_df_group["General Topic"]) > 1:
|
363 |
+
|
364 |
+
filtered_reference_df = reference_df_group.reset_index()
|
365 |
+
|
366 |
+
filtered_reference_df_unique = filtered_reference_df.drop_duplicates(["General Topic", "Subtopic", "Sentiment", "Summary"])
|
367 |
+
|
368 |
+
# Sample n of the unique topic summaries. To limit the length of the text going into the summarisation tool
|
369 |
+
filtered_reference_df_unique_sampled = filtered_reference_df_unique.sample(min(no_of_sampled_summaries, len(filtered_reference_df_unique)), random_state=random_seed)
|
370 |
+
|
371 |
+
#topic_summary_table_markdown = filtered_reference_df_unique_sampled.to_markdown(index=False)
|
372 |
+
|
373 |
+
#print(filtered_reference_df_unique_sampled)
|
374 |
+
|
375 |
+
all_summaries = pd.concat([all_summaries, filtered_reference_df_unique_sampled])
|
376 |
+
|
377 |
+
summarised_references = all_summaries.groupby(["General Topic", "Subtopic", "Sentiment"]).agg({
|
378 |
+
'Response References': 'size', # Count the number of references
|
379 |
+
'Summary': lambda x: '\n'.join([s.split(': ', 1)[1] for s in x if ': ' in s]) # Join substrings after ': '
|
380 |
+
}).reset_index()
|
381 |
+
|
382 |
+
summarised_references = summarised_references.loc[(summarised_references["Sentiment"] != "Not Mentioned") & (summarised_references["Response References"] > 1)]
|
383 |
+
|
384 |
+
summarised_references_markdown = summarised_references.to_markdown(index=False)
|
385 |
+
|
386 |
+
return summarised_references, summarised_references_markdown, reference_df, unique_topics_df
|
387 |
+
|
388 |
+
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, local_model=[]):
|
389 |
+
conversation_history = []
|
390 |
+
whole_conversation_metadata = []
|
391 |
+
|
392 |
+
# Prepare Gemini models before query
|
393 |
+
if "gemini" in model_choice:
|
394 |
+
print("Using Gemini model:", model_choice)
|
395 |
+
model, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
|
396 |
+
else:
|
397 |
+
print("Using AWS Bedrock model:", model_choice)
|
398 |
+
model = model_choice
|
399 |
+
config = {}
|
400 |
+
|
401 |
+
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
402 |
+
|
403 |
+
# Process requests to large language model
|
404 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, local_model=local_model)
|
405 |
+
|
406 |
+
print("Finished summary query")
|
407 |
+
|
408 |
+
if isinstance(responses[-1], ResponseObject):
|
409 |
+
response_texts = [resp.text for resp in responses]
|
410 |
+
elif "choices" in responses[-1]:
|
411 |
+
response_texts = [resp["choices"][0]['text'] for resp in responses]
|
412 |
+
else:
|
413 |
+
response_texts = [resp.text for resp in responses]
|
414 |
+
|
415 |
+
latest_response_text = response_texts[-1]
|
416 |
+
|
417 |
+
#print("latest_response_text:", latest_response_text)
|
418 |
+
#print("Whole conversation metadata:", whole_conversation_metadata)
|
419 |
+
|
420 |
+
return latest_response_text, conversation_history, whole_conversation_metadata
|
421 |
+
|
422 |
+
@spaces.GPU
|
423 |
+
def summarise_output_topics(summarised_references:pd.DataFrame,
|
424 |
+
unique_table_df:pd.DataFrame,
|
425 |
+
reference_table_df:pd.DataFrame,
|
426 |
+
model_choice:str,
|
427 |
+
in_api_key:str,
|
428 |
+
topic_summary_table_markdown:str,
|
429 |
+
temperature:float,
|
430 |
+
table_file_name:str,
|
431 |
+
summarised_outputs:list = [],
|
432 |
+
latest_summary_completed:int = 0,
|
433 |
+
out_metadata_str:str = "",
|
434 |
+
in_data_files:List[str]=[],
|
435 |
+
in_excel_sheets:str="",
|
436 |
+
chosen_cols:List[str]=[],
|
437 |
+
log_output_files:list[str]=[],
|
438 |
+
summarise_format_radio:str="Return a summary up to two paragraphs long that includes as much detail as possible from the original text",
|
439 |
+
output_folder:str=OUTPUT_FOLDER,
|
440 |
+
output_files:list[str] = [],
|
441 |
+
summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
|
442 |
+
do_summaries:str="Yes",
|
443 |
+
progress=gr.Progress(track_tqdm=True)):
|
444 |
+
'''
|
445 |
+
Create better summaries of the raw batch-level summaries created in the first run of the model.
|
446 |
+
'''
|
447 |
+
out_metadata = []
|
448 |
+
local_model = []
|
449 |
+
summarised_output_markdown = ""
|
450 |
+
|
451 |
+
|
452 |
+
# Check for data for summarisations
|
453 |
+
if not unique_table_df.empty and not reference_table_df.empty:
|
454 |
+
print("Unique table and reference table data found.")
|
455 |
+
else:
|
456 |
+
out_message = "Please upload a unique topic table and reference table file to continue with summarisation."
|
457 |
+
print(out_message)
|
458 |
+
raise Exception(out_message)
|
459 |
+
|
460 |
+
if 'Revised summary' in reference_table_df.columns:
|
461 |
+
out_message = "Summary has already been created for this file"
|
462 |
+
print(out_message)
|
463 |
+
raise Exception(out_message)
|
464 |
+
|
465 |
+
# Load in data file and chosen columns if exists to create pivot table later
|
466 |
+
if in_data_files and chosen_cols:
|
467 |
+
file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file(in_data_files, chosen_cols, 1, in_excel_sheets=in_excel_sheets)
|
468 |
+
else:
|
469 |
+
out_message = "No file data found, pivot table output will not be created."
|
470 |
+
print(out_message)
|
471 |
+
raise Exception(out_message)
|
472 |
+
|
473 |
+
|
474 |
+
all_summaries = summarised_references["Summary"].tolist()
|
475 |
+
length_all_summaries = len(all_summaries)
|
476 |
+
|
477 |
+
# If all summaries completed, make final outputs
|
478 |
+
if latest_summary_completed >= length_all_summaries:
|
479 |
+
print("All summaries completed. Creating outputs.")
|
480 |
+
|
481 |
+
model_choice_clean = model_name_map[model_choice]
|
482 |
+
file_name = re.search(r'(.*?)(?:_batch_|_col_)', table_file_name).group(1) if re.search(r'(.*?)(?:_batch_|_col_)', table_file_name) else table_file_name
|
483 |
+
latest_batch_completed = int(re.search(r'batch_(\d+)_', table_file_name).group(1)) if 'batch_' in table_file_name else ""
|
484 |
+
batch_size_number = int(re.search(r'size_(\d+)_', table_file_name).group(1)) if 'size_' in table_file_name else ""
|
485 |
+
in_column_cleaned = re.search(r'col_(.*?)_reference', table_file_name).group(1) if 'col_' in table_file_name else ""
|
486 |
+
|
487 |
+
# Save outputs for each batch. If master file created, label file as master
|
488 |
+
if latest_batch_completed:
|
489 |
+
batch_file_path_details = f"{file_name}_batch_{latest_batch_completed}_size_{batch_size_number}_col_{in_column_cleaned}"
|
490 |
+
else:
|
491 |
+
batch_file_path_details = f"{file_name}_col_{in_column_cleaned}"
|
492 |
+
|
493 |
+
summarised_references["Revised summary"] = summarised_outputs
|
494 |
+
|
495 |
+
join_cols = ["General Topic", "Subtopic", "Sentiment"]
|
496 |
+
join_plus_summary_cols = ["General Topic", "Subtopic", "Sentiment", "Revised summary"]
|
497 |
+
|
498 |
+
summarised_references_j = summarised_references[join_plus_summary_cols].drop_duplicates(join_plus_summary_cols)
|
499 |
+
|
500 |
+
unique_table_df_revised = unique_table_df.merge(summarised_references_j, on = join_cols, how = "left")
|
501 |
+
|
502 |
+
# If no new summary is available, keep the original
|
503 |
+
unique_table_df_revised["Revised summary"] = unique_table_df_revised["Revised summary"].combine_first(unique_table_df_revised["Summary"])
|
504 |
+
|
505 |
+
unique_table_df_revised = unique_table_df_revised[["General Topic", "Subtopic", "Sentiment", "Response References", "Revised summary"]]
|
506 |
+
|
507 |
+
reference_table_df_revised = reference_table_df.merge(summarised_references_j, on = join_cols, how = "left")
|
508 |
+
# If no new summary is available, keep the original
|
509 |
+
reference_table_df_revised["Revised summary"] = reference_table_df_revised["Revised summary"].combine_first(reference_table_df_revised["Summary"])
|
510 |
+
reference_table_df_revised = reference_table_df_revised.drop("Summary", axis=1)
|
511 |
+
|
512 |
+
# Remove topics that are tagged as 'Not Mentioned'
|
513 |
+
unique_table_df_revised = unique_table_df_revised.loc[unique_table_df_revised["Sentiment"] != "Not Mentioned", :]
|
514 |
+
reference_table_df_revised = reference_table_df_revised.loc[reference_table_df_revised["Sentiment"] != "Not Mentioned", :]
|
515 |
+
|
516 |
+
|
517 |
+
|
518 |
+
|
519 |
+
if not file_data.empty:
|
520 |
+
basic_response_data = get_basic_response_data(file_data, chosen_cols)
|
521 |
+
reference_table_df_revised_pivot = convert_reference_table_to_pivot_table(reference_table_df_revised, basic_response_data)
|
522 |
+
|
523 |
+
### Save pivot file to log area
|
524 |
+
reference_table_df_revised_pivot_path = output_folder + batch_file_path_details + "_summarised_reference_table_pivot_" + model_choice_clean + ".csv"
|
525 |
+
reference_table_df_revised_pivot.to_csv(reference_table_df_revised_pivot_path, index=None, encoding='utf-8')
|
526 |
+
log_output_files.append(reference_table_df_revised_pivot_path)
|
527 |
+
|
528 |
+
# Save to file
|
529 |
+
unique_table_df_revised_path = output_folder + batch_file_path_details + "_summarised_unique_topic_table_" + model_choice_clean + ".csv"
|
530 |
+
unique_table_df_revised.to_csv(unique_table_df_revised_path, index = None, encoding='utf-8')
|
531 |
+
|
532 |
+
reference_table_df_revised_path = output_folder + batch_file_path_details + "_summarised_reference_table_" + model_choice_clean + ".csv"
|
533 |
+
reference_table_df_revised.to_csv(reference_table_df_revised_path, index = None, encoding='utf-8')
|
534 |
+
|
535 |
+
output_files.extend([reference_table_df_revised_path, unique_table_df_revised_path])
|
536 |
+
|
537 |
+
###
|
538 |
+
unique_table_df_revised_display = unique_table_df_revised.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
|
539 |
+
|
540 |
+
summarised_output_markdown = unique_table_df_revised_display.to_markdown(index=False)
|
541 |
+
|
542 |
+
# Ensure same file name not returned twice
|
543 |
+
output_files = list(set(output_files))
|
544 |
+
log_output_files = list(set(log_output_files))
|
545 |
+
|
546 |
+
return summarised_references, unique_table_df_revised, reference_table_df_revised, output_files, summarised_outputs, latest_summary_completed, out_metadata_str, summarised_output_markdown, log_output_files
|
547 |
+
|
548 |
+
tic = time.perf_counter()
|
549 |
+
|
550 |
+
#print("Starting with:", latest_summary_completed)
|
551 |
+
#print("Last summary number:", length_all_summaries)
|
552 |
+
|
553 |
+
if (model_choice == "gemma_2b_it_local") & (RUN_LOCAL_MODEL == "1"):
|
554 |
+
progress(0.1, "Loading in Gemma 2b model")
|
555 |
+
local_model, tokenizer = load_model()
|
556 |
+
#print("Local model loaded:", local_model)
|
557 |
+
|
558 |
+
summary_loop_description = "Creating summaries. " + str(latest_summary_completed) + " summaries completed so far."
|
559 |
+
summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Creating summaries", unit="summaries")
|
560 |
+
|
561 |
+
if do_summaries == "Yes":
|
562 |
+
for summary_no in summary_loop:
|
563 |
+
|
564 |
+
print("Current summary number is:", summary_no)
|
565 |
+
|
566 |
+
summary_text = all_summaries[summary_no]
|
567 |
+
#print("summary_text:", summary_text)
|
568 |
+
formatted_summary_prompt = [summarise_topic_descriptions_prompt.format(summaries=summary_text, summary_format=summarise_format_radio)]
|
569 |
+
|
570 |
+
try:
|
571 |
+
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, summarise_topic_descriptions_system_prompt, local_model)
|
572 |
+
summarised_output = response
|
573 |
+
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
574 |
+
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|
575 |
+
summarised_output = summarised_output.strip()
|
576 |
+
except Exception as e:
|
577 |
+
print(e)
|
578 |
+
summarised_output = ""
|
579 |
+
|
580 |
+
summarised_outputs.append(summarised_output)
|
581 |
+
out_metadata.extend(metadata)
|
582 |
+
out_metadata_str = '. '.join(out_metadata)
|
583 |
+
|
584 |
+
latest_summary_completed += 1
|
585 |
+
|
586 |
+
# Check if beyond max time allowed for processing and break if necessary
|
587 |
+
toc = time.perf_counter()
|
588 |
+
time_taken = tic - toc
|
589 |
+
|
590 |
+
if time_taken > max_time_for_loop:
|
591 |
+
print("Time taken for loop is greater than maximum time allowed. Exiting and restarting loop")
|
592 |
+
summary_loop.close()
|
593 |
+
tqdm._instances.clear()
|
594 |
+
break
|
595 |
+
|
596 |
+
# If all summaries completeed
|
597 |
+
if latest_summary_completed >= length_all_summaries:
|
598 |
+
print("At last summary.")
|
599 |
+
|
600 |
+
output_files = list(set(output_files))
|
601 |
+
|
602 |
+
return summarised_references, unique_table_df, reference_table_df, output_files, summarised_outputs, latest_summary_completed, out_metadata_str, summarised_output_markdown, log_output_files
|
tools/helper_functions.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
import os
|
2 |
import re
|
|
|
3 |
import gradio as gr
|
4 |
import pandas as pd
|
|
|
5 |
from typing import List
|
6 |
import math
|
|
|
|
|
7 |
|
8 |
def empty_output_vars_extract_topics():
|
9 |
# Empty output objects before processing a new file
|
@@ -37,7 +41,6 @@ def empty_output_vars_summarise():
|
|
37 |
|
38 |
return summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox
|
39 |
|
40 |
-
|
41 |
def get_or_create_env_var(var_name, default_value):
|
42 |
# Get the environment variable if it exists
|
43 |
value = os.environ.get(var_name)
|
@@ -49,45 +52,6 @@ def get_or_create_env_var(var_name, default_value):
|
|
49 |
|
50 |
return value
|
51 |
|
52 |
-
RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "1")
|
53 |
-
print(f'The value of RUN_AWS_FUNCTIONS is {RUN_AWS_FUNCTIONS}')
|
54 |
-
|
55 |
-
RUN_LOCAL_MODEL = get_or_create_env_var("RUN_LOCAL_MODEL", "1")
|
56 |
-
print(f'The value of RUN_LOCAL_MODEL is {RUN_LOCAL_MODEL}')
|
57 |
-
|
58 |
-
RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
|
59 |
-
print(f'The value of RUN_GEMINI_MODELS is {RUN_GEMINI_MODELS}')
|
60 |
-
|
61 |
-
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
62 |
-
|
63 |
-
# Build up options for models
|
64 |
-
model_full_names = []
|
65 |
-
model_short_names = []
|
66 |
-
|
67 |
-
if RUN_LOCAL_MODEL == "1":
|
68 |
-
model_full_names.append("gemma_2b_it_local")
|
69 |
-
model_short_names.append("gemma_local")
|
70 |
-
|
71 |
-
if RUN_AWS_FUNCTIONS == "1":
|
72 |
-
model_full_names.extend(["anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"])
|
73 |
-
model_short_names.extend(["haiku", "sonnet"])
|
74 |
-
|
75 |
-
if RUN_GEMINI_MODELS == "1":
|
76 |
-
model_full_names.extend(["gemini-2.0-flash-001", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-exp-05-06" ]) # , # Gemini pro No longer available on free tier
|
77 |
-
model_short_names.extend(["gemini_flash_2", "gemini_flash_2.5", "gemini_pro"])
|
78 |
-
|
79 |
-
print("model_short_names:", model_short_names)
|
80 |
-
print("model_full_names:", model_full_names)
|
81 |
-
|
82 |
-
model_name_map = {short: full for short, full in zip(model_full_names, model_short_names)}
|
83 |
-
|
84 |
-
# Retrieving or setting output folder
|
85 |
-
env_var_name = 'GRADIO_OUTPUT_FOLDER'
|
86 |
-
default_value = 'output/'
|
87 |
-
|
88 |
-
output_folder = get_or_create_env_var(env_var_name, default_value)
|
89 |
-
print(f'The value of {env_var_name} is {output_folder}')
|
90 |
-
|
91 |
def get_file_path_with_extension(file_path):
|
92 |
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
93 |
basename = os.path.basename(file_path)
|
@@ -222,7 +186,7 @@ def load_in_previous_reference_file(file:str):
|
|
222 |
|
223 |
return reference_file_data, reference_file_name
|
224 |
|
225 |
-
def join_cols_onto_reference_df(reference_df:pd.DataFrame, original_data_df:pd.DataFrame, join_columns:List[str], original_file_name:str, output_folder:str=
|
226 |
|
227 |
#print("original_data_df columns:", original_data_df.columns)
|
228 |
#print("original_data_df:", original_data_df)
|
@@ -246,6 +210,75 @@ def join_cols_onto_reference_df(reference_df:pd.DataFrame, original_data_df:pd.D
|
|
246 |
|
247 |
return out_reference_df, file_data_outputs
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
# Wrap text in each column to the specified max width, including whole words
|
250 |
def wrap_text(text:str, max_width=60, max_text_length=None):
|
251 |
if not isinstance(text, str):
|
@@ -332,7 +365,7 @@ def wrap_text(text:str, max_width=60, max_text_length=None):
|
|
332 |
|
333 |
return '<br>'.join(wrapped_lines)
|
334 |
|
335 |
-
def initial_clean(text):
|
336 |
#### Some of my cleaning functions
|
337 |
html_pattern_regex = r'<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});|\xa0| '
|
338 |
html_start_pattern_end_dots_regex = r'<(.*?)\.\.'
|
@@ -445,7 +478,7 @@ def add_folder_to_path(folder_path: str):
|
|
445 |
def reveal_feedback_buttons():
|
446 |
return gr.Radio(visible=True), gr.Textbox(visible=True), gr.Button(visible=True), gr.Markdown(visible=True)
|
447 |
|
448 |
-
def wipe_logs(feedback_logs_loc, usage_logs_loc):
|
449 |
try:
|
450 |
os.remove(feedback_logs_loc)
|
451 |
except Exception as e:
|
@@ -454,65 +487,67 @@ def wipe_logs(feedback_logs_loc, usage_logs_loc):
|
|
454 |
os.remove(usage_logs_loc)
|
455 |
except Exception as e:
|
456 |
print("Could not remove usage logs file", e)
|
457 |
-
|
458 |
-
async def get_connection_params(request: gr.Request
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
# print("Request headers dictionary:", request.headers)
|
471 |
-
# print("All host elements", request.client)
|
472 |
-
# print("IP address:", request.client.host)
|
473 |
-
# print("Query parameters:", dict(request.query_params))
|
474 |
-
# To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
|
475 |
-
#print("Request dictionary to object:", request.request.body())
|
476 |
-
print("Session hash:", request.session_hash)
|
477 |
-
|
478 |
-
# Retrieving or setting CUSTOM_CLOUDFRONT_HEADER
|
479 |
-
CUSTOM_CLOUDFRONT_HEADER_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER', '')
|
480 |
-
#print(f'The value of CUSTOM_CLOUDFRONT_HEADER is {CUSTOM_CLOUDFRONT_HEADER_var}')
|
481 |
-
|
482 |
-
# Retrieving or setting CUSTOM_CLOUDFRONT_HEADER_VALUE
|
483 |
-
CUSTOM_CLOUDFRONT_HEADER_VALUE_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER_VALUE', '')
|
484 |
-
#print(f'The value of CUSTOM_CLOUDFRONT_HEADER_VALUE_var is {CUSTOM_CLOUDFRONT_HEADER_VALUE_var}')
|
485 |
-
|
486 |
-
if CUSTOM_CLOUDFRONT_HEADER_var and CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
|
487 |
-
if CUSTOM_CLOUDFRONT_HEADER_var in request.headers:
|
488 |
-
supplied_cloudfront_custom_value = request.headers[CUSTOM_CLOUDFRONT_HEADER_var]
|
489 |
-
if supplied_cloudfront_custom_value == CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
|
490 |
-
print("Custom Cloudfront header found:", supplied_cloudfront_custom_value)
|
491 |
else:
|
492 |
-
|
|
|
|
|
|
|
|
|
493 |
|
494 |
-
|
495 |
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
print("Request username found:", out_session_hash)
|
500 |
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
print("Cognito ID found:", out_session_hash)
|
505 |
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
|
511 |
-
|
512 |
-
|
513 |
-
|
|
|
514 |
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
else:
|
517 |
-
|
518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
+
import boto3
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
from typing import List
|
8 |
import math
|
9 |
+
from botocore.exceptions import ClientError
|
10 |
+
from tools.config import OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, CUSTOM_HEADER, CUSTOM_HEADER_VALUE, AWS_USER_POOL_ID
|
11 |
|
12 |
def empty_output_vars_extract_topics():
|
13 |
# Empty output objects before processing a new file
|
|
|
41 |
|
42 |
return summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox
|
43 |
|
|
|
44 |
def get_or_create_env_var(var_name, default_value):
|
45 |
# Get the environment variable if it exists
|
46 |
value = os.environ.get(var_name)
|
|
|
52 |
|
53 |
return value
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def get_file_path_with_extension(file_path):
|
56 |
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
57 |
basename = os.path.basename(file_path)
|
|
|
186 |
|
187 |
return reference_file_data, reference_file_name
|
188 |
|
189 |
+
def join_cols_onto_reference_df(reference_df:pd.DataFrame, original_data_df:pd.DataFrame, join_columns:List[str], original_file_name:str, output_folder:str=OUTPUT_FOLDER):
|
190 |
|
191 |
#print("original_data_df columns:", original_data_df.columns)
|
192 |
#print("original_data_df:", original_data_df)
|
|
|
210 |
|
211 |
return out_reference_df, file_data_outputs
|
212 |
|
213 |
+
|
214 |
+
def get_basic_response_data(file_data:pd.DataFrame, chosen_cols:List[str], verify_titles:bool=False) -> pd.DataFrame:
|
215 |
+
|
216 |
+
if not isinstance(chosen_cols, list):
|
217 |
+
chosen_cols = [chosen_cols]
|
218 |
+
else:
|
219 |
+
chosen_cols = chosen_cols
|
220 |
+
|
221 |
+
basic_response_data = file_data[chosen_cols].reset_index(names="Reference")
|
222 |
+
basic_response_data["Reference"] = basic_response_data["Reference"].astype(int) + 1
|
223 |
+
|
224 |
+
if verify_titles == True:
|
225 |
+
basic_response_data = basic_response_data.rename(columns={chosen_cols[0]: "Response", chosen_cols[1]: "Title"})
|
226 |
+
basic_response_data["Title"] = basic_response_data["Title"].str.strip()
|
227 |
+
basic_response_data["Title"] = basic_response_data["Title"].apply(initial_clean)
|
228 |
+
else:
|
229 |
+
basic_response_data = basic_response_data.rename(columns={chosen_cols[0]: "Response"})
|
230 |
+
|
231 |
+
basic_response_data["Response"] = basic_response_data["Response"].str.strip()
|
232 |
+
basic_response_data["Response"] = basic_response_data["Response"].apply(initial_clean)
|
233 |
+
|
234 |
+
return basic_response_data
|
235 |
+
|
236 |
+
def convert_reference_table_to_pivot_table(df:pd.DataFrame, basic_response_data:pd.DataFrame=pd.DataFrame()):
|
237 |
+
|
238 |
+
df_in = df[['Response References', 'General Topic', 'Subtopic', 'Sentiment']].copy()
|
239 |
+
|
240 |
+
df_in['Response References'] = df_in['Response References'].astype(int)
|
241 |
+
|
242 |
+
# Create a combined category column
|
243 |
+
df_in['Category'] = df_in['General Topic'] + ' - ' + df_in['Subtopic'] + ' - ' + df_in['Sentiment']
|
244 |
+
|
245 |
+
# Create pivot table counting occurrences of each unique combination
|
246 |
+
pivot_table = pd.crosstab(
|
247 |
+
index=df_in['Response References'],
|
248 |
+
columns=[df_in['General Topic'], df_in['Subtopic'], df_in['Sentiment']],
|
249 |
+
margins=True
|
250 |
+
)
|
251 |
+
|
252 |
+
# Flatten column names to make them more readable
|
253 |
+
pivot_table.columns = [' - '.join(col) for col in pivot_table.columns]
|
254 |
+
|
255 |
+
pivot_table.reset_index(inplace=True)
|
256 |
+
|
257 |
+
if not basic_response_data.empty:
|
258 |
+
pivot_table = basic_response_data.merge(pivot_table, right_on="Response References", left_on="Reference", how="left")
|
259 |
+
|
260 |
+
pivot_table.drop("Response References", axis=1, inplace=True)
|
261 |
+
|
262 |
+
pivot_table.columns = pivot_table.columns.str.replace("Not assessed - ", "").str.replace("- Not assessed", "")
|
263 |
+
|
264 |
+
return pivot_table
|
265 |
+
|
266 |
+
def create_unique_table_df_from_reference_table(reference_df:pd.DataFrame):
|
267 |
+
|
268 |
+
out_unique_topics_df = (reference_df.groupby(["General Topic", "Subtopic", "Sentiment"])
|
269 |
+
.agg({
|
270 |
+
'Response References': 'size', # Count the number of references
|
271 |
+
'Summary': lambda x: '<br>'.join(
|
272 |
+
sorted(set(x), key=lambda summary: reference_df.loc[reference_df['Summary'] == summary, 'Start row of group'].min())
|
273 |
+
)
|
274 |
+
})
|
275 |
+
.reset_index()
|
276 |
+
.sort_values('Response References', ascending=False) # Sort by size, biggest first
|
277 |
+
.assign(Topic_number=lambda df: np.arange(1, len(df) + 1)) # Add numbering 1 to x
|
278 |
+
)
|
279 |
+
|
280 |
+
return out_unique_topics_df
|
281 |
+
|
282 |
# Wrap text in each column to the specified max width, including whole words
|
283 |
def wrap_text(text:str, max_width=60, max_text_length=None):
|
284 |
if not isinstance(text, str):
|
|
|
365 |
|
366 |
return '<br>'.join(wrapped_lines)
|
367 |
|
368 |
+
def initial_clean(text:str):
|
369 |
#### Some of my cleaning functions
|
370 |
html_pattern_regex = r'<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});|\xa0| '
|
371 |
html_start_pattern_end_dots_regex = r'<(.*?)\.\.'
|
|
|
478 |
def reveal_feedback_buttons():
|
479 |
return gr.Radio(visible=True), gr.Textbox(visible=True), gr.Button(visible=True), gr.Markdown(visible=True)
|
480 |
|
481 |
+
def wipe_logs(feedback_logs_loc:str, usage_logs_loc:str):
|
482 |
try:
|
483 |
os.remove(feedback_logs_loc)
|
484 |
except Exception as e:
|
|
|
487 |
os.remove(usage_logs_loc)
|
488 |
except Exception as e:
|
489 |
print("Could not remove usage logs file", e)
|
490 |
+
|
491 |
+
async def get_connection_params(request: gr.Request,
|
492 |
+
output_folder_textbox:str=OUTPUT_FOLDER,
|
493 |
+
input_folder_textbox:str=INPUT_FOLDER,
|
494 |
+
session_output_folder:str=SESSION_OUTPUT_FOLDER):
|
495 |
+
|
496 |
+
#print("Session hash:", request.session_hash)
|
497 |
+
|
498 |
+
if CUSTOM_HEADER and CUSTOM_HEADER_VALUE:
|
499 |
+
if CUSTOM_HEADER in request.headers:
|
500 |
+
supplied_custom_header_value = request.headers[CUSTOM_HEADER]
|
501 |
+
if supplied_custom_header_value == CUSTOM_HEADER_VALUE:
|
502 |
+
print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
else:
|
504 |
+
print("Custom header value does not match expected value.")
|
505 |
+
raise ValueError("Custom header value does not match expected value.")
|
506 |
+
else:
|
507 |
+
print("Custom header value not found.")
|
508 |
+
raise ValueError("Custom header value not found.")
|
509 |
|
510 |
+
# Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
|
511 |
|
512 |
+
if request.username:
|
513 |
+
out_session_hash = request.username
|
514 |
+
#print("Request username found:", out_session_hash)
|
|
|
515 |
|
516 |
+
elif 'x-cognito-id' in request.headers:
|
517 |
+
out_session_hash = request.headers['x-cognito-id']
|
518 |
+
#print("Cognito ID found:", out_session_hash)
|
|
|
519 |
|
520 |
+
elif 'x-amzn-oidc-identity' in request.headers:
|
521 |
+
out_session_hash = request.headers['x-amzn-oidc-identity']
|
522 |
+
|
523 |
+
# Fetch email address using Cognito client
|
524 |
+
cognito_client = boto3.client('cognito-idp')
|
525 |
+
try:
|
526 |
+
response = cognito_client.admin_get_user(
|
527 |
+
UserPoolId=AWS_USER_POOL_ID, # Replace with your User Pool ID
|
528 |
+
Username=out_session_hash
|
529 |
+
)
|
530 |
+
email = next(attr['Value'] for attr in response['UserAttributes'] if attr['Name'] == 'email')
|
531 |
+
#print("Email address found:", email)
|
532 |
|
533 |
+
out_session_hash = email
|
534 |
+
except ClientError as e:
|
535 |
+
print("Error fetching user details:", e)
|
536 |
+
email = None
|
537 |
|
538 |
+
print("Cognito ID found:", out_session_hash)
|
539 |
+
|
540 |
+
else:
|
541 |
+
out_session_hash = request.session_hash
|
542 |
+
|
543 |
+
if session_output_folder == 'True':
|
544 |
+
output_folder = output_folder_textbox + out_session_hash + "/"
|
545 |
+
input_folder = input_folder_textbox + out_session_hash + "/"
|
546 |
else:
|
547 |
+
output_folder = output_folder_textbox
|
548 |
+
input_folder = input_folder_textbox
|
549 |
+
|
550 |
+
if not os.path.exists(output_folder): os.mkdir(output_folder)
|
551 |
+
if not os.path.exists(input_folder): os.mkdir(input_folder)
|
552 |
+
|
553 |
+
return out_session_hash, output_folder, out_session_hash, input_folder
|
tools/llm_api_call.py
CHANGED
@@ -6,51 +6,41 @@ import gradio as gr
|
|
6 |
import markdown
|
7 |
import time
|
8 |
import boto3
|
9 |
-
import json
|
10 |
-
import math
|
11 |
import string
|
12 |
import re
|
13 |
import spaces
|
14 |
-
from rapidfuzz import process, fuzz
|
15 |
from tqdm import tqdm
|
|
|
16 |
from gradio import Progress
|
17 |
from typing import List, Tuple
|
18 |
from io import StringIO
|
19 |
|
20 |
GradioFileData = gr.FileData
|
21 |
|
22 |
-
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt,
|
23 |
-
from tools.helper_functions import
|
24 |
-
from tools.
|
25 |
-
|
26 |
-
# ResponseObject class for AWS Bedrock calls
|
27 |
-
class ResponseObject:
|
28 |
-
def __init__(self, text, usage_metadata):
|
29 |
-
self.text = text
|
30 |
-
self.usage_metadata = usage_metadata
|
31 |
-
|
32 |
-
max_tokens = 4096 # Maximum number of output tokens
|
33 |
-
timeout_wait = 30 # AWS now seems to have a 60 second minimum wait between API calls
|
34 |
-
number_of_api_retry_attempts = 5
|
35 |
-
# Try up to 3 times to get a valid markdown table response with LLM calls, otherwise retry with temperature changed
|
36 |
-
MAX_OUTPUT_VALIDATION_ATTEMPTS = 3
|
37 |
-
max_time_for_loop = 99999
|
38 |
-
batch_size_default = 5
|
39 |
-
deduplication_threshold = 90
|
40 |
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
bedrock_runtime =
|
50 |
|
51 |
### HELPER FUNCTIONS
|
52 |
|
53 |
-
def normalise_string(text):
|
54 |
# Replace two or more dashes with a single dash
|
55 |
text = re.sub(r'-{2,}', '-', text)
|
56 |
|
@@ -130,30 +120,7 @@ def load_in_previous_data_files(file_paths_partial_output:List[str], for_modifie
|
|
130 |
|
131 |
return gr.Dataframe(value=unique_file_data, headers=None, col_count=(unique_file_data.shape[1], "fixed"), row_count = (unique_file_data.shape[0], "fixed"), visible=True, type="pandas"), reference_file_data, unique_file_data, reference_file_name, unique_file_name, out_file_names
|
132 |
|
133 |
-
|
134 |
-
def get_basic_response_data(file_data:pd.DataFrame, chosen_cols:List[str], verify_titles:bool=False) -> pd.DataFrame:
|
135 |
-
|
136 |
-
if not isinstance(chosen_cols, list):
|
137 |
-
chosen_cols = [chosen_cols]
|
138 |
-
else:
|
139 |
-
chosen_cols = chosen_cols
|
140 |
-
|
141 |
-
basic_response_data = file_data[chosen_cols].reset_index(names="Reference")
|
142 |
-
basic_response_data["Reference"] = basic_response_data["Reference"].astype(int) + 1
|
143 |
-
|
144 |
-
if verify_titles == True:
|
145 |
-
basic_response_data = basic_response_data.rename(columns={chosen_cols[0]: "Response", chosen_cols[1]: "Title"})
|
146 |
-
basic_response_data["Title"] = basic_response_data["Title"].str.strip()
|
147 |
-
basic_response_data["Title"] = basic_response_data["Title"].apply(initial_clean)
|
148 |
-
else:
|
149 |
-
basic_response_data = basic_response_data.rename(columns={chosen_cols[0]: "Response"})
|
150 |
-
|
151 |
-
basic_response_data["Response"] = basic_response_data["Response"].str.strip()
|
152 |
-
basic_response_data["Response"] = basic_response_data["Response"].apply(initial_clean)
|
153 |
-
|
154 |
-
return basic_response_data
|
155 |
-
|
156 |
-
def data_file_to_markdown_table(file_data:pd.DataFrame, file_name:str, chosen_cols: List[str], output_folder: str, batch_number: int, batch_size: int, verify_titles:bool=False) -> Tuple[str, str, str]:
|
157 |
"""
|
158 |
Processes a file by simplifying its content based on chosen columns and saves the result to a specified output folder.
|
159 |
|
@@ -161,7 +128,6 @@ def data_file_to_markdown_table(file_data:pd.DataFrame, file_name:str, chosen_co
|
|
161 |
- file_data (pd.DataFrame): Tabular data file with responses.
|
162 |
- file_name (str): File name with extension.
|
163 |
- chosen_cols (List[str]): A list of column names to include in the simplified file.
|
164 |
-
- output_folder (str): The directory where the simplified file will be saved.
|
165 |
- batch_number (int): The current batch number for processing.
|
166 |
- batch_size (int): The number of rows to process in each batch.
|
167 |
|
@@ -230,309 +196,6 @@ def replace_punctuation_with_underscore(input_string):
|
|
230 |
# Translate the input string using the translation table
|
231 |
return input_string.translate(translation_table)
|
232 |
|
233 |
-
### LLM FUNCTIONS
|
234 |
-
|
235 |
-
def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int) -> Tuple[object, dict]:
|
236 |
-
"""
|
237 |
-
Constructs a GenerativeModel for Gemini API calls.
|
238 |
-
|
239 |
-
Parameters:
|
240 |
-
- in_api_key (str): The API key for authentication.
|
241 |
-
- temperature (float): The temperature parameter for the model, controlling the randomness of the output.
|
242 |
-
- model_choice (str): The choice of model to use for generation.
|
243 |
-
- system_prompt (str): The system prompt to guide the generation.
|
244 |
-
- max_tokens (int): The maximum number of tokens to generate.
|
245 |
-
|
246 |
-
Returns:
|
247 |
-
- Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
|
248 |
-
"""
|
249 |
-
# Construct a GenerativeModel
|
250 |
-
try:
|
251 |
-
if in_api_key:
|
252 |
-
#print("Getting API key from textbox")
|
253 |
-
api_key = in_api_key
|
254 |
-
ai.configure(api_key=api_key)
|
255 |
-
elif "GOOGLE_API_KEY" in os.environ:
|
256 |
-
#print("Searching for API key in environmental variables")
|
257 |
-
api_key = os.environ["GOOGLE_API_KEY"]
|
258 |
-
ai.configure(api_key=api_key)
|
259 |
-
else:
|
260 |
-
print("No API key foound")
|
261 |
-
raise gr.Error("No API key found.")
|
262 |
-
except Exception as e:
|
263 |
-
print(e)
|
264 |
-
|
265 |
-
config = ai.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens)
|
266 |
-
|
267 |
-
#model = ai.GenerativeModel.from_cached_content(cached_content=cache, generation_config=config)
|
268 |
-
model = ai.GenerativeModel(model_name='models/' + model_choice, system_instruction=system_prompt, generation_config=config)
|
269 |
-
|
270 |
-
# Upload CSV file (replace with your actual file path)
|
271 |
-
#file_id = ai.upload_file(upload_file_path)
|
272 |
-
|
273 |
-
|
274 |
-
# if file_type == 'xlsx':
|
275 |
-
# print("Running through all xlsx sheets")
|
276 |
-
# #anon_xlsx = pd.ExcelFile(upload_file_path)
|
277 |
-
# if not in_excel_sheets:
|
278 |
-
# out_message.append("No Excel sheets selected. Please select at least one to anonymise.")
|
279 |
-
# continue
|
280 |
-
|
281 |
-
# anon_xlsx = pd.ExcelFile(upload_file_path)
|
282 |
-
|
283 |
-
# # Create xlsx file:
|
284 |
-
# anon_xlsx_export_file_name = output_folder + file_name + "_redacted.xlsx"
|
285 |
-
|
286 |
-
|
287 |
-
### QUERYING LARGE LANGUAGE MODEL ###
|
288 |
-
# Prompt caching the table and system prompt. See here: https://ai.google.dev/gemini-api/docs/caching?lang=python
|
289 |
-
# Create a cache with a 5 minute TTL. ONLY FOR CACHES OF AT LEAST 32k TOKENS!
|
290 |
-
# cache = ai.caching.CachedContent.create(
|
291 |
-
# model='models/' + model_choice,
|
292 |
-
# display_name=file_name, # used to identify the cache
|
293 |
-
# system_instruction=system_prompt,
|
294 |
-
# ttl=datetime.timedelta(minutes=5),
|
295 |
-
# )
|
296 |
-
|
297 |
-
return model, config
|
298 |
-
|
299 |
-
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
|
300 |
-
"""
|
301 |
-
This function sends a request to AWS Claude with the following parameters:
|
302 |
-
- prompt: The user's input prompt to be processed by the model.
|
303 |
-
- system_prompt: A system-defined prompt that provides context or instructions for the model.
|
304 |
-
- temperature: A value that controls the randomness of the model's output, with higher values resulting in more diverse responses.
|
305 |
-
- max_tokens: The maximum number of tokens (words or characters) in the model's response.
|
306 |
-
- model_choice: The specific model to use for processing the request.
|
307 |
-
|
308 |
-
The function constructs the request configuration, invokes the model, extracts the response text, and returns a ResponseObject containing the text and metadata.
|
309 |
-
"""
|
310 |
-
|
311 |
-
prompt_config = {
|
312 |
-
"anthropic_version": "bedrock-2023-05-31",
|
313 |
-
"max_tokens": max_tokens,
|
314 |
-
"top_p": 0.999,
|
315 |
-
"temperature":temperature,
|
316 |
-
"system": system_prompt,
|
317 |
-
"messages": [
|
318 |
-
{
|
319 |
-
"role": "user",
|
320 |
-
"content": [
|
321 |
-
{"type": "text", "text": prompt},
|
322 |
-
],
|
323 |
-
}
|
324 |
-
],
|
325 |
-
}
|
326 |
-
|
327 |
-
body = json.dumps(prompt_config)
|
328 |
-
|
329 |
-
modelId = model_choice
|
330 |
-
accept = "application/json"
|
331 |
-
contentType = "application/json"
|
332 |
-
|
333 |
-
request = bedrock_runtime.invoke_model(
|
334 |
-
body=body, modelId=modelId, accept=accept, contentType=contentType
|
335 |
-
)
|
336 |
-
|
337 |
-
# Extract text from request
|
338 |
-
response_body = json.loads(request.get("body").read())
|
339 |
-
text = response_body.get("content")[0].get("text")
|
340 |
-
|
341 |
-
response = ResponseObject(
|
342 |
-
text=text,
|
343 |
-
usage_metadata=request['ResponseMetadata']
|
344 |
-
)
|
345 |
-
|
346 |
-
# Now you can access both the text and metadata
|
347 |
-
#print("Text:", response.text)
|
348 |
-
#print("Metadata:", response.usage_metadata)
|
349 |
-
#print("Text:", response.text)
|
350 |
-
|
351 |
-
return response
|
352 |
-
|
353 |
-
# Function to send a request and update history
|
354 |
-
def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, local_model=[], progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
355 |
-
"""
|
356 |
-
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
357 |
-
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
358 |
-
If the model choice is specific to AWS Claude, it calls the `call_aws_claude` function; otherwise, it uses the `model.generate_content` method.
|
359 |
-
The function returns the response text and the updated conversation history.
|
360 |
-
"""
|
361 |
-
# Constructing the full prompt from the conversation history
|
362 |
-
full_prompt = "Conversation history:\n"
|
363 |
-
|
364 |
-
for entry in conversation_history:
|
365 |
-
role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'parts'
|
366 |
-
message = ' '.join(entry['parts']) # Combining all parts of the message
|
367 |
-
full_prompt += f"{role}: {message}\n"
|
368 |
-
|
369 |
-
# Adding the new user prompt
|
370 |
-
full_prompt += f"\nUser: {prompt}"
|
371 |
-
|
372 |
-
# Clear any existing progress bars
|
373 |
-
tqdm._instances.clear()
|
374 |
-
|
375 |
-
progress_bar = range(0,number_of_api_retry_attempts)
|
376 |
-
|
377 |
-
# Generate the model's response
|
378 |
-
if "gemini" in model_choice:
|
379 |
-
|
380 |
-
for i in progress_bar:
|
381 |
-
try:
|
382 |
-
print("Calling Gemini model, attempt", i + 1)
|
383 |
-
#print("full_prompt:", full_prompt)
|
384 |
-
#print("generation_config:", config)
|
385 |
-
|
386 |
-
response = model.generate_content(contents=full_prompt, generation_config=config)
|
387 |
-
|
388 |
-
#progress_bar.close()
|
389 |
-
#tqdm._instances.clear()
|
390 |
-
|
391 |
-
print("Successful call to Gemini model.")
|
392 |
-
break
|
393 |
-
except Exception as e:
|
394 |
-
# If fails, try again after X seconds in case there is a throttle limit
|
395 |
-
print("Call to Gemini model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
396 |
-
|
397 |
-
time.sleep(timeout_wait)
|
398 |
-
|
399 |
-
if i == number_of_api_retry_attempts:
|
400 |
-
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
401 |
-
elif "anthropic.claude" in model_choice:
|
402 |
-
for i in progress_bar:
|
403 |
-
try:
|
404 |
-
print("Calling AWS Claude model, attempt", i + 1)
|
405 |
-
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
406 |
-
|
407 |
-
#progress_bar.close()
|
408 |
-
#tqdm._instances.clear()
|
409 |
-
|
410 |
-
print("Successful call to Claude model.")
|
411 |
-
break
|
412 |
-
except Exception as e:
|
413 |
-
# If fails, try again after X seconds in case there is a throttle limit
|
414 |
-
print("Call to Claude model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
415 |
-
|
416 |
-
time.sleep(timeout_wait)
|
417 |
-
#response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
418 |
-
|
419 |
-
if i == number_of_api_retry_attempts:
|
420 |
-
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
421 |
-
else:
|
422 |
-
# This is the Gemma model
|
423 |
-
for i in progress_bar:
|
424 |
-
try:
|
425 |
-
print("Calling Gemma 2B Instruct model, attempt", i + 1)
|
426 |
-
|
427 |
-
gen_config = LlamaCPPGenerationConfig()
|
428 |
-
gen_config.update_temp(temperature)
|
429 |
-
|
430 |
-
response = call_llama_cpp_model(prompt, gen_config, model=local_model)
|
431 |
-
|
432 |
-
#progress_bar.close()
|
433 |
-
#tqdm._instances.clear()
|
434 |
-
|
435 |
-
print("Successful call to Gemma model.")
|
436 |
-
print("Response:", response)
|
437 |
-
break
|
438 |
-
except Exception as e:
|
439 |
-
# If fails, try again after X seconds in case there is a throttle limit
|
440 |
-
print("Call to Gemma model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
441 |
-
|
442 |
-
time.sleep(timeout_wait)
|
443 |
-
#response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
444 |
-
|
445 |
-
if i == number_of_api_retry_attempts:
|
446 |
-
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
447 |
-
|
448 |
-
# Update the conversation history with the new prompt and response
|
449 |
-
conversation_history.append({'role': 'user', 'parts': [prompt]})
|
450 |
-
|
451 |
-
# Check if is a LLama.cpp model response
|
452 |
-
# Check if the response is a ResponseObject
|
453 |
-
if isinstance(response, ResponseObject):
|
454 |
-
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
455 |
-
elif 'choices' in response:
|
456 |
-
conversation_history.append({'role': 'assistant', 'parts': [response['choices'][0]['text']]})
|
457 |
-
else:
|
458 |
-
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
459 |
-
|
460 |
-
# Print the updated conversation history
|
461 |
-
#print("conversation_history:", conversation_history)
|
462 |
-
|
463 |
-
return response, conversation_history
|
464 |
-
|
465 |
-
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, local_model = [], master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
466 |
-
"""
|
467 |
-
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
468 |
-
|
469 |
-
Args:
|
470 |
-
prompts (List[str]): A list of prompts to be processed.
|
471 |
-
system_prompt (str): The system prompt.
|
472 |
-
conversation_history (List[dict]): The history of the conversation.
|
473 |
-
whole_conversation (List[str]): The complete conversation including prompts and responses.
|
474 |
-
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
475 |
-
model (object): The model to use for processing the prompts.
|
476 |
-
config (dict): Configuration for the model.
|
477 |
-
model_choice (str): The choice of model to use.
|
478 |
-
temperature (float): The temperature parameter for the model.
|
479 |
-
batch_no (int): Batch number of the large language model request.
|
480 |
-
local_model: Local gguf model (if loaded)
|
481 |
-
master (bool): Is this request for the master table.
|
482 |
-
|
483 |
-
Returns:
|
484 |
-
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
|
485 |
-
"""
|
486 |
-
responses = []
|
487 |
-
|
488 |
-
# Clear any existing progress bars
|
489 |
-
tqdm._instances.clear()
|
490 |
-
|
491 |
-
for prompt in prompts:
|
492 |
-
|
493 |
-
#print("prompt to LLM:", prompt)
|
494 |
-
|
495 |
-
response, conversation_history = send_request(prompt, conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model)
|
496 |
-
|
497 |
-
if isinstance(response, ResponseObject):
|
498 |
-
response_text = response.text
|
499 |
-
elif 'choices' in response:
|
500 |
-
response_text = response['choices'][0]['text']
|
501 |
-
else:
|
502 |
-
response_text = response.text
|
503 |
-
|
504 |
-
responses.append(response)
|
505 |
-
whole_conversation.append(prompt)
|
506 |
-
whole_conversation.append(response_text)
|
507 |
-
|
508 |
-
# Create conversation metadata
|
509 |
-
if master == False:
|
510 |
-
whole_conversation_metadata.append(f"Query batch {batch_no} prompt {len(responses)} metadata:")
|
511 |
-
else:
|
512 |
-
whole_conversation_metadata.append(f"Query summary metadata:")
|
513 |
-
|
514 |
-
if not isinstance(response, str):
|
515 |
-
try:
|
516 |
-
print("model_choice:", model_choice)
|
517 |
-
if "claude" in model_choice:
|
518 |
-
print("Appending selected metadata items to metadata")
|
519 |
-
whole_conversation_metadata.append('x-amzn-bedrock-output-token-count:')
|
520 |
-
whole_conversation_metadata.append(str(response.usage_metadata['HTTPHeaders']['x-amzn-bedrock-output-token-count']))
|
521 |
-
whole_conversation_metadata.append('x-amzn-bedrock-input-token-count:')
|
522 |
-
whole_conversation_metadata.append(str(response.usage_metadata['HTTPHeaders']['x-amzn-bedrock-input-token-count']))
|
523 |
-
elif "gemini" in model_choice:
|
524 |
-
whole_conversation_metadata.append(str(response.usage_metadata))
|
525 |
-
else:
|
526 |
-
whole_conversation_metadata.append(str(response['usage']))
|
527 |
-
except KeyError as e:
|
528 |
-
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
529 |
-
else:
|
530 |
-
print("Response is a string object.")
|
531 |
-
whole_conversation_metadata.append("Length prompt: " + str(len(prompt)) + ". Length response: " + str(len(response)))
|
532 |
-
|
533 |
-
|
534 |
-
return responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text
|
535 |
-
|
536 |
### INITIAL TOPIC MODEL DEVELOPMENT FUNCTIONS
|
537 |
|
538 |
def clean_markdown_table(text: str):
|
@@ -620,22 +283,6 @@ def clean_column_name(column_name, max_length=20):
|
|
620 |
# Truncate to max_length
|
621 |
return column_name[:max_length]
|
622 |
|
623 |
-
def create_unique_table_df_from_reference_table(reference_df:pd.DataFrame):
|
624 |
-
|
625 |
-
out_unique_topics_df = (reference_df.groupby(["General Topic", "Subtopic", "Sentiment"])
|
626 |
-
.agg({
|
627 |
-
'Response References': 'size', # Count the number of references
|
628 |
-
'Summary': lambda x: '<br>'.join(
|
629 |
-
sorted(set(x), key=lambda summary: reference_df.loc[reference_df['Summary'] == summary, 'Start row of group'].min())
|
630 |
-
)
|
631 |
-
})
|
632 |
-
.reset_index()
|
633 |
-
.sort_values('Response References', ascending=False) # Sort by size, biggest first
|
634 |
-
.assign(Topic_number=lambda df: np.arange(1, len(df) + 1)) # Add numbering 1 to x
|
635 |
-
)
|
636 |
-
|
637 |
-
return out_unique_topics_df
|
638 |
-
|
639 |
# Convert output table to markdown and then to a pandas dataframe to csv
|
640 |
def remove_before_last_term(input_string: str) -> str:
|
641 |
# Use regex to find the last occurrence of the term
|
@@ -754,8 +401,9 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
754 |
temperature: float,
|
755 |
reported_batch_no: int,
|
756 |
local_model: object,
|
757 |
-
MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
|
758 |
-
master:bool=False
|
|
|
759 |
"""
|
760 |
Call the large language model with checks for a valid markdown table.
|
761 |
|
@@ -791,7 +439,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
791 |
call_temperature, reported_batch_no, local_model, master=master
|
792 |
)
|
793 |
|
794 |
-
if
|
795 |
stripped_response = responses[-1].text.strip()
|
796 |
else:
|
797 |
stripped_response = responses[-1]['choices'][0]['text'].strip()
|
@@ -824,8 +472,9 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
|
|
824 |
existing_reference_df:pd.DataFrame,
|
825 |
existing_topics_df:pd.DataFrame,
|
826 |
batch_size_number:int,
|
827 |
-
in_column:str,
|
828 |
-
first_run: bool = False
|
|
|
829 |
"""
|
830 |
Writes the output of the large language model requests and logs to files.
|
831 |
|
@@ -933,12 +582,24 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
|
|
933 |
# If no numbers found in the Response References column, check the Summary column in case reference numbers were put there by mistake
|
934 |
if not references:
|
935 |
references = re.findall(r'\d+', str(row.iloc[4])) if pd.notna(row.iloc[4]) else []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
topic = row.iloc[0] if pd.notna(row.iloc[0]) else ""
|
937 |
subtopic = row.iloc[1] if pd.notna(row.iloc[1]) else ""
|
938 |
sentiment = row.iloc[2] if pd.notna(row.iloc[2]) else ""
|
939 |
summary = row.iloc[4] if pd.notna(row.iloc[4]) else ""
|
940 |
# If the reference response column is very long, and there's nothing in the summary column, assume that the summary was put in the reference column
|
941 |
-
if not summary and len(str(row.iloc[3]) > 30):
|
942 |
summary = row.iloc[3]
|
943 |
|
944 |
summary = row_number_string_start + summary
|
@@ -1172,10 +833,12 @@ def extract_topics(in_data_file,
|
|
1172 |
force_zero_shot_radio:str = "No",
|
1173 |
in_excel_sheets:List[str] = [],
|
1174 |
force_single_topic_radio:str = "No",
|
|
|
1175 |
force_single_topic_prompt:str=force_single_topic_prompt,
|
1176 |
max_tokens:int=max_tokens,
|
1177 |
model_name_map:dict=model_name_map,
|
1178 |
-
max_time_for_loop:int=max_time_for_loop,
|
|
|
1179 |
progress=Progress(track_tqdm=True)):
|
1180 |
|
1181 |
'''
|
@@ -1215,10 +878,12 @@ def extract_topics(in_data_file,
|
|
1215 |
- force_zero_shot_radio (str, optional): Should responses be forced into a zero shot topic or not.
|
1216 |
- in_excel_sheets (List[str], optional): List of excel sheets to load from input file.
|
1217 |
- force_single_topic_radio (str, optional): Should the model be forced to assign only one single topic to each response (effectively a classifier).
|
|
|
1218 |
- force_single_topic_prompt (str, optional): The prompt for forcing the model to assign only one single topic to each response.
|
1219 |
- max_tokens (int): The maximum number of tokens for the model.
|
1220 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
1221 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
|
|
1222 |
- progress (Progress): A progress tracker.
|
1223 |
'''
|
1224 |
|
@@ -1265,9 +930,9 @@ def extract_topics(in_data_file,
|
|
1265 |
out_message = []
|
1266 |
out_file_paths = []
|
1267 |
|
1268 |
-
if (model_choice ==
|
1269 |
-
progress(0.1, "Loading in
|
1270 |
-
local_model, tokenizer = load_model()
|
1271 |
|
1272 |
if num_batches > 0:
|
1273 |
progress_measure = round(latest_batch_completed / num_batches, 1)
|
@@ -1305,7 +970,7 @@ def extract_topics(in_data_file,
|
|
1305 |
print("Running query batch", str(reported_batch_no))
|
1306 |
|
1307 |
# Call the function to prepare the input table
|
1308 |
-
simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols,
|
1309 |
#log_files_output_paths.append(simplified_csv_table_path)
|
1310 |
|
1311 |
# Conversation history
|
@@ -1420,7 +1085,7 @@ def extract_topics(in_data_file,
|
|
1420 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no, local_model, MAX_OUTPUT_VALIDATION_ATTEMPTS, master = True)
|
1421 |
|
1422 |
# Return output tables
|
1423 |
-
topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, new_topic_df, new_markdown_table, new_reference_df, new_unique_topics_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, first_run=False)
|
1424 |
|
1425 |
# Write final output to text file for logging purposes
|
1426 |
try:
|
@@ -1484,8 +1149,8 @@ def extract_topics(in_data_file,
|
|
1484 |
if "gemini" in model_choice:
|
1485 |
print("Using Gemini model:", model_choice)
|
1486 |
model, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
|
1487 |
-
elif
|
1488 |
-
print("Using local
|
1489 |
else:
|
1490 |
print("Using AWS Bedrock model:", model_choice)
|
1491 |
|
@@ -1512,7 +1177,7 @@ def extract_topics(in_data_file,
|
|
1512 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no, local_model, MAX_OUTPUT_VALIDATION_ATTEMPTS)
|
1513 |
|
1514 |
|
1515 |
-
topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_table_df, markdown_table, reference_df, new_unique_topics_df, batch_file_path_details, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, first_run=True)
|
1516 |
|
1517 |
# If error in table parsing, leave function
|
1518 |
if is_error == True:
|
@@ -1697,39 +1362,6 @@ def extract_topics(in_data_file,
|
|
1697 |
|
1698 |
return unique_table_df_display_table_markdown, existing_topics_table, existing_unique_topics_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, gr.Dataframe(value=modifiable_unique_topics_df, headers=None, col_count=(modifiable_unique_topics_df.shape[1], "fixed"), row_count = (modifiable_unique_topics_df.shape[0], "fixed"), visible=True, type="pandas"), out_file_paths, join_file_paths
|
1699 |
|
1700 |
-
def convert_reference_table_to_pivot_table(df:pd.DataFrame, basic_response_data:pd.DataFrame=pd.DataFrame()):
|
1701 |
-
|
1702 |
-
df_in = df[['Response References', 'General Topic', 'Subtopic', 'Sentiment']].copy()
|
1703 |
-
|
1704 |
-
df_in['Response References'] = df_in['Response References'].astype(int)
|
1705 |
-
|
1706 |
-
# Create a combined category column
|
1707 |
-
df_in['Category'] = df_in['General Topic'] + ' - ' + df_in['Subtopic'] + ' - ' + df_in['Sentiment']
|
1708 |
-
|
1709 |
-
# Create pivot table counting occurrences of each unique combination
|
1710 |
-
pivot_table = pd.crosstab(
|
1711 |
-
index=df_in['Response References'],
|
1712 |
-
columns=[df_in['General Topic'], df_in['Subtopic'], df_in['Sentiment']],
|
1713 |
-
margins=True
|
1714 |
-
)
|
1715 |
-
|
1716 |
-
# Flatten column names to make them more readable
|
1717 |
-
pivot_table.columns = [' - '.join(col) for col in pivot_table.columns]
|
1718 |
-
|
1719 |
-
pivot_table.reset_index(inplace=True)
|
1720 |
-
|
1721 |
-
if not basic_response_data.empty:
|
1722 |
-
pivot_table = basic_response_data.merge(pivot_table, right_on="Response References", left_on="Reference", how="left")
|
1723 |
-
|
1724 |
-
pivot_table.drop("Response References", axis=1, inplace=True)
|
1725 |
-
|
1726 |
-
pivot_table.columns = pivot_table.columns.str.replace("Not assessed - ", "").str.replace("- Not assessed", "")
|
1727 |
-
|
1728 |
-
# print("pivot_table:", pivot_table)
|
1729 |
-
|
1730 |
-
return pivot_table
|
1731 |
-
|
1732 |
-
|
1733 |
def join_modified_topic_names_to_ref_table(modified_unique_topics_df:pd.DataFrame, original_unique_topics_df:pd.DataFrame, reference_df:pd.DataFrame):
|
1734 |
'''
|
1735 |
Take a unique topic table that has been modified by the user, and apply the topic name changes to the long-form reference table.
|
@@ -1770,7 +1402,7 @@ def join_modified_topic_names_to_ref_table(modified_unique_topics_df:pd.DataFram
|
|
1770 |
return modified_reference_df
|
1771 |
|
1772 |
# MODIFY EXISTING TABLE
|
1773 |
-
def modify_existing_output_tables(original_unique_topics_df:pd.DataFrame, modifiable_unique_topics_df:pd.DataFrame, reference_df:pd.DataFrame, text_output_file_list_state:List[str]) -> Tuple:
|
1774 |
'''
|
1775 |
Take a unique_topics table that has been modified, apply these new topic names to the long-form reference_df, and save both tables to file.
|
1776 |
'''
|
@@ -1787,9 +1419,6 @@ def modify_existing_output_tables(original_unique_topics_df:pd.DataFrame, modifi
|
|
1787 |
reference_file_path = os.path.basename(reference_files[0]) if reference_files else None
|
1788 |
unique_table_file_path = os.path.basename(unique_files[0]) if unique_files else None
|
1789 |
|
1790 |
-
print("Reference File:", reference_file_path)
|
1791 |
-
print("Unique Table File:", unique_table_file_path)
|
1792 |
-
|
1793 |
output_file_list = []
|
1794 |
|
1795 |
if reference_file_path and unique_table_file_path:
|
@@ -1833,583 +1462,3 @@ def modify_existing_output_tables(original_unique_topics_df:pd.DataFrame, modifi
|
|
1833 |
|
1834 |
|
1835 |
return modifiable_unique_topics_df, reference_df, output_file_list, output_file_list, output_file_list, output_file_list, reference_table_file_name, unique_table_file_name, deduplicated_unique_table_markdown
|
1836 |
-
|
1837 |
-
|
1838 |
-
# DEDUPLICATION/SUMMARISATION FUNCTIONS
|
1839 |
-
def deduplicate_categories(category_series: pd.Series, join_series: pd.Series, reference_df: pd.DataFrame, general_topic_series: pd.Series = None, merge_general_topics = "No", merge_sentiment:str="No", threshold: float = 90) -> pd.DataFrame:
|
1840 |
-
"""
|
1841 |
-
Deduplicates similar category names in a pandas Series based on a fuzzy matching threshold,
|
1842 |
-
merging smaller topics into larger topics.
|
1843 |
-
|
1844 |
-
Parameters:
|
1845 |
-
category_series (pd.Series): Series containing category names to deduplicate.
|
1846 |
-
join_series (pd.Series): Additional series used for joining back to original results.
|
1847 |
-
reference_df (pd.DataFrame): DataFrame containing the reference data to count occurrences.
|
1848 |
-
threshold (float): Similarity threshold for considering two strings as duplicates.
|
1849 |
-
|
1850 |
-
Returns:
|
1851 |
-
pd.DataFrame: DataFrame with columns ['old_category', 'deduplicated_category'].
|
1852 |
-
"""
|
1853 |
-
# Count occurrences of each category in the reference_df
|
1854 |
-
category_counts = reference_df['Subtopic'].value_counts().to_dict()
|
1855 |
-
|
1856 |
-
# Initialize dictionaries for both category mapping and scores
|
1857 |
-
deduplication_map = {}
|
1858 |
-
match_scores = {} # New dictionary to store match scores
|
1859 |
-
|
1860 |
-
# First pass: Handle exact matches
|
1861 |
-
for category in category_series.unique():
|
1862 |
-
if category in deduplication_map:
|
1863 |
-
continue
|
1864 |
-
|
1865 |
-
# Find all exact matches
|
1866 |
-
exact_matches = category_series[category_series.str.lower() == category.lower()].index.tolist()
|
1867 |
-
if len(exact_matches) > 1:
|
1868 |
-
# Find the variant with the highest count
|
1869 |
-
match_counts = {match: category_counts.get(category_series[match], 0) for match in exact_matches}
|
1870 |
-
most_common = max(match_counts.items(), key=lambda x: x[1])[0]
|
1871 |
-
most_common_category = category_series[most_common]
|
1872 |
-
|
1873 |
-
# Map all exact matches to the most common variant and store score
|
1874 |
-
for match in exact_matches:
|
1875 |
-
deduplication_map[category_series[match]] = most_common_category
|
1876 |
-
match_scores[category_series[match]] = 100 # Exact matches get score of 100
|
1877 |
-
|
1878 |
-
# Second pass: Handle fuzzy matches for remaining categories
|
1879 |
-
# Create a DataFrame to maintain the relationship between categories and general topics
|
1880 |
-
categories_df = pd.DataFrame({
|
1881 |
-
'category': category_series,
|
1882 |
-
'general_topic': general_topic_series
|
1883 |
-
}).drop_duplicates()
|
1884 |
-
|
1885 |
-
for _, row in categories_df.iterrows():
|
1886 |
-
category = row['category']
|
1887 |
-
if category in deduplication_map:
|
1888 |
-
continue
|
1889 |
-
|
1890 |
-
current_general_topic = row['general_topic']
|
1891 |
-
|
1892 |
-
# Filter potential matches to only those within the same General Topic if relevant
|
1893 |
-
if merge_general_topics == "No":
|
1894 |
-
potential_matches = categories_df[
|
1895 |
-
(categories_df['category'] != category) &
|
1896 |
-
(categories_df['general_topic'] == current_general_topic)
|
1897 |
-
]['category'].tolist()
|
1898 |
-
else:
|
1899 |
-
potential_matches = categories_df[
|
1900 |
-
(categories_df['category'] != category)
|
1901 |
-
]['category'].tolist()
|
1902 |
-
|
1903 |
-
matches = process.extract(category,
|
1904 |
-
potential_matches,
|
1905 |
-
scorer=fuzz.WRatio,
|
1906 |
-
score_cutoff=threshold)
|
1907 |
-
|
1908 |
-
if matches:
|
1909 |
-
best_match = max(matches, key=lambda x: x[1])
|
1910 |
-
match, score, _ = best_match
|
1911 |
-
|
1912 |
-
if category_counts.get(category, 0) < category_counts.get(match, 0):
|
1913 |
-
deduplication_map[category] = match
|
1914 |
-
match_scores[category] = score
|
1915 |
-
else:
|
1916 |
-
deduplication_map[match] = category
|
1917 |
-
match_scores[match] = score
|
1918 |
-
else:
|
1919 |
-
deduplication_map[category] = category
|
1920 |
-
match_scores[category] = 100
|
1921 |
-
|
1922 |
-
# Create the result DataFrame with scores
|
1923 |
-
result_df = pd.DataFrame({
|
1924 |
-
'old_category': category_series + " | " + join_series,
|
1925 |
-
'deduplicated_category': category_series.map(lambda x: deduplication_map.get(x, x)),
|
1926 |
-
'match_score': category_series.map(lambda x: match_scores.get(x, 100)) # Add scores column
|
1927 |
-
})
|
1928 |
-
|
1929 |
-
#print(result_df)
|
1930 |
-
|
1931 |
-
return result_df
|
1932 |
-
|
1933 |
-
def deduplicate_topics(reference_df:pd.DataFrame,
|
1934 |
-
unique_topics_df:pd.DataFrame,
|
1935 |
-
reference_table_file_name:str,
|
1936 |
-
unique_topics_table_file_name:str,
|
1937 |
-
in_excel_sheets:str="",
|
1938 |
-
merge_sentiment:str= "No",
|
1939 |
-
merge_general_topics:str="No",
|
1940 |
-
score_threshold:int=90,
|
1941 |
-
in_data_files:List[str]=[],
|
1942 |
-
chosen_cols:List[str]="",
|
1943 |
-
deduplicate_topics:str="Yes"
|
1944 |
-
):
|
1945 |
-
'''
|
1946 |
-
Deduplicate topics based on a reference and unique topics table
|
1947 |
-
'''
|
1948 |
-
output_files = []
|
1949 |
-
log_output_files = []
|
1950 |
-
file_data = pd.DataFrame()
|
1951 |
-
|
1952 |
-
reference_table_file_name_no_ext = reference_table_file_name #get_file_name_no_ext(reference_table_file_name)
|
1953 |
-
unique_topics_table_file_name_no_ext = unique_topics_table_file_name #get_file_name_no_ext(unique_topics_table_file_name)
|
1954 |
-
|
1955 |
-
# For checking that data is not lost during the process
|
1956 |
-
initial_unique_references = len(reference_df["Response References"].unique())
|
1957 |
-
|
1958 |
-
if unique_topics_df.empty:
|
1959 |
-
unique_topics_df = create_unique_table_df_from_reference_table(reference_df)
|
1960 |
-
|
1961 |
-
# Then merge the topic numbers back to the original dataframe
|
1962 |
-
reference_df = reference_df.merge(
|
1963 |
-
unique_topics_df[['General Topic', 'Subtopic', 'Sentiment', 'Topic_number']],
|
1964 |
-
on=['General Topic', 'Subtopic', 'Sentiment'],
|
1965 |
-
how='left'
|
1966 |
-
)
|
1967 |
-
|
1968 |
-
if in_data_files and chosen_cols:
|
1969 |
-
file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file(in_data_files, chosen_cols, 1, in_excel_sheets)
|
1970 |
-
else:
|
1971 |
-
out_message = "No file data found, pivot table output will not be created."
|
1972 |
-
print(out_message)
|
1973 |
-
#raise Exception(out_message)
|
1974 |
-
|
1975 |
-
# Run through this x times to try to get all duplicate topics
|
1976 |
-
if deduplicate_topics == "Yes":
|
1977 |
-
for i in range(0, 8):
|
1978 |
-
if merge_sentiment == "No":
|
1979 |
-
if merge_general_topics == "No":
|
1980 |
-
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
1981 |
-
reference_df_unique = reference_df.drop_duplicates("old_category")
|
1982 |
-
|
1983 |
-
deduplicated_topic_map_df = reference_df_unique.groupby(["General Topic", "Sentiment"]).apply(
|
1984 |
-
lambda group: deduplicate_categories(
|
1985 |
-
group["Subtopic"],
|
1986 |
-
group["Sentiment"],
|
1987 |
-
reference_df,
|
1988 |
-
general_topic_series=group["General Topic"],
|
1989 |
-
merge_general_topics="No",
|
1990 |
-
threshold=score_threshold
|
1991 |
-
)
|
1992 |
-
).reset_index(drop=True)
|
1993 |
-
else:
|
1994 |
-
# This case should allow cross-topic matching but is still grouping by Sentiment
|
1995 |
-
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
1996 |
-
reference_df_unique = reference_df.drop_duplicates("old_category")
|
1997 |
-
|
1998 |
-
deduplicated_topic_map_df = reference_df_unique.groupby("Sentiment").apply(
|
1999 |
-
lambda group: deduplicate_categories(
|
2000 |
-
group["Subtopic"],
|
2001 |
-
group["Sentiment"],
|
2002 |
-
reference_df,
|
2003 |
-
general_topic_series=None, # Set to None to allow cross-topic matching
|
2004 |
-
merge_general_topics="Yes",
|
2005 |
-
threshold=score_threshold
|
2006 |
-
)
|
2007 |
-
).reset_index(drop=True)
|
2008 |
-
else:
|
2009 |
-
if merge_general_topics == "No":
|
2010 |
-
# Update this case to maintain general topic boundaries
|
2011 |
-
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
2012 |
-
reference_df_unique = reference_df.drop_duplicates("old_category")
|
2013 |
-
|
2014 |
-
deduplicated_topic_map_df = reference_df_unique.groupby("General Topic").apply(
|
2015 |
-
lambda group: deduplicate_categories(
|
2016 |
-
group["Subtopic"],
|
2017 |
-
group["Sentiment"],
|
2018 |
-
reference_df,
|
2019 |
-
general_topic_series=group["General Topic"],
|
2020 |
-
merge_general_topics="No",
|
2021 |
-
merge_sentiment=merge_sentiment,
|
2022 |
-
threshold=score_threshold
|
2023 |
-
)
|
2024 |
-
).reset_index(drop=True)
|
2025 |
-
else:
|
2026 |
-
# For complete merging across all categories
|
2027 |
-
reference_df["old_category"] = reference_df["Subtopic"] + " | " + reference_df["Sentiment"]
|
2028 |
-
reference_df_unique = reference_df.drop_duplicates("old_category")
|
2029 |
-
|
2030 |
-
deduplicated_topic_map_df = deduplicate_categories(
|
2031 |
-
reference_df_unique["Subtopic"],
|
2032 |
-
reference_df_unique["Sentiment"],
|
2033 |
-
reference_df,
|
2034 |
-
general_topic_series=None, # Set to None to allow cross-topic matching
|
2035 |
-
merge_general_topics="Yes",
|
2036 |
-
merge_sentiment=merge_sentiment,
|
2037 |
-
threshold=score_threshold
|
2038 |
-
).reset_index(drop=True)
|
2039 |
-
|
2040 |
-
if deduplicated_topic_map_df['deduplicated_category'].isnull().all():
|
2041 |
-
# Check if 'deduplicated_category' contains any values
|
2042 |
-
print("No deduplicated categories found, skipping the following code.")
|
2043 |
-
|
2044 |
-
else:
|
2045 |
-
# Remove rows where 'deduplicated_category' is blank or NaN
|
2046 |
-
deduplicated_topic_map_df = deduplicated_topic_map_df.loc[(deduplicated_topic_map_df['deduplicated_category'].str.strip() != '') & ~(deduplicated_topic_map_df['deduplicated_category'].isnull()), ['old_category','deduplicated_category', 'match_score']]
|
2047 |
-
|
2048 |
-
#deduplicated_topic_map_df.to_csv(output_folder + "deduplicated_topic_map_df_" + str(i) + ".csv", index=None)
|
2049 |
-
|
2050 |
-
reference_df = reference_df.merge(deduplicated_topic_map_df, on="old_category", how="left")
|
2051 |
-
|
2052 |
-
reference_df.rename(columns={"Subtopic": "Subtopic_old", "Sentiment": "Sentiment_old"}, inplace=True)
|
2053 |
-
# Extract subtopic and sentiment from deduplicated_category
|
2054 |
-
reference_df["Subtopic"] = reference_df["deduplicated_category"].str.extract(r'^(.*?) \|')[0] # Extract subtopic
|
2055 |
-
reference_df["Sentiment"] = reference_df["deduplicated_category"].str.extract(r'\| (.*)$')[0] # Extract sentiment
|
2056 |
-
|
2057 |
-
# Combine with old values to ensure no data is lost
|
2058 |
-
reference_df["Subtopic"] = reference_df["deduplicated_category"].combine_first(reference_df["Subtopic_old"])
|
2059 |
-
reference_df["Sentiment"] = reference_df["Sentiment"].combine_first(reference_df["Sentiment_old"])
|
2060 |
-
|
2061 |
-
|
2062 |
-
reference_df.drop(['old_category', 'deduplicated_category', "Subtopic_old", "Sentiment_old"], axis=1, inplace=True, errors="ignore")
|
2063 |
-
|
2064 |
-
reference_df = reference_df[["Response References", "General Topic", "Subtopic", "Sentiment", "Summary", "Start row of group"]]
|
2065 |
-
|
2066 |
-
#reference_df["General Topic"] = reference_df["General Topic"].str.lower().str.capitalize()
|
2067 |
-
#reference_df["Subtopic"] = reference_df["Subtopic"].str.lower().str.capitalize()
|
2068 |
-
#reference_df["Sentiment"] = reference_df["Sentiment"].str.lower().str.capitalize()
|
2069 |
-
|
2070 |
-
if merge_general_topics == "Yes":
|
2071 |
-
# Replace General topic names for each Subtopic with that for the Subtopic with the most responses
|
2072 |
-
# Step 1: Count the number of occurrences for each General Topic and Subtopic combination
|
2073 |
-
count_df = reference_df.groupby(['Subtopic', 'General Topic']).size().reset_index(name='Count')
|
2074 |
-
|
2075 |
-
# Step 2: Find the General Topic with the maximum count for each Subtopic
|
2076 |
-
max_general_topic = count_df.loc[count_df.groupby('Subtopic')['Count'].idxmax()]
|
2077 |
-
|
2078 |
-
# Step 3: Map the General Topic back to the original DataFrame
|
2079 |
-
reference_df = reference_df.merge(max_general_topic[['Subtopic', 'General Topic']], on='Subtopic', suffixes=('', '_max'), how='left')
|
2080 |
-
|
2081 |
-
reference_df['General Topic'] = reference_df["General Topic_max"].combine_first(reference_df["General Topic"])
|
2082 |
-
|
2083 |
-
if merge_sentiment == "Yes":
|
2084 |
-
# Step 1: Count the number of occurrences for each General Topic and Subtopic combination
|
2085 |
-
count_df = reference_df.groupby(['Subtopic', 'Sentiment']).size().reset_index(name='Count')
|
2086 |
-
|
2087 |
-
# Step 2: Determine the number of unique Sentiment values for each Subtopic
|
2088 |
-
unique_sentiments = count_df.groupby('Subtopic')['Sentiment'].nunique().reset_index(name='UniqueCount')
|
2089 |
-
|
2090 |
-
# Step 3: Update Sentiment to 'Mixed' where there is more than one unique sentiment
|
2091 |
-
reference_df = reference_df.merge(unique_sentiments, on='Subtopic', how='left')
|
2092 |
-
reference_df['Sentiment'] = reference_df.apply(
|
2093 |
-
lambda row: 'Mixed' if row['UniqueCount'] > 1 else row['Sentiment'],
|
2094 |
-
axis=1
|
2095 |
-
)
|
2096 |
-
|
2097 |
-
# Clean up the DataFrame by dropping the UniqueCount column
|
2098 |
-
reference_df.drop(columns=['UniqueCount'], inplace=True)
|
2099 |
-
|
2100 |
-
reference_df = reference_df[["Response References", "General Topic", "Subtopic", "Sentiment", "Summary", "Start row of group"]]
|
2101 |
-
|
2102 |
-
# Update reference summary column with all summaries
|
2103 |
-
reference_df["Summary"] = reference_df.groupby(
|
2104 |
-
["Response References", "General Topic", "Subtopic", "Sentiment"]
|
2105 |
-
)["Summary"].transform(' <br> '.join)
|
2106 |
-
|
2107 |
-
# Check that we have not inadvertantly removed some data during the above process
|
2108 |
-
end_unique_references = len(reference_df["Response References"].unique())
|
2109 |
-
|
2110 |
-
if initial_unique_references != end_unique_references:
|
2111 |
-
raise Exception(f"Number of unique references changed during processing: Initial={initial_unique_references}, Final={end_unique_references}")
|
2112 |
-
|
2113 |
-
# Drop duplicates in the reference table - each comment should only have the same topic referred to once
|
2114 |
-
reference_df.drop_duplicates(['Response References', 'General Topic', 'Subtopic', 'Sentiment'], inplace=True)
|
2115 |
-
|
2116 |
-
|
2117 |
-
# Remake unique_topics_df based on new reference_df
|
2118 |
-
unique_topics_df = create_unique_table_df_from_reference_table(reference_df)
|
2119 |
-
|
2120 |
-
# Then merge the topic numbers back to the original dataframe
|
2121 |
-
reference_df = reference_df.merge(
|
2122 |
-
unique_topics_df[['General Topic', 'Subtopic', 'Sentiment', 'Topic_number']],
|
2123 |
-
on=['General Topic', 'Subtopic', 'Sentiment'],
|
2124 |
-
how='left'
|
2125 |
-
)
|
2126 |
-
|
2127 |
-
if not file_data.empty:
|
2128 |
-
basic_response_data = get_basic_response_data(file_data, chosen_cols)
|
2129 |
-
reference_df_pivot = convert_reference_table_to_pivot_table(reference_df, basic_response_data)
|
2130 |
-
|
2131 |
-
reference_pivot_file_path = output_folder + reference_table_file_name_no_ext + "_pivot_dedup.csv"
|
2132 |
-
reference_df_pivot.to_csv(reference_pivot_file_path, index=None, encoding='utf-8')
|
2133 |
-
log_output_files.append(reference_pivot_file_path)
|
2134 |
-
|
2135 |
-
#reference_table_file_name_no_ext = get_file_name_no_ext(reference_table_file_name)
|
2136 |
-
#unique_topics_table_file_name_no_ext = get_file_name_no_ext(unique_topics_table_file_name)
|
2137 |
-
|
2138 |
-
reference_file_path = output_folder + reference_table_file_name_no_ext + "_dedup.csv"
|
2139 |
-
unique_topics_file_path = output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv"
|
2140 |
-
reference_df.to_csv(reference_file_path, index = None, encoding='utf-8')
|
2141 |
-
unique_topics_df.to_csv(unique_topics_file_path, index=None, encoding='utf-8')
|
2142 |
-
|
2143 |
-
output_files.append(reference_file_path)
|
2144 |
-
output_files.append(unique_topics_file_path)
|
2145 |
-
|
2146 |
-
# Outputs for markdown table output
|
2147 |
-
unique_table_df_revised_display = unique_topics_df.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
|
2148 |
-
|
2149 |
-
deduplicated_unique_table_markdown = unique_table_df_revised_display.to_markdown(index=False)
|
2150 |
-
|
2151 |
-
return reference_df, unique_topics_df, output_files, log_output_files, deduplicated_unique_table_markdown
|
2152 |
-
|
2153 |
-
def sample_reference_table_summaries(reference_df:pd.DataFrame,
|
2154 |
-
unique_topics_df:pd.DataFrame,
|
2155 |
-
random_seed:int,
|
2156 |
-
no_of_sampled_summaries:int=150):
|
2157 |
-
|
2158 |
-
'''
|
2159 |
-
Sample x number of summaries from which to produce summaries, so that the input token length is not too long.
|
2160 |
-
'''
|
2161 |
-
|
2162 |
-
all_summaries = pd.DataFrame()
|
2163 |
-
output_files = []
|
2164 |
-
|
2165 |
-
reference_df_grouped = reference_df.groupby(["General Topic", "Subtopic", "Sentiment"])
|
2166 |
-
|
2167 |
-
if 'Revised summary' in reference_df.columns:
|
2168 |
-
out_message = "Summary has already been created for this file"
|
2169 |
-
print(out_message)
|
2170 |
-
raise Exception(out_message)
|
2171 |
-
|
2172 |
-
for group_keys, reference_df_group in reference_df_grouped:
|
2173 |
-
#print(f"Group: {group_keys}")
|
2174 |
-
#print(f"Data: {reference_df_group}")
|
2175 |
-
|
2176 |
-
if len(reference_df_group["General Topic"]) > 1:
|
2177 |
-
|
2178 |
-
filtered_reference_df = reference_df_group.reset_index()
|
2179 |
-
|
2180 |
-
filtered_reference_df_unique = filtered_reference_df.drop_duplicates(["General Topic", "Subtopic", "Sentiment", "Summary"])
|
2181 |
-
|
2182 |
-
# Sample n of the unique topic summaries. To limit the length of the text going into the summarisation tool
|
2183 |
-
filtered_reference_df_unique_sampled = filtered_reference_df_unique.sample(min(no_of_sampled_summaries, len(filtered_reference_df_unique)), random_state=random_seed)
|
2184 |
-
|
2185 |
-
#topic_summary_table_markdown = filtered_reference_df_unique_sampled.to_markdown(index=False)
|
2186 |
-
|
2187 |
-
#print(filtered_reference_df_unique_sampled)
|
2188 |
-
|
2189 |
-
all_summaries = pd.concat([all_summaries, filtered_reference_df_unique_sampled])
|
2190 |
-
|
2191 |
-
summarised_references = all_summaries.groupby(["General Topic", "Subtopic", "Sentiment"]).agg({
|
2192 |
-
'Response References': 'size', # Count the number of references
|
2193 |
-
'Summary': lambda x: '\n'.join([s.split(': ', 1)[1] for s in x if ': ' in s]) # Join substrings after ': '
|
2194 |
-
}).reset_index()
|
2195 |
-
|
2196 |
-
summarised_references = summarised_references.loc[(summarised_references["Sentiment"] != "Not Mentioned") & (summarised_references["Response References"] > 1)]
|
2197 |
-
|
2198 |
-
summarised_references_markdown = summarised_references.to_markdown(index=False)
|
2199 |
-
|
2200 |
-
return summarised_references, summarised_references_markdown, reference_df, unique_topics_df
|
2201 |
-
|
2202 |
-
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, local_model=[]):
|
2203 |
-
conversation_history = []
|
2204 |
-
whole_conversation_metadata = []
|
2205 |
-
|
2206 |
-
# Prepare Gemini models before query
|
2207 |
-
if "gemini" in model_choice:
|
2208 |
-
print("Using Gemini model:", model_choice)
|
2209 |
-
model, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
|
2210 |
-
else:
|
2211 |
-
print("Using AWS Bedrock model:", model_choice)
|
2212 |
-
model = model_choice
|
2213 |
-
config = {}
|
2214 |
-
|
2215 |
-
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
2216 |
-
|
2217 |
-
# Process requests to large language model
|
2218 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, local_model=local_model)
|
2219 |
-
|
2220 |
-
print("Finished summary query")
|
2221 |
-
|
2222 |
-
if isinstance(responses[-1], ResponseObject):
|
2223 |
-
response_texts = [resp.text for resp in responses]
|
2224 |
-
elif "choices" in responses[-1]:
|
2225 |
-
response_texts = [resp["choices"][0]['text'] for resp in responses]
|
2226 |
-
else:
|
2227 |
-
response_texts = [resp.text for resp in responses]
|
2228 |
-
|
2229 |
-
latest_response_text = response_texts[-1]
|
2230 |
-
|
2231 |
-
#print("latest_response_text:", latest_response_text)
|
2232 |
-
#print("Whole conversation metadata:", whole_conversation_metadata)
|
2233 |
-
|
2234 |
-
return latest_response_text, conversation_history, whole_conversation_metadata
|
2235 |
-
|
2236 |
-
@spaces.GPU
|
2237 |
-
def summarise_output_topics(summarised_references:pd.DataFrame,
|
2238 |
-
unique_table_df:pd.DataFrame,
|
2239 |
-
reference_table_df:pd.DataFrame,
|
2240 |
-
model_choice:str,
|
2241 |
-
in_api_key:str,
|
2242 |
-
topic_summary_table_markdown:str,
|
2243 |
-
temperature:float,
|
2244 |
-
table_file_name:str,
|
2245 |
-
summarised_outputs:list = [],
|
2246 |
-
latest_summary_completed:int = 0,
|
2247 |
-
out_metadata_str:str = "",
|
2248 |
-
in_data_files:List[str]=[],
|
2249 |
-
in_excel_sheets:str="",
|
2250 |
-
chosen_cols:List[str]=[],
|
2251 |
-
log_output_files:list[str]=[],
|
2252 |
-
summarise_format_radio:str="Return a summary up to two paragraphs long that includes as much detail as possible from the original text",
|
2253 |
-
output_files:list[str] = [],
|
2254 |
-
summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
|
2255 |
-
do_summaries="Yes",
|
2256 |
-
progress=gr.Progress(track_tqdm=True)):
|
2257 |
-
'''
|
2258 |
-
Create better summaries of the raw batch-level summaries created in the first run of the model.
|
2259 |
-
'''
|
2260 |
-
out_metadata = []
|
2261 |
-
local_model = []
|
2262 |
-
summarised_output_markdown = ""
|
2263 |
-
|
2264 |
-
|
2265 |
-
# Check for data for summarisations
|
2266 |
-
if not unique_table_df.empty and not reference_table_df.empty:
|
2267 |
-
print("Unique table and reference table data found.")
|
2268 |
-
else:
|
2269 |
-
out_message = "Please upload a unique topic table and reference table file to continue with summarisation."
|
2270 |
-
print(out_message)
|
2271 |
-
raise Exception(out_message)
|
2272 |
-
|
2273 |
-
if 'Revised summary' in reference_table_df.columns:
|
2274 |
-
out_message = "Summary has already been created for this file"
|
2275 |
-
print(out_message)
|
2276 |
-
raise Exception(out_message)
|
2277 |
-
|
2278 |
-
# Load in data file and chosen columns if exists to create pivot table later
|
2279 |
-
if in_data_files and chosen_cols:
|
2280 |
-
file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file(in_data_files, chosen_cols, 1, in_excel_sheets=in_excel_sheets)
|
2281 |
-
else:
|
2282 |
-
out_message = "No file data found, pivot table output will not be created."
|
2283 |
-
print(out_message)
|
2284 |
-
raise Exception(out_message)
|
2285 |
-
|
2286 |
-
|
2287 |
-
all_summaries = summarised_references["Summary"].tolist()
|
2288 |
-
length_all_summaries = len(all_summaries)
|
2289 |
-
|
2290 |
-
# If all summaries completed, make final outputs
|
2291 |
-
if latest_summary_completed >= length_all_summaries:
|
2292 |
-
print("All summaries completed. Creating outputs.")
|
2293 |
-
|
2294 |
-
model_choice_clean = model_name_map[model_choice]
|
2295 |
-
file_name = re.search(r'(.*?)(?:_batch_|_col_)', table_file_name).group(1) if re.search(r'(.*?)(?:_batch_|_col_)', table_file_name) else table_file_name
|
2296 |
-
latest_batch_completed = int(re.search(r'batch_(\d+)_', table_file_name).group(1)) if 'batch_' in table_file_name else ""
|
2297 |
-
batch_size_number = int(re.search(r'size_(\d+)_', table_file_name).group(1)) if 'size_' in table_file_name else ""
|
2298 |
-
in_column_cleaned = re.search(r'col_(.*?)_reference', table_file_name).group(1) if 'col_' in table_file_name else ""
|
2299 |
-
|
2300 |
-
# Save outputs for each batch. If master file created, label file as master
|
2301 |
-
if latest_batch_completed:
|
2302 |
-
batch_file_path_details = f"{file_name}_batch_{latest_batch_completed}_size_{batch_size_number}_col_{in_column_cleaned}"
|
2303 |
-
else:
|
2304 |
-
batch_file_path_details = f"{file_name}_col_{in_column_cleaned}"
|
2305 |
-
|
2306 |
-
summarised_references["Revised summary"] = summarised_outputs
|
2307 |
-
|
2308 |
-
join_cols = ["General Topic", "Subtopic", "Sentiment"]
|
2309 |
-
join_plus_summary_cols = ["General Topic", "Subtopic", "Sentiment", "Revised summary"]
|
2310 |
-
|
2311 |
-
summarised_references_j = summarised_references[join_plus_summary_cols].drop_duplicates(join_plus_summary_cols)
|
2312 |
-
|
2313 |
-
unique_table_df_revised = unique_table_df.merge(summarised_references_j, on = join_cols, how = "left")
|
2314 |
-
|
2315 |
-
# If no new summary is available, keep the original
|
2316 |
-
unique_table_df_revised["Revised summary"] = unique_table_df_revised["Revised summary"].combine_first(unique_table_df_revised["Summary"])
|
2317 |
-
|
2318 |
-
unique_table_df_revised = unique_table_df_revised[["General Topic", "Subtopic", "Sentiment", "Response References", "Revised summary"]]
|
2319 |
-
|
2320 |
-
reference_table_df_revised = reference_table_df.merge(summarised_references_j, on = join_cols, how = "left")
|
2321 |
-
# If no new summary is available, keep the original
|
2322 |
-
reference_table_df_revised["Revised summary"] = reference_table_df_revised["Revised summary"].combine_first(reference_table_df_revised["Summary"])
|
2323 |
-
reference_table_df_revised = reference_table_df_revised.drop("Summary", axis=1)
|
2324 |
-
|
2325 |
-
# Remove topics that are tagged as 'Not Mentioned'
|
2326 |
-
unique_table_df_revised = unique_table_df_revised.loc[unique_table_df_revised["Sentiment"] != "Not Mentioned", :]
|
2327 |
-
reference_table_df_revised = reference_table_df_revised.loc[reference_table_df_revised["Sentiment"] != "Not Mentioned", :]
|
2328 |
-
|
2329 |
-
|
2330 |
-
|
2331 |
-
|
2332 |
-
if not file_data.empty:
|
2333 |
-
basic_response_data = get_basic_response_data(file_data, chosen_cols)
|
2334 |
-
reference_table_df_revised_pivot = convert_reference_table_to_pivot_table(reference_table_df_revised, basic_response_data)
|
2335 |
-
|
2336 |
-
### Save pivot file to log area
|
2337 |
-
reference_table_df_revised_pivot_path = output_folder + batch_file_path_details + "_summarised_reference_table_pivot_" + model_choice_clean + ".csv"
|
2338 |
-
reference_table_df_revised_pivot.to_csv(reference_table_df_revised_pivot_path, index=None, encoding='utf-8')
|
2339 |
-
log_output_files.append(reference_table_df_revised_pivot_path)
|
2340 |
-
|
2341 |
-
# Save to file
|
2342 |
-
unique_table_df_revised_path = output_folder + batch_file_path_details + "_summarised_unique_topic_table_" + model_choice_clean + ".csv"
|
2343 |
-
unique_table_df_revised.to_csv(unique_table_df_revised_path, index = None, encoding='utf-8')
|
2344 |
-
|
2345 |
-
reference_table_df_revised_path = output_folder + batch_file_path_details + "_summarised_reference_table_" + model_choice_clean + ".csv"
|
2346 |
-
reference_table_df_revised.to_csv(reference_table_df_revised_path, index = None, encoding='utf-8')
|
2347 |
-
|
2348 |
-
output_files.extend([reference_table_df_revised_path, unique_table_df_revised_path])
|
2349 |
-
|
2350 |
-
###
|
2351 |
-
unique_table_df_revised_display = unique_table_df_revised.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
|
2352 |
-
|
2353 |
-
summarised_output_markdown = unique_table_df_revised_display.to_markdown(index=False)
|
2354 |
-
|
2355 |
-
# Ensure same file name not returned twice
|
2356 |
-
output_files = list(set(output_files))
|
2357 |
-
log_output_files = list(set(log_output_files))
|
2358 |
-
|
2359 |
-
return summarised_references, unique_table_df_revised, reference_table_df_revised, output_files, summarised_outputs, latest_summary_completed, out_metadata_str, summarised_output_markdown, log_output_files
|
2360 |
-
|
2361 |
-
tic = time.perf_counter()
|
2362 |
-
|
2363 |
-
#print("Starting with:", latest_summary_completed)
|
2364 |
-
#print("Last summary number:", length_all_summaries)
|
2365 |
-
|
2366 |
-
if (model_choice == "gemma_2b_it_local") & (RUN_LOCAL_MODEL == "1"):
|
2367 |
-
progress(0.1, "Loading in Gemma 2b model")
|
2368 |
-
local_model, tokenizer = load_model()
|
2369 |
-
#print("Local model loaded:", local_model)
|
2370 |
-
|
2371 |
-
summary_loop_description = "Creating summaries. " + str(latest_summary_completed) + " summaries completed so far."
|
2372 |
-
summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Creating summaries", unit="summaries")
|
2373 |
-
|
2374 |
-
if do_summaries == "Yes":
|
2375 |
-
for summary_no in summary_loop:
|
2376 |
-
|
2377 |
-
print("Current summary number is:", summary_no)
|
2378 |
-
|
2379 |
-
summary_text = all_summaries[summary_no]
|
2380 |
-
#print("summary_text:", summary_text)
|
2381 |
-
formatted_summary_prompt = [summarise_topic_descriptions_prompt.format(summaries=summary_text, summary_format=summarise_format_radio)]
|
2382 |
-
|
2383 |
-
try:
|
2384 |
-
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, summarise_topic_descriptions_system_prompt, local_model)
|
2385 |
-
summarised_output = response
|
2386 |
-
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
2387 |
-
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|
2388 |
-
summarised_output = summarised_output.strip()
|
2389 |
-
except Exception as e:
|
2390 |
-
print(e)
|
2391 |
-
summarised_output = ""
|
2392 |
-
|
2393 |
-
summarised_outputs.append(summarised_output)
|
2394 |
-
out_metadata.extend(metadata)
|
2395 |
-
out_metadata_str = '. '.join(out_metadata)
|
2396 |
-
|
2397 |
-
latest_summary_completed += 1
|
2398 |
-
|
2399 |
-
# Check if beyond max time allowed for processing and break if necessary
|
2400 |
-
toc = time.perf_counter()
|
2401 |
-
time_taken = tic - toc
|
2402 |
-
|
2403 |
-
if time_taken > max_time_for_loop:
|
2404 |
-
print("Time taken for loop is greater than maximum time allowed. Exiting and restarting loop")
|
2405 |
-
summary_loop.close()
|
2406 |
-
tqdm._instances.clear()
|
2407 |
-
break
|
2408 |
-
|
2409 |
-
# If all summaries completeed
|
2410 |
-
if latest_summary_completed >= length_all_summaries:
|
2411 |
-
print("At last summary.")
|
2412 |
-
|
2413 |
-
output_files = list(set(output_files))
|
2414 |
-
|
2415 |
-
return summarised_references, unique_table_df, reference_table_df, output_files, summarised_outputs, latest_summary_completed, out_metadata_str, summarised_output_markdown, log_output_files
|
|
|
6 |
import markdown
|
7 |
import time
|
8 |
import boto3
|
|
|
|
|
9 |
import string
|
10 |
import re
|
11 |
import spaces
|
|
|
12 |
from tqdm import tqdm
|
13 |
+
|
14 |
from gradio import Progress
|
15 |
from typing import List, Tuple
|
16 |
from io import StringIO
|
17 |
|
18 |
GradioFileData = gr.FileData
|
19 |
|
20 |
+
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt
|
21 |
+
from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_unique_table_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data
|
22 |
+
from tools.llm_funcs import ResponseObject, process_requests, construct_gemini_generative_model
|
23 |
+
from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, RUN_AWS_FUNCTIONS, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
if RUN_LOCAL_MODEL == "1":
|
26 |
+
from tools.llm_funcs import load_model
|
27 |
|
28 |
+
max_tokens = MAX_TOKENS
|
29 |
+
timeout_wait = TIMEOUT_WAIT
|
30 |
+
number_of_api_retry_attempts = NUMBER_OF_RETRY_ATTEMPTS
|
31 |
+
max_time_for_loop = MAX_TIME_FOR_LOOP
|
32 |
+
batch_size_default = BATCH_SIZE_DEFAULT
|
33 |
+
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
34 |
+
max_comment_character_length = MAX_COMMENT_CHARS
|
35 |
|
36 |
+
if RUN_AWS_FUNCTIONS == '1':
|
37 |
+
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
38 |
+
else:
|
39 |
+
bedrock_runtime = []
|
40 |
|
41 |
### HELPER FUNCTIONS
|
42 |
|
43 |
+
def normalise_string(text:str):
|
44 |
# Replace two or more dashes with a single dash
|
45 |
text = re.sub(r'-{2,}', '-', text)
|
46 |
|
|
|
120 |
|
121 |
return gr.Dataframe(value=unique_file_data, headers=None, col_count=(unique_file_data.shape[1], "fixed"), row_count = (unique_file_data.shape[0], "fixed"), visible=True, type="pandas"), reference_file_data, unique_file_data, reference_file_name, unique_file_name, out_file_names
|
122 |
|
123 |
+
def data_file_to_markdown_table(file_data:pd.DataFrame, file_name:str, chosen_cols: List[str], batch_number: int, batch_size: int, verify_titles:bool=False) -> Tuple[str, str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
"""
|
125 |
Processes a file by simplifying its content based on chosen columns and saves the result to a specified output folder.
|
126 |
|
|
|
128 |
- file_data (pd.DataFrame): Tabular data file with responses.
|
129 |
- file_name (str): File name with extension.
|
130 |
- chosen_cols (List[str]): A list of column names to include in the simplified file.
|
|
|
131 |
- batch_number (int): The current batch number for processing.
|
132 |
- batch_size (int): The number of rows to process in each batch.
|
133 |
|
|
|
196 |
# Translate the input string using the translation table
|
197 |
return input_string.translate(translation_table)
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
### INITIAL TOPIC MODEL DEVELOPMENT FUNCTIONS
|
200 |
|
201 |
def clean_markdown_table(text: str):
|
|
|
283 |
# Truncate to max_length
|
284 |
return column_name[:max_length]
|
285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
# Convert output table to markdown and then to a pandas dataframe to csv
|
287 |
def remove_before_last_term(input_string: str) -> str:
|
288 |
# Use regex to find the last occurrence of the term
|
|
|
401 |
temperature: float,
|
402 |
reported_batch_no: int,
|
403 |
local_model: object,
|
404 |
+
MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
|
405 |
+
master:bool=False,
|
406 |
+
CHOSEN_LOCAL_MODEL_TYPE:str=CHOSEN_LOCAL_MODEL_TYPE) -> Tuple[List[ResponseObject], List[dict], List[str], List[str], str]:
|
407 |
"""
|
408 |
Call the large language model with checks for a valid markdown table.
|
409 |
|
|
|
439 |
call_temperature, reported_batch_no, local_model, master=master
|
440 |
)
|
441 |
|
442 |
+
if model_choice != CHOSEN_LOCAL_MODEL_TYPE:
|
443 |
stripped_response = responses[-1].text.strip()
|
444 |
else:
|
445 |
stripped_response = responses[-1]['choices'][0]['text'].strip()
|
|
|
472 |
existing_reference_df:pd.DataFrame,
|
473 |
existing_topics_df:pd.DataFrame,
|
474 |
batch_size_number:int,
|
475 |
+
in_column:str,
|
476 |
+
first_run: bool = False,
|
477 |
+
output_folder:str=OUTPUT_FOLDER) -> None:
|
478 |
"""
|
479 |
Writes the output of the large language model requests and logs to files.
|
480 |
|
|
|
582 |
# If no numbers found in the Response References column, check the Summary column in case reference numbers were put there by mistake
|
583 |
if not references:
|
584 |
references = re.findall(r'\d+', str(row.iloc[4])) if pd.notna(row.iloc[4]) else []
|
585 |
+
|
586 |
+
# Filter out references that are outside the valid range
|
587 |
+
if references:
|
588 |
+
try:
|
589 |
+
# Convert all references to integers and keep only those within valid range
|
590 |
+
ref_numbers = [int(ref) for ref in references]
|
591 |
+
references = [str(ref) for ref in ref_numbers if 1 <= ref <= batch_size_number]
|
592 |
+
except ValueError:
|
593 |
+
# If any reference can't be converted to int, skip this row
|
594 |
+
print("Response value could not be converted to number:", references)
|
595 |
+
continue
|
596 |
+
|
597 |
topic = row.iloc[0] if pd.notna(row.iloc[0]) else ""
|
598 |
subtopic = row.iloc[1] if pd.notna(row.iloc[1]) else ""
|
599 |
sentiment = row.iloc[2] if pd.notna(row.iloc[2]) else ""
|
600 |
summary = row.iloc[4] if pd.notna(row.iloc[4]) else ""
|
601 |
# If the reference response column is very long, and there's nothing in the summary column, assume that the summary was put in the reference column
|
602 |
+
if not summary and (len(str(row.iloc[3])) > 30):
|
603 |
summary = row.iloc[3]
|
604 |
|
605 |
summary = row_number_string_start + summary
|
|
|
833 |
force_zero_shot_radio:str = "No",
|
834 |
in_excel_sheets:List[str] = [],
|
835 |
force_single_topic_radio:str = "No",
|
836 |
+
output_folder:str=OUTPUT_FOLDER,
|
837 |
force_single_topic_prompt:str=force_single_topic_prompt,
|
838 |
max_tokens:int=max_tokens,
|
839 |
model_name_map:dict=model_name_map,
|
840 |
+
max_time_for_loop:int=max_time_for_loop,
|
841 |
+
CHOSEN_LOCAL_MODEL_TYPE:str=CHOSEN_LOCAL_MODEL_TYPE,
|
842 |
progress=Progress(track_tqdm=True)):
|
843 |
|
844 |
'''
|
|
|
878 |
- force_zero_shot_radio (str, optional): Should responses be forced into a zero shot topic or not.
|
879 |
- in_excel_sheets (List[str], optional): List of excel sheets to load from input file.
|
880 |
- force_single_topic_radio (str, optional): Should the model be forced to assign only one single topic to each response (effectively a classifier).
|
881 |
+
- output_folder (str, optional): Output folder where results will be stored.
|
882 |
- force_single_topic_prompt (str, optional): The prompt for forcing the model to assign only one single topic to each response.
|
883 |
- max_tokens (int): The maximum number of tokens for the model.
|
884 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
885 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
886 |
+
- CHOSEN_LOCAL_MODEL_TYPE (str, optional): The name of the chosen local model.
|
887 |
- progress (Progress): A progress tracker.
|
888 |
'''
|
889 |
|
|
|
930 |
out_message = []
|
931 |
out_file_paths = []
|
932 |
|
933 |
+
if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
|
934 |
+
progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
|
935 |
+
local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER)
|
936 |
|
937 |
if num_batches > 0:
|
938 |
progress_measure = round(latest_batch_completed / num_batches, 1)
|
|
|
970 |
print("Running query batch", str(reported_batch_no))
|
971 |
|
972 |
# Call the function to prepare the input table
|
973 |
+
simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
|
974 |
#log_files_output_paths.append(simplified_csv_table_path)
|
975 |
|
976 |
# Conversation history
|
|
|
1085 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no, local_model, MAX_OUTPUT_VALIDATION_ATTEMPTS, master = True)
|
1086 |
|
1087 |
# Return output tables
|
1088 |
+
topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, new_topic_df, new_markdown_table, new_reference_df, new_unique_topics_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, first_run=False, output_folder=output_folder)
|
1089 |
|
1090 |
# Write final output to text file for logging purposes
|
1091 |
try:
|
|
|
1149 |
if "gemini" in model_choice:
|
1150 |
print("Using Gemini model:", model_choice)
|
1151 |
model, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
|
1152 |
+
elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
|
1153 |
+
print("Using local model:", model_choice)
|
1154 |
else:
|
1155 |
print("Using AWS Bedrock model:", model_choice)
|
1156 |
|
|
|
1177 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no, local_model, MAX_OUTPUT_VALIDATION_ATTEMPTS)
|
1178 |
|
1179 |
|
1180 |
+
topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_table_df, markdown_table, reference_df, new_unique_topics_df, batch_file_path_details, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, first_run=True, output_folder=output_folder)
|
1181 |
|
1182 |
# If error in table parsing, leave function
|
1183 |
if is_error == True:
|
|
|
1362 |
|
1363 |
return unique_table_df_display_table_markdown, existing_topics_table, existing_unique_topics_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, gr.Dataframe(value=modifiable_unique_topics_df, headers=None, col_count=(modifiable_unique_topics_df.shape[1], "fixed"), row_count = (modifiable_unique_topics_df.shape[0], "fixed"), visible=True, type="pandas"), out_file_paths, join_file_paths
|
1364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1365 |
def join_modified_topic_names_to_ref_table(modified_unique_topics_df:pd.DataFrame, original_unique_topics_df:pd.DataFrame, reference_df:pd.DataFrame):
|
1366 |
'''
|
1367 |
Take a unique topic table that has been modified by the user, and apply the topic name changes to the long-form reference table.
|
|
|
1402 |
return modified_reference_df
|
1403 |
|
1404 |
# MODIFY EXISTING TABLE
|
1405 |
+
def modify_existing_output_tables(original_unique_topics_df:pd.DataFrame, modifiable_unique_topics_df:pd.DataFrame, reference_df:pd.DataFrame, text_output_file_list_state:List[str], output_folder:str=OUTPUT_FOLDER) -> Tuple:
|
1406 |
'''
|
1407 |
Take a unique_topics table that has been modified, apply these new topic names to the long-form reference_df, and save both tables to file.
|
1408 |
'''
|
|
|
1419 |
reference_file_path = os.path.basename(reference_files[0]) if reference_files else None
|
1420 |
unique_table_file_path = os.path.basename(unique_files[0]) if unique_files else None
|
1421 |
|
|
|
|
|
|
|
1422 |
output_file_list = []
|
1423 |
|
1424 |
if reference_file_path and unique_table_file_path:
|
|
|
1462 |
|
1463 |
|
1464 |
return modifiable_unique_topics_df, reference_df, output_file_list, output_file_list, output_file_list, output_file_list, reference_table_file_name, unique_table_file_name, deduplicated_unique_table_markdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llm_funcs.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypeVar
|
2 |
+
import torch.cuda
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import boto3
|
6 |
+
import json
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from typing import List, Tuple
|
11 |
+
import google.generativeai as ai
|
12 |
+
import gradio as gr
|
13 |
+
from gradio import Progress
|
14 |
+
|
15 |
+
torch.cuda.empty_cache()
|
16 |
+
|
17 |
+
PandasDataFrame = TypeVar('pd.core.frame.DataFrame')
|
18 |
+
|
19 |
+
model_type = None # global variable setup
|
20 |
+
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
21 |
+
model = [] # Define empty list for model functions to run
|
22 |
+
tokenizer = [] #[] # Define empty list for model functions to run
|
23 |
+
|
24 |
+
|
25 |
+
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
|
26 |
+
# Check for torch cuda
|
27 |
+
print("Is CUDA enabled? ", torch.cuda.is_available())
|
28 |
+
print("Is a CUDA device available on this computer?", torch.backends.cudnn.enabled)
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
torch_device = "cuda"
|
31 |
+
gpu_layers = -1
|
32 |
+
os.system("nvidia-smi")
|
33 |
+
else:
|
34 |
+
torch_device = "cpu"
|
35 |
+
gpu_layers = 0
|
36 |
+
|
37 |
+
print("Device used is: ", torch_device)
|
38 |
+
print("Running on device:", torch_device)
|
39 |
+
|
40 |
+
from tools.config import RUN_AWS_FUNCTIONS, AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, RUN_LOCAL_MODEL, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN
|
41 |
+
|
42 |
+
if RUN_LOCAL_MODEL == "1":
|
43 |
+
print("Running local model - importing llama-cpp-python")
|
44 |
+
from llama_cpp import Llama
|
45 |
+
|
46 |
+
max_tokens = MAX_TOKENS
|
47 |
+
timeout_wait = TIMEOUT_WAIT
|
48 |
+
number_of_api_retry_attempts = NUMBER_OF_RETRY_ATTEMPTS
|
49 |
+
max_time_for_loop = MAX_TIME_FOR_LOOP
|
50 |
+
batch_size_default = BATCH_SIZE_DEFAULT
|
51 |
+
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
52 |
+
max_comment_character_length = MAX_COMMENT_CHARS
|
53 |
+
|
54 |
+
if RUN_AWS_FUNCTIONS == '1':
|
55 |
+
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
56 |
+
else:
|
57 |
+
bedrock_runtime = []
|
58 |
+
|
59 |
+
if not LLM_THREADS:
|
60 |
+
threads = torch.get_num_threads() # 8
|
61 |
+
else: threads = LLM_THREADS
|
62 |
+
print("CPU threads:", threads)
|
63 |
+
|
64 |
+
if LLM_RESET == 'True': reset = True
|
65 |
+
else: reset = False
|
66 |
+
|
67 |
+
if LLM_STREAM == 'True': stream = True
|
68 |
+
else: stream = False
|
69 |
+
|
70 |
+
if LLM_SAMPLE == 'True': sample = True
|
71 |
+
else: sample = False
|
72 |
+
|
73 |
+
temperature = LLM_TEMPERATURE
|
74 |
+
top_k = LLM_TOP_K
|
75 |
+
top_p = LLM_TOP_P
|
76 |
+
repetition_penalty = LLM_REPETITION_PENALTY # Mild repetition penalty to prevent repeating table rows
|
77 |
+
last_n_tokens = LLM_LAST_N_TOKENS
|
78 |
+
max_new_tokens: int = LLM_MAX_NEW_TOKENS
|
79 |
+
seed: int = LLM_SEED
|
80 |
+
reset: bool = reset
|
81 |
+
stream: bool = stream
|
82 |
+
threads: int = threads
|
83 |
+
batch_size:int = LLM_BATCH_SIZE
|
84 |
+
context_length:int = LLM_CONTEXT_LENGTH
|
85 |
+
sample = LLM_SAMPLE
|
86 |
+
|
87 |
+
class llama_cpp_init_config_gpu:
|
88 |
+
def __init__(self,
|
89 |
+
last_n_tokens=last_n_tokens,
|
90 |
+
seed=seed,
|
91 |
+
n_threads=threads,
|
92 |
+
n_batch=batch_size,
|
93 |
+
n_ctx=context_length,
|
94 |
+
n_gpu_layers=gpu_layers):
|
95 |
+
|
96 |
+
self.last_n_tokens = last_n_tokens
|
97 |
+
self.seed = seed
|
98 |
+
self.n_threads = n_threads
|
99 |
+
self.n_batch = n_batch
|
100 |
+
self.n_ctx = n_ctx
|
101 |
+
self.n_gpu_layers = n_gpu_layers
|
102 |
+
# self.stop: list[str] = field(default_factory=lambda: [stop_string])
|
103 |
+
|
104 |
+
def update_gpu(self, new_value):
|
105 |
+
self.n_gpu_layers = new_value
|
106 |
+
|
107 |
+
def update_context(self, new_value):
|
108 |
+
self.n_ctx = new_value
|
109 |
+
|
110 |
+
class llama_cpp_init_config_cpu(llama_cpp_init_config_gpu):
|
111 |
+
def __init__(self):
|
112 |
+
super().__init__()
|
113 |
+
self.n_gpu_layers = gpu_layers
|
114 |
+
self.n_ctx=context_length
|
115 |
+
|
116 |
+
gpu_config = llama_cpp_init_config_gpu()
|
117 |
+
cpu_config = llama_cpp_init_config_cpu()
|
118 |
+
|
119 |
+
|
120 |
+
class LlamaCPPGenerationConfig:
|
121 |
+
def __init__(self, temperature=temperature,
|
122 |
+
top_k=top_k,
|
123 |
+
top_p=top_p,
|
124 |
+
repeat_penalty=repetition_penalty,
|
125 |
+
seed=seed,
|
126 |
+
stream=stream,
|
127 |
+
max_tokens=max_new_tokens
|
128 |
+
):
|
129 |
+
self.temperature = temperature
|
130 |
+
self.top_k = top_k
|
131 |
+
self.top_p = top_p
|
132 |
+
self.repeat_penalty = repeat_penalty
|
133 |
+
self.seed = seed
|
134 |
+
self.max_tokens=max_tokens
|
135 |
+
self.stream = stream
|
136 |
+
|
137 |
+
def update_temp(self, new_value):
|
138 |
+
self.temperature = new_value
|
139 |
+
|
140 |
+
# ResponseObject class for AWS Bedrock calls
|
141 |
+
class ResponseObject:
|
142 |
+
def __init__(self, text, usage_metadata):
|
143 |
+
self.text = text
|
144 |
+
self.usage_metadata = usage_metadata
|
145 |
+
|
146 |
+
###
|
147 |
+
# LOCAL MODEL FUNCTIONS
|
148 |
+
###
|
149 |
+
|
150 |
+
def get_model_path(repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER, hf_token=HF_TOKEN):
|
151 |
+
# Construct the expected local path
|
152 |
+
local_path = os.path.join(model_dir, model_filename)
|
153 |
+
|
154 |
+
print("local path for model load:", local_path)
|
155 |
+
|
156 |
+
if os.path.exists(local_path):
|
157 |
+
print(f"Model already exists at: {local_path}")
|
158 |
+
|
159 |
+
return local_path
|
160 |
+
else:
|
161 |
+
print(f"Checking default Hugging Face folder. Downloading model from Hugging Face Hub if not found")
|
162 |
+
if hf_token:
|
163 |
+
downloaded_model_path = hf_hub_download(repo_id=repo_id, token=hf_token, filename=model_filename)
|
164 |
+
else:
|
165 |
+
downloaded_model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
166 |
+
|
167 |
+
return downloaded_model_path
|
168 |
+
|
169 |
+
def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE, gpu_layers:int=gpu_layers, max_context_length:int=context_length, gpu_config:llama_cpp_init_config_gpu=gpu_config, cpu_config:llama_cpp_init_config_cpu=cpu_config, torch_device:str=torch_device, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER):
|
170 |
+
'''
|
171 |
+
Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
|
172 |
+
'''
|
173 |
+
print("Loading model ", local_model_type)
|
174 |
+
model_path = get_model_path(repo_id=repo_id, model_filename=model_filename, model_dir=model_dir)
|
175 |
+
|
176 |
+
print("model_path:", model_path)
|
177 |
+
|
178 |
+
# GPU mode
|
179 |
+
if torch_device == "cuda":
|
180 |
+
gpu_config.update_gpu(gpu_layers)
|
181 |
+
gpu_config.update_context(max_context_length)
|
182 |
+
|
183 |
+
try:
|
184 |
+
print("GPU load variables:" , vars(gpu_config))
|
185 |
+
llama_model = Llama(model_path=model_path, **vars(gpu_config)) # type_k=8, type_v = 8, flash_attn=True,
|
186 |
+
|
187 |
+
except Exception as e:
|
188 |
+
print("GPU load failed due to:", e)
|
189 |
+
llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config)) # type_v = 8, flash_attn=True,
|
190 |
+
|
191 |
+
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU. And a maximum context length of ", gpu_config.n_ctx)
|
192 |
+
|
193 |
+
# CPU mode
|
194 |
+
else:
|
195 |
+
gpu_config.update_gpu(gpu_layers)
|
196 |
+
cpu_config.update_gpu(gpu_layers)
|
197 |
+
|
198 |
+
# Update context length according to slider
|
199 |
+
gpu_config.update_context(max_context_length)
|
200 |
+
cpu_config.update_context(max_context_length)
|
201 |
+
|
202 |
+
llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config)) # type_v = 8, flash_attn=True,
|
203 |
+
|
204 |
+
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU. And a maximum context length of ", gpu_config.n_ctx)
|
205 |
+
|
206 |
+
tokenizer = []
|
207 |
+
|
208 |
+
print("Finished loading model:", local_model_type)
|
209 |
+
return llama_model, tokenizer
|
210 |
+
|
211 |
+
def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
|
212 |
+
"""
|
213 |
+
Calls your generation model with parameters from the LlamaCPPGenerationConfig object.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
formatted_string (str): The formatted input text for the model.
|
217 |
+
gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
|
218 |
+
"""
|
219 |
+
# Extracting parameters from the gen_config object
|
220 |
+
temperature = gen_config.temperature
|
221 |
+
top_k = gen_config.top_k
|
222 |
+
top_p = gen_config.top_p
|
223 |
+
repeat_penalty = gen_config.repeat_penalty
|
224 |
+
seed = gen_config.seed
|
225 |
+
max_tokens = gen_config.max_tokens
|
226 |
+
stream = gen_config.stream
|
227 |
+
|
228 |
+
# Now you can call your model directly, passing the parameters:
|
229 |
+
output = model(
|
230 |
+
formatted_string,
|
231 |
+
temperature=temperature,
|
232 |
+
top_k=top_k,
|
233 |
+
top_p=top_p,
|
234 |
+
repeat_penalty=repeat_penalty,
|
235 |
+
seed=seed,
|
236 |
+
max_tokens=max_tokens,
|
237 |
+
stream=stream#,
|
238 |
+
#stop=["<|eot_id|>", "\n\n"]
|
239 |
+
)
|
240 |
+
|
241 |
+
return output
|
242 |
+
|
243 |
+
# This function is not used in this app
|
244 |
+
def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
245 |
+
|
246 |
+
gen_config = LlamaCPPGenerationConfig()
|
247 |
+
gen_config.update_temp(temperature)
|
248 |
+
|
249 |
+
print(vars(gen_config))
|
250 |
+
|
251 |
+
# Pull the generated text from the streamer, and update the model output.
|
252 |
+
start = time.time()
|
253 |
+
NUM_TOKENS=0
|
254 |
+
print('-'*4+'Start Generation'+'-'*4)
|
255 |
+
|
256 |
+
output = model(
|
257 |
+
full_prompt, **vars(gen_config))
|
258 |
+
|
259 |
+
history[-1][1] = ""
|
260 |
+
for out in output:
|
261 |
+
|
262 |
+
if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
|
263 |
+
history[-1][1] += out["choices"][0]["text"]
|
264 |
+
NUM_TOKENS+=1
|
265 |
+
yield history
|
266 |
+
else:
|
267 |
+
print(f"Unexpected output structure: {out}")
|
268 |
+
|
269 |
+
time_generate = time.time() - start
|
270 |
+
print('\n')
|
271 |
+
print('-'*4+'End Generation'+'-'*4)
|
272 |
+
print(f'Num of generated tokens: {NUM_TOKENS}')
|
273 |
+
print(f'Time for complete generation: {time_generate}s')
|
274 |
+
print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
|
275 |
+
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
276 |
+
|
277 |
+
###
|
278 |
+
# LLM FUNCTIONS
|
279 |
+
###
|
280 |
+
|
281 |
+
def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int) -> Tuple[object, dict]:
|
282 |
+
"""
|
283 |
+
Constructs a GenerativeModel for Gemini API calls.
|
284 |
+
|
285 |
+
Parameters:
|
286 |
+
- in_api_key (str): The API key for authentication.
|
287 |
+
- temperature (float): The temperature parameter for the model, controlling the randomness of the output.
|
288 |
+
- model_choice (str): The choice of model to use for generation.
|
289 |
+
- system_prompt (str): The system prompt to guide the generation.
|
290 |
+
- max_tokens (int): The maximum number of tokens to generate.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
- Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
|
294 |
+
"""
|
295 |
+
# Construct a GenerativeModel
|
296 |
+
try:
|
297 |
+
if in_api_key:
|
298 |
+
#print("Getting API key from textbox")
|
299 |
+
api_key = in_api_key
|
300 |
+
ai.configure(api_key=api_key)
|
301 |
+
elif "GOOGLE_API_KEY" in os.environ:
|
302 |
+
#print("Searching for API key in environmental variables")
|
303 |
+
api_key = os.environ["GOOGLE_API_KEY"]
|
304 |
+
ai.configure(api_key=api_key)
|
305 |
+
else:
|
306 |
+
print("No API key foound")
|
307 |
+
raise gr.Error("No API key found.")
|
308 |
+
except Exception as e:
|
309 |
+
print(e)
|
310 |
+
|
311 |
+
config = ai.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens)
|
312 |
+
|
313 |
+
#model = ai.GenerativeModel.from_cached_content(cached_content=cache, generation_config=config)
|
314 |
+
model = ai.GenerativeModel(model_name='models/' + model_choice, system_instruction=system_prompt, generation_config=config)
|
315 |
+
|
316 |
+
# Upload CSV file (replace with your actual file path)
|
317 |
+
#file_id = ai.upload_file(upload_file_path)
|
318 |
+
|
319 |
+
|
320 |
+
# if file_type == 'xlsx':
|
321 |
+
# print("Running through all xlsx sheets")
|
322 |
+
# #anon_xlsx = pd.ExcelFile(upload_file_path)
|
323 |
+
# if not in_excel_sheets:
|
324 |
+
# out_message.append("No Excel sheets selected. Please select at least one to anonymise.")
|
325 |
+
# continue
|
326 |
+
|
327 |
+
# anon_xlsx = pd.ExcelFile(upload_file_path)
|
328 |
+
|
329 |
+
# # Create xlsx file:
|
330 |
+
# anon_xlsx_export_file_name = output_folder + file_name + "_redacted.xlsx"
|
331 |
+
|
332 |
+
|
333 |
+
### QUERYING LARGE LANGUAGE MODEL ###
|
334 |
+
# Prompt caching the table and system prompt. See here: https://ai.google.dev/gemini-api/docs/caching?lang=python
|
335 |
+
# Create a cache with a 5 minute TTL. ONLY FOR CACHES OF AT LEAST 32k TOKENS!
|
336 |
+
# cache = ai.caching.CachedContent.create(
|
337 |
+
# model='models/' + model_choice,
|
338 |
+
# display_name=file_name, # used to identify the cache
|
339 |
+
# system_instruction=system_prompt,
|
340 |
+
# ttl=datetime.timedelta(minutes=5),
|
341 |
+
# )
|
342 |
+
|
343 |
+
return model, config
|
344 |
+
|
345 |
+
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
|
346 |
+
"""
|
347 |
+
This function sends a request to AWS Claude with the following parameters:
|
348 |
+
- prompt: The user's input prompt to be processed by the model.
|
349 |
+
- system_prompt: A system-defined prompt that provides context or instructions for the model.
|
350 |
+
- temperature: A value that controls the randomness of the model's output, with higher values resulting in more diverse responses.
|
351 |
+
- max_tokens: The maximum number of tokens (words or characters) in the model's response.
|
352 |
+
- model_choice: The specific model to use for processing the request.
|
353 |
+
|
354 |
+
The function constructs the request configuration, invokes the model, extracts the response text, and returns a ResponseObject containing the text and metadata.
|
355 |
+
"""
|
356 |
+
|
357 |
+
prompt_config = {
|
358 |
+
"anthropic_version": "bedrock-2023-05-31",
|
359 |
+
"max_tokens": max_tokens,
|
360 |
+
"top_p": 0.999,
|
361 |
+
"temperature":temperature,
|
362 |
+
"system": system_prompt,
|
363 |
+
"messages": [
|
364 |
+
{
|
365 |
+
"role": "user",
|
366 |
+
"content": [
|
367 |
+
{"type": "text", "text": prompt},
|
368 |
+
],
|
369 |
+
}
|
370 |
+
],
|
371 |
+
}
|
372 |
+
|
373 |
+
body = json.dumps(prompt_config)
|
374 |
+
|
375 |
+
modelId = model_choice
|
376 |
+
accept = "application/json"
|
377 |
+
contentType = "application/json"
|
378 |
+
|
379 |
+
request = bedrock_runtime.invoke_model(
|
380 |
+
body=body, modelId=modelId, accept=accept, contentType=contentType
|
381 |
+
)
|
382 |
+
|
383 |
+
# Extract text from request
|
384 |
+
response_body = json.loads(request.get("body").read())
|
385 |
+
text = response_body.get("content")[0].get("text")
|
386 |
+
|
387 |
+
response = ResponseObject(
|
388 |
+
text=text,
|
389 |
+
usage_metadata=request['ResponseMetadata']
|
390 |
+
)
|
391 |
+
|
392 |
+
# Now you can access both the text and metadata
|
393 |
+
#print("Text:", response.text)
|
394 |
+
#print("Metadata:", response.usage_metadata)
|
395 |
+
#print("Text:", response.text)
|
396 |
+
|
397 |
+
return response
|
398 |
+
|
399 |
+
# Function to send a request and update history
|
400 |
+
def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, local_model=[], progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
401 |
+
"""
|
402 |
+
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
403 |
+
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
404 |
+
If the model choice is specific to AWS Claude, it calls the `call_aws_claude` function; otherwise, it uses the `model.generate_content` method.
|
405 |
+
The function returns the response text and the updated conversation history.
|
406 |
+
"""
|
407 |
+
# Constructing the full prompt from the conversation history
|
408 |
+
full_prompt = "Conversation history:\n"
|
409 |
+
|
410 |
+
for entry in conversation_history:
|
411 |
+
role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'parts'
|
412 |
+
message = ' '.join(entry['parts']) # Combining all parts of the message
|
413 |
+
full_prompt += f"{role}: {message}\n"
|
414 |
+
|
415 |
+
# Adding the new user prompt
|
416 |
+
full_prompt += f"\nUser: {prompt}"
|
417 |
+
|
418 |
+
# Clear any existing progress bars
|
419 |
+
tqdm._instances.clear()
|
420 |
+
|
421 |
+
progress_bar = range(0,number_of_api_retry_attempts)
|
422 |
+
|
423 |
+
# Generate the model's response
|
424 |
+
if "gemini" in model_choice:
|
425 |
+
|
426 |
+
for i in progress_bar:
|
427 |
+
try:
|
428 |
+
print("Calling Gemini model, attempt", i + 1)
|
429 |
+
#print("full_prompt:", full_prompt)
|
430 |
+
#print("generation_config:", config)
|
431 |
+
|
432 |
+
response = model.generate_content(contents=full_prompt, generation_config=config)
|
433 |
+
|
434 |
+
#progress_bar.close()
|
435 |
+
#tqdm._instances.clear()
|
436 |
+
|
437 |
+
print("Successful call to Gemini model.")
|
438 |
+
break
|
439 |
+
except Exception as e:
|
440 |
+
# If fails, try again after X seconds in case there is a throttle limit
|
441 |
+
print("Call to Gemini model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
442 |
+
|
443 |
+
time.sleep(timeout_wait)
|
444 |
+
|
445 |
+
if i == number_of_api_retry_attempts:
|
446 |
+
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
447 |
+
elif "anthropic.claude" in model_choice:
|
448 |
+
for i in progress_bar:
|
449 |
+
try:
|
450 |
+
print("Calling AWS Claude model, attempt", i + 1)
|
451 |
+
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
452 |
+
|
453 |
+
#progress_bar.close()
|
454 |
+
#tqdm._instances.clear()
|
455 |
+
|
456 |
+
print("Successful call to Claude model.")
|
457 |
+
break
|
458 |
+
except Exception as e:
|
459 |
+
# If fails, try again after X seconds in case there is a throttle limit
|
460 |
+
print("Call to Claude model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
461 |
+
|
462 |
+
time.sleep(timeout_wait)
|
463 |
+
#response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
464 |
+
|
465 |
+
if i == number_of_api_retry_attempts:
|
466 |
+
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
467 |
+
else:
|
468 |
+
# This is the local model
|
469 |
+
for i in progress_bar:
|
470 |
+
try:
|
471 |
+
print("Calling local model, attempt", i + 1)
|
472 |
+
|
473 |
+
gen_config = LlamaCPPGenerationConfig()
|
474 |
+
gen_config.update_temp(temperature)
|
475 |
+
|
476 |
+
response = call_llama_cpp_model(prompt, gen_config, model=local_model)
|
477 |
+
|
478 |
+
#progress_bar.close()
|
479 |
+
#tqdm._instances.clear()
|
480 |
+
|
481 |
+
print("Successful call to local model. Response:", response)
|
482 |
+
break
|
483 |
+
except Exception as e:
|
484 |
+
# If fails, try again after X seconds in case there is a throttle limit
|
485 |
+
print("Call to Gemma model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
486 |
+
|
487 |
+
time.sleep(timeout_wait)
|
488 |
+
#response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
489 |
+
|
490 |
+
if i == number_of_api_retry_attempts:
|
491 |
+
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history
|
492 |
+
|
493 |
+
# Update the conversation history with the new prompt and response
|
494 |
+
conversation_history.append({'role': 'user', 'parts': [prompt]})
|
495 |
+
|
496 |
+
# Check if is a LLama.cpp model response
|
497 |
+
# Check if the response is a ResponseObject
|
498 |
+
if isinstance(response, ResponseObject):
|
499 |
+
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
500 |
+
elif 'choices' in response:
|
501 |
+
conversation_history.append({'role': 'assistant', 'parts': [response['choices'][0]['text']]})
|
502 |
+
else:
|
503 |
+
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
504 |
+
|
505 |
+
# Print the updated conversation history
|
506 |
+
#print("conversation_history:", conversation_history)
|
507 |
+
|
508 |
+
return response, conversation_history
|
509 |
+
|
510 |
+
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, local_model = [], master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
511 |
+
"""
|
512 |
+
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
513 |
+
|
514 |
+
Args:
|
515 |
+
prompts (List[str]): A list of prompts to be processed.
|
516 |
+
system_prompt (str): The system prompt.
|
517 |
+
conversation_history (List[dict]): The history of the conversation.
|
518 |
+
whole_conversation (List[str]): The complete conversation including prompts and responses.
|
519 |
+
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
520 |
+
model (object): The model to use for processing the prompts.
|
521 |
+
config (dict): Configuration for the model.
|
522 |
+
model_choice (str): The choice of model to use.
|
523 |
+
temperature (float): The temperature parameter for the model.
|
524 |
+
batch_no (int): Batch number of the large language model request.
|
525 |
+
local_model: Local gguf model (if loaded)
|
526 |
+
master (bool): Is this request for the master table.
|
527 |
+
|
528 |
+
Returns:
|
529 |
+
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
|
530 |
+
"""
|
531 |
+
responses = []
|
532 |
+
|
533 |
+
# Clear any existing progress bars
|
534 |
+
tqdm._instances.clear()
|
535 |
+
|
536 |
+
for prompt in prompts:
|
537 |
+
|
538 |
+
#print("prompt to LLM:", prompt)
|
539 |
+
|
540 |
+
response, conversation_history = send_request(prompt, conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model)
|
541 |
+
|
542 |
+
if isinstance(response, ResponseObject):
|
543 |
+
response_text = response.text
|
544 |
+
elif 'choices' in response:
|
545 |
+
response_text = response['choices'][0]['text']
|
546 |
+
else:
|
547 |
+
response_text = response.text
|
548 |
+
|
549 |
+
responses.append(response)
|
550 |
+
whole_conversation.append(prompt)
|
551 |
+
whole_conversation.append(response_text)
|
552 |
+
|
553 |
+
# Create conversation metadata
|
554 |
+
if master == False:
|
555 |
+
whole_conversation_metadata.append(f"Query batch {batch_no} prompt {len(responses)} metadata:")
|
556 |
+
else:
|
557 |
+
whole_conversation_metadata.append(f"Query summary metadata:")
|
558 |
+
|
559 |
+
if not isinstance(response, str):
|
560 |
+
try:
|
561 |
+
print("model_choice:", model_choice)
|
562 |
+
if "claude" in model_choice:
|
563 |
+
print("Appending selected metadata items to metadata")
|
564 |
+
whole_conversation_metadata.append('x-amzn-bedrock-output-token-count:')
|
565 |
+
whole_conversation_metadata.append(str(response.usage_metadata['HTTPHeaders']['x-amzn-bedrock-output-token-count']))
|
566 |
+
whole_conversation_metadata.append('x-amzn-bedrock-input-token-count:')
|
567 |
+
whole_conversation_metadata.append(str(response.usage_metadata['HTTPHeaders']['x-amzn-bedrock-input-token-count']))
|
568 |
+
elif "gemini" in model_choice:
|
569 |
+
whole_conversation_metadata.append(str(response.usage_metadata))
|
570 |
+
else:
|
571 |
+
whole_conversation_metadata.append(str(response['usage']))
|
572 |
+
except KeyError as e:
|
573 |
+
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
574 |
+
else:
|
575 |
+
print("Response is a string object.")
|
576 |
+
whole_conversation_metadata.append("Length prompt: " + str(len(prompt)) + ". Length response: " + str(len(response)))
|
577 |
+
|
578 |
+
|
579 |
+
return responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text
|
tools/verify_titles.py
CHANGED
@@ -17,12 +17,12 @@ from gradio import Progress
|
|
17 |
from typing import List, Tuple
|
18 |
from io import StringIO
|
19 |
|
20 |
-
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt,
|
21 |
-
from tools.helper_functions import
|
22 |
-
from tools.
|
23 |
-
from tools.llm_api_call import
|
24 |
|
25 |
-
|
26 |
|
27 |
def write_llm_output_and_logs_verify(responses: List[ResponseObject],
|
28 |
whole_conversation: List[str],
|
@@ -37,8 +37,9 @@ def write_llm_output_and_logs_verify(responses: List[ResponseObject],
|
|
37 |
existing_reference_df:pd.DataFrame,
|
38 |
existing_topics_df:pd.DataFrame,
|
39 |
batch_size_number:int,
|
40 |
-
in_column:str,
|
41 |
-
first_run: bool = False
|
|
|
42 |
"""
|
43 |
Writes the output of the large language model requests and logs to files.
|
44 |
|
@@ -56,6 +57,7 @@ def write_llm_output_and_logs_verify(responses: List[ResponseObject],
|
|
56 |
- existing_reference_df (pd.DataFrame): The existing reference dataframe mapping response numbers to topics.
|
57 |
- existing_topics_df (pd.DataFrame): The existing unique topics dataframe
|
58 |
- first_run (bool): A boolean indicating if this is the first run through this function in this process. Defaults to False.
|
|
|
59 |
"""
|
60 |
unique_topics_df_out_path = []
|
61 |
topic_table_out_path = "topic_table_error.csv"
|
@@ -236,6 +238,7 @@ def verify_titles(in_data_file,
|
|
236 |
sentiment_checkbox:str = "Negative, Neutral, or Positive",
|
237 |
force_zero_shot_radio:str = "No",
|
238 |
in_excel_sheets:List[str] = [],
|
|
|
239 |
max_tokens:int=max_tokens,
|
240 |
model_name_map:dict=model_name_map,
|
241 |
max_time_for_loop:int=max_time_for_loop,
|
@@ -276,7 +279,8 @@ def verify_titles(in_data_file,
|
|
276 |
- time_taken (float, optional): The amount of time taken to process the responses up until this point.
|
277 |
- sentiment_checkbox (str, optional): What type of sentiment analysis should the topic modeller do?
|
278 |
- force_zero_shot_radio (str, optional): Should responses be forced into a zero shot topic or not.
|
279 |
-
- in_excel_sheets (List[str], optional): List of excel sheets to load from input file
|
|
|
280 |
- max_tokens (int): The maximum number of tokens for the model.
|
281 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
282 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
|
|
17 |
from typing import List, Tuple
|
18 |
from io import StringIO
|
19 |
|
20 |
+
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt,add_existing_topics_system_prompt, add_existing_topics_prompt
|
21 |
+
from tools.helper_functions import put_columns_in_df, wrap_text
|
22 |
+
from tools.llm_funcs import load_model, construct_gemini_generative_model
|
23 |
+
from tools.llm_api_call import load_in_data_file, get_basic_response_data, data_file_to_markdown_table, clean_column_name, convert_response_text_to_markdown_table, call_llm_with_markdown_table_checks, ResponseObject, max_tokens, max_time_for_loop, batch_size_default, GradioFileData
|
24 |
|
25 |
+
from tools.config import MAX_OUTPUT_VALIDATION_ATTEMPTS, RUN_LOCAL_MODEL, model_name_map, OUTPUT_FOLDER
|
26 |
|
27 |
def write_llm_output_and_logs_verify(responses: List[ResponseObject],
|
28 |
whole_conversation: List[str],
|
|
|
37 |
existing_reference_df:pd.DataFrame,
|
38 |
existing_topics_df:pd.DataFrame,
|
39 |
batch_size_number:int,
|
40 |
+
in_column:str,
|
41 |
+
first_run: bool = False,
|
42 |
+
output_folder:str=OUTPUT_FOLDER) -> None:
|
43 |
"""
|
44 |
Writes the output of the large language model requests and logs to files.
|
45 |
|
|
|
57 |
- existing_reference_df (pd.DataFrame): The existing reference dataframe mapping response numbers to topics.
|
58 |
- existing_topics_df (pd.DataFrame): The existing unique topics dataframe
|
59 |
- first_run (bool): A boolean indicating if this is the first run through this function in this process. Defaults to False.
|
60 |
+
- output_folder (str): A string indicating the folder to output to
|
61 |
"""
|
62 |
unique_topics_df_out_path = []
|
63 |
topic_table_out_path = "topic_table_error.csv"
|
|
|
238 |
sentiment_checkbox:str = "Negative, Neutral, or Positive",
|
239 |
force_zero_shot_radio:str = "No",
|
240 |
in_excel_sheets:List[str] = [],
|
241 |
+
output_folder:str=OUTPUT_FOLDER,
|
242 |
max_tokens:int=max_tokens,
|
243 |
model_name_map:dict=model_name_map,
|
244 |
max_time_for_loop:int=max_time_for_loop,
|
|
|
279 |
- time_taken (float, optional): The amount of time taken to process the responses up until this point.
|
280 |
- sentiment_checkbox (str, optional): What type of sentiment analysis should the topic modeller do?
|
281 |
- force_zero_shot_radio (str, optional): Should responses be forced into a zero shot topic or not.
|
282 |
+
- in_excel_sheets (List[str], optional): List of excel sheets to load from input file.
|
283 |
+
- output_folder (str): The output folder where files will be saved.
|
284 |
- max_tokens (int): The maximum number of tokens for the model.
|
285 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
286 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
windows_install_llama-cpp-python.txt
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
|
3 |
+
set PKG_CONFIG_PATH=C:\<path-to-openblas>\OpenBLAS\lib\pkgconfig # Set this in environment variables
|
4 |
+
|
5 |
+
|
6 |
+
pip install llama-cpp-python==0.3.9 --force-reinstall --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/<path-to-openblas>/OpenBLAS/include;-DBLAS_LIBRARIES=C:/<path-to-openblas>/OpenBLAS/lib/libopenblas.lib"
|
7 |
+
---
|
8 |
+
|
9 |
+
# With CUDA
|
10 |
+
|
11 |
+
pip install llama-cpp-python==0.3.9 --force-reinstall --no-cache-dir --verbose -C cmake.args="-DGGML_CUDA=on"
|
12 |
+
|
13 |
+
|
14 |
+
---
|
15 |
+
|
16 |
+
|
17 |
+
How to Make it Work: Step-by-Step Guide
|
18 |
+
To successfully run your command, you need to set up a proper C++ development environment.
|
19 |
+
|
20 |
+
Step 1: Install the C++ Compiler
|
21 |
+
Go to the Visual Studio downloads page.
|
22 |
+
|
23 |
+
Scroll down to "Tools for Visual Studio" and download the "Build Tools for Visual Studio". This is a standalone installer that gives you the C++ compiler and libraries without installing the full Visual Studio IDE.
|
24 |
+
|
25 |
+
Run the installer. In the "Workloads" tab, check the box for "Desktop development with C++".
|
26 |
+
|
27 |
+
MSVC v143
|
28 |
+
C++ ATL
|
29 |
+
C++ Profiling tools
|
30 |
+
C++ CMake tools for Windows
|
31 |
+
C++ MFC
|
32 |
+
C++ Modules
|
33 |
+
Windows 10 SDK (10.0.20348.0)
|
34 |
+
|
35 |
+
Proceed with the installation.
|
36 |
+
|
37 |
+
|
38 |
+
Need to use 'x64 Native Tools Command Prompt for VS 2022' to install. Run as administrator
|
39 |
+
|
40 |
+
Step 2: Install CMake
|
41 |
+
Go to the CMake download page.
|
42 |
+
|
43 |
+
Download the latest Windows installer (e.g., cmake-x.xx.x-windows-x86_64.msi).
|
44 |
+
|
45 |
+
Run the installer. Crucially, when prompted, select the option to "Add CMake to the system PATH for all users" or "for the current user." This allows you to run cmake from any command prompt.
|
46 |
+
|
47 |
+
|
48 |
+
Step 3: Download and Place OpenBLAS
|
49 |
+
This is often the trickiest part.
|
50 |
+
|
51 |
+
Go to the OpenBLAS releases on GitHub.
|
52 |
+
|
53 |
+
Find a recent release and download the pre-compiled version for Windows. It will typically be a file named something like OpenBLAS-0.3.21-x64.zip (the version number will change). Make sure you get the 64-bit (x64) version if you are using 64-bit Python.
|
54 |
+
|
55 |
+
Create a folder somewhere easily accessible, for example, C:\libs\.
|
56 |
+
|
57 |
+
Extract the contents of the OpenBLAS zip file into that folder. Your final directory structure should look something like this:
|
58 |
+
|
59 |
+
Generated code
|
60 |
+
C:\libs\OpenBLAS\
|
61 |
+
βββ bin\
|
62 |
+
βββ include\
|
63 |
+
βββ lib\
|
64 |
+
Use code with caution.
|
65 |
+
|
66 |
+
3.b. Install Chocolatey
|
67 |
+
https://chocolatey.org/install
|
68 |
+
|
69 |
+
Step 1: Install Chocolatey (if you don't already have it)
|
70 |
+
Open PowerShell as an Administrator. (Right-click the Start Menu -> "Windows PowerShell (Admin)" or "Terminal (Admin)").
|
71 |
+
|
72 |
+
Run the following command to install Chocolatey. It's a single, long line:
|
73 |
+
|
74 |
+
Generated powershell
|
75 |
+
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
|
76 |
+
Use code with caution.
|
77 |
+
Powershell
|
78 |
+
Wait for it to finish. Once it's done, close the Administrator PowerShell window.
|
79 |
+
|
80 |
+
Step 2: Install pkg-config-lite using Chocolatey
|
81 |
+
IMPORTANT: Open a NEW command prompt or PowerShell window (as a regular user is fine). This is necessary so it recognizes the new choco command.
|
82 |
+
|
83 |
+
Run the following command to install a lightweight version of pkg-config:
|
84 |
+
|
85 |
+
Generated cmd
|
86 |
+
choco install pkgconfiglite
|
87 |
+
Use code with caution.
|
88 |
+
Cmd
|
89 |
+
Approve the installation by typing Y or A if prompted.
|
90 |
+
|
91 |
+
|
92 |
+
Step 4: Run the Installation Command
|
93 |
+
Now you have all the pieces. The final step is to run the command in a terminal that is aware of your new build environment.
|
94 |
+
|
95 |
+
Open the "Developer Command Prompt for VS" from your Start Menu. This is important! This special command prompt automatically configures all the necessary paths for the C++ compiler.
|
96 |
+
|
97 |
+
## For CPU
|
98 |
+
|
99 |
+
set PKG_CONFIG_PATH=C:\<path-to-openblas>\OpenBLAS\lib\pkgconfig # Set this in environment variables
|
100 |
+
|
101 |
+
|
102 |
+
pip install llama-cpp-python==0.3.9 --force-reinstall --verbose --no-cache-dir -Ccmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS;-DBLAS_INCLUDE_DIRS=C:/<path-to-openblas>/OpenBLAS/include;-DBLAS_LIBRARIES=C:/<path-to-openblas>/OpenBLAS/lib/libopenblas.lib"
|
103 |
+
|
104 |
+
## With Cuda
|
105 |
+
|
106 |
+
|
107 |
+
Use NVIDIA GPU (cuBLAS): If you have an NVIDIA GPU, using cuBLAS is often easier because the CUDA Toolkit installer handles most of the setup.
|
108 |
+
|
109 |
+
Install the NVIDIA CUDA Toolkit.
|
110 |
+
|
111 |
+
Run the install command specifying cuBLAS:
|
112 |
+
|
113 |
+
|
114 |
+
set PKG_CONFIG_PATH=C:\<path-to-openblas>\OpenBLAS\lib\pkgconfig # Set this in environment variables
|
115 |
+
|
116 |
+
pip install llama-cpp-python==0.3.9 --force-reinstall --no-cache-dir --verbose -C cmake.args="-DGGML_CUDA=on"
|
117 |
+
|
118 |
+
|