Spaces:
Running
Running
apply new backend
Browse files- .gitignore +2 -0
- src/content/agent.py +13 -2
- src/content/common.py +97 -88
- src/exceptions.py +0 -4
- src/generation.py +0 -140
- src/retrieval.py +0 -75
- src/tunnel.py +0 -72
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
src/content/agent.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import streamlit as st
|
| 3 |
|
| 4 |
-
from src.retrieval import STANDARD_QUERIES
|
| 5 |
from src.content.common import (
|
| 6 |
MODEL_NAMES,
|
| 7 |
AUDIO_SAMPLES_W_INSTRUCT,
|
|
@@ -17,6 +20,9 @@ from src.content.common import (
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}"""
|
| 21 |
|
| 22 |
|
|
@@ -96,7 +102,12 @@ def _prepare_final_prompt_with_ui(one_time_prompt):
|
|
| 96 |
return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt)
|
| 97 |
|
| 98 |
with st.spinner("Searching appropriate querys..."):
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
if len(st.session_state.ag_messages) <= 2:
|
| 101 |
relevant_query_indices.append(0)
|
| 102 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
|
| 4 |
import numpy as np
|
| 5 |
import streamlit as st
|
| 6 |
|
| 7 |
+
from src.retrieval import STANDARD_QUERIES
|
| 8 |
from src.content.common import (
|
| 9 |
MODEL_NAMES,
|
| 10 |
AUDIO_SAMPLES_W_INSTRUCT,
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
| 23 |
+
API_BASE_URL = os.getenv('API_BASE_URL')
|
| 24 |
+
|
| 25 |
+
|
| 26 |
LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}"""
|
| 27 |
|
| 28 |
|
|
|
|
| 102 |
return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt)
|
| 103 |
|
| 104 |
with st.spinner("Searching appropriate querys..."):
|
| 105 |
+
response = requests.get(
|
| 106 |
+
f"{API_BASE_URL}retrieve_relevant_docs",
|
| 107 |
+
params={"user_question": one_time_prompt}
|
| 108 |
+
)
|
| 109 |
+
relevant_query_indices = response.json()
|
| 110 |
+
|
| 111 |
if len(st.session_state.ag_messages) <= 2:
|
| 112 |
relevant_query_indices.append(0)
|
| 113 |
|
src/content/common.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import copy
|
| 3 |
import base64
|
|
|
|
| 4 |
import itertools
|
| 5 |
from collections import OrderedDict
|
| 6 |
from typing import List, Optional
|
|
@@ -8,17 +10,11 @@ from typing import List, Optional
|
|
| 8 |
import numpy as np
|
| 9 |
import streamlit as st
|
| 10 |
|
| 11 |
-
from src.tunnel import start_server
|
| 12 |
-
from src.retrieval import load_retriever
|
| 13 |
from src.logger import load_logger
|
| 14 |
from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription
|
| 15 |
-
from src.generation import
|
| 16 |
-
FIXED_GENERATION_CONFIG,
|
| 17 |
-
MAX_AUDIO_LENGTH,
|
| 18 |
-
load_model,
|
| 19 |
-
retrive_response
|
| 20 |
-
)
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
PLAYGROUND_DIALOGUE_STATES = dict(
|
| 24 |
pg_audio_base64='',
|
|
@@ -65,46 +61,26 @@ DEFAULT_DIALOGUE_STATE_DICTS = [
|
|
| 65 |
]
|
| 66 |
|
| 67 |
|
| 68 |
-
MODEL_NAMES = OrderedDict({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
| 72 |
-
"female_pilot#1": {
|
| 73 |
-
"apperance": "Female Pilot Interview: Transcription",
|
| 74 |
-
"instructions": [
|
| 75 |
-
"Please transcribe the speech"
|
| 76 |
-
]
|
| 77 |
-
},
|
| 78 |
-
"female_pilot#2": {
|
| 79 |
-
"apperance": "Female Pilot Interview: Aircraft name",
|
| 80 |
-
"instructions": [
|
| 81 |
-
"What does 大力士 mean in the conversation"
|
| 82 |
-
]
|
| 83 |
-
},
|
| 84 |
-
"female_pilot#3": {
|
| 85 |
-
"apperance": "Female Pilot Interview: Air Force Personnel Count",
|
| 86 |
-
"instructions": [
|
| 87 |
-
"How many air force personnel are there?"
|
| 88 |
-
]
|
| 89 |
-
},
|
| 90 |
-
"female_pilot#4": {
|
| 91 |
-
"apperance": "Female Pilot Interview: Air Force Personnel Name",
|
| 92 |
-
"instructions": [
|
| 93 |
-
"Can you tell me the names of the two pilots?"
|
| 94 |
-
]
|
| 95 |
-
},
|
| 96 |
-
"female_pilot#5": {
|
| 97 |
-
"apperance": "Female Pilot Interview: Pilot Seat Restriction",
|
| 98 |
-
"instructions": [
|
| 99 |
-
"What is the concern of having a big butt for pilot?"
|
| 100 |
-
]
|
| 101 |
-
},
|
| 102 |
-
"female_pilot#6": {
|
| 103 |
-
"apperance": "Female Pilot Interview: Conversation Mood",
|
| 104 |
-
"instructions": [
|
| 105 |
-
"What is the mood of the conversation?"
|
| 106 |
-
]
|
| 107 |
-
},
|
| 108 |
"7_ASR_IMDA_PART3_30_ASR_v2_2269": {
|
| 109 |
"apperance": "7. Automatic Speech Recognition task: conversation in Singapore accent",
|
| 110 |
"instructions": [
|
|
@@ -358,13 +334,40 @@ AUDIO_SAMPLES_W_INSTRUCT = {
|
|
| 358 |
"instructions": [
|
| 359 |
"Please follow the instruction in the speech."
|
| 360 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
}
|
| 362 |
}
|
| 363 |
|
| 364 |
|
| 365 |
-
exec(os.getenv('APP_CONFIGS'))
|
| 366 |
-
|
| 367 |
-
|
| 368 |
def reset_states(*state_dicts):
|
| 369 |
for states in state_dicts:
|
| 370 |
st.session_state.update(copy.deepcopy(states))
|
|
@@ -403,14 +406,6 @@ def init_state_section():
|
|
| 403 |
st.session_state.logger = load_logger()
|
| 404 |
st.session_state.session_id = st.session_state.logger.register_session()
|
| 405 |
|
| 406 |
-
if "server" not in st.session_state:
|
| 407 |
-
st.session_state.server = start_server()
|
| 408 |
-
|
| 409 |
-
if "client_mapper" not in st.session_state:
|
| 410 |
-
st.session_state.client_mapper = load_model()
|
| 411 |
-
|
| 412 |
-
if "retriever" not in st.session_state:
|
| 413 |
-
st.session_state.retriever = load_retriever()
|
| 414 |
|
| 415 |
for key, value in FIXED_GENERATION_CONFIG.items():
|
| 416 |
if key not in st.session_state:
|
|
@@ -551,54 +546,68 @@ def retrive_response_with_ui(
|
|
| 551 |
if history is None:
|
| 552 |
history = []
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
base64_audio_input=base64_audio_input,
|
| 572 |
-
history=history,
|
| 573 |
-
**generation_params,
|
| 574 |
-
**kwargs
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
if error_msg:
|
| 578 |
-
st.error(error_msg)
|
| 579 |
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
st.warning(warning_msg)
|
| 583 |
|
|
|
|
|
|
|
| 584 |
response = ""
|
| 585 |
-
|
|
|
|
| 586 |
if stream:
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
response = st.write_stream(response_obj)
|
| 589 |
else:
|
| 590 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
if normalise_response:
|
| 592 |
response = postprocess_voice_transcription(response)
|
| 593 |
response = prefix + response
|
| 594 |
st.write(response)
|
| 595 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
st.session_state.logger.register_query(
|
| 597 |
session_id=st.session_state.session_id,
|
| 598 |
base64_audio=base64_audio_input,
|
| 599 |
text_input=text_input,
|
| 600 |
history=history,
|
| 601 |
-
params=
|
| 602 |
response=response,
|
| 603 |
warnings=warnings,
|
| 604 |
error_msg=error_msg
|
|
|
|
| 1 |
import os
|
| 2 |
+
import re
|
| 3 |
import copy
|
| 4 |
import base64
|
| 5 |
+
import requests
|
| 6 |
import itertools
|
| 7 |
from collections import OrderedDict
|
| 8 |
from typing import List, Optional
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import streamlit as st
|
| 12 |
|
|
|
|
|
|
|
| 13 |
from src.logger import load_logger
|
| 14 |
from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription
|
| 15 |
+
from src.generation import FIXED_GENERATION_CONFIG, MAX_AUDIO_LENGTH
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
API_BASE_URL = os.getenv('API_BASE_URL')
|
| 18 |
|
| 19 |
PLAYGROUND_DIALOGUE_STATES = dict(
|
| 20 |
pg_audio_base64='',
|
|
|
|
| 61 |
]
|
| 62 |
|
| 63 |
|
| 64 |
+
MODEL_NAMES = OrderedDict({
|
| 65 |
+
"llm": {
|
| 66 |
+
"vllm_name": "MERaLiON-Gemma",
|
| 67 |
+
"model_name": "MERaLiON-Gemma",
|
| 68 |
+
"ui_name": "MERaLiON-Gemma"
|
| 69 |
+
},
|
| 70 |
+
"audiollm": {
|
| 71 |
+
"vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION",
|
| 72 |
+
"model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION",
|
| 73 |
+
"ui_name": "MERaLiON-AudioLLM"
|
| 74 |
+
},
|
| 75 |
+
"audiollm-it": {
|
| 76 |
+
"vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION-it",
|
| 77 |
+
"model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION-it",
|
| 78 |
+
"ui_name": "MERaLiON-AudioLLM-Instruction-Tuning"
|
| 79 |
+
}
|
| 80 |
+
})
|
| 81 |
|
| 82 |
|
| 83 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
"7_ASR_IMDA_PART3_30_ASR_v2_2269": {
|
| 85 |
"apperance": "7. Automatic Speech Recognition task: conversation in Singapore accent",
|
| 86 |
"instructions": [
|
|
|
|
| 334 |
"instructions": [
|
| 335 |
"Please follow the instruction in the speech."
|
| 336 |
]
|
| 337 |
+
},
|
| 338 |
+
"female_pilot#1": {
|
| 339 |
+
"apperance": "Female Pilot Interview: Transcription",
|
| 340 |
+
"instructions": [
|
| 341 |
+
"Please transcribe the speech"
|
| 342 |
+
]
|
| 343 |
+
},
|
| 344 |
+
"female_pilot#2": {
|
| 345 |
+
"apperance": "Female Pilot Interview: Aircraft name",
|
| 346 |
+
"instructions": [
|
| 347 |
+
"What does 大力士 mean in the conversation"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
"female_pilot#3": {
|
| 351 |
+
"apperance": "Female Pilot Interview: Air Force Personnel Count",
|
| 352 |
+
"instructions": [
|
| 353 |
+
"How many air force personnel are there?"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
"female_pilot#4": {
|
| 357 |
+
"apperance": "Female Pilot Interview: Air Force Personnel Name",
|
| 358 |
+
"instructions": [
|
| 359 |
+
"Can you tell me the names of the two pilots?"
|
| 360 |
+
]
|
| 361 |
+
},
|
| 362 |
+
"female_pilot#5": {
|
| 363 |
+
"apperance": "Female Pilot Interview: Conversation Mood",
|
| 364 |
+
"instructions": [
|
| 365 |
+
"What is the mood of the conversation?"
|
| 366 |
+
]
|
| 367 |
}
|
| 368 |
}
|
| 369 |
|
| 370 |
|
|
|
|
|
|
|
|
|
|
| 371 |
def reset_states(*state_dicts):
|
| 372 |
for states in state_dicts:
|
| 373 |
st.session_state.update(copy.deepcopy(states))
|
|
|
|
| 406 |
st.session_state.logger = load_logger()
|
| 407 |
st.session_state.session_id = st.session_state.logger.register_session()
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
for key, value in FIXED_GENERATION_CONFIG.items():
|
| 411 |
if key not in st.session_state:
|
|
|
|
| 546 |
if history is None:
|
| 547 |
history = []
|
| 548 |
|
| 549 |
+
# Prepare request data
|
| 550 |
+
request_data = {
|
| 551 |
+
"text_input": str(text_input),
|
| 552 |
+
"model_name": str(model_name),
|
| 553 |
+
"array_audio_input": array_audio_input.tolist(), # Convert numpy array to list
|
| 554 |
+
"base64_audio_input": str(base64_audio_input) if base64_audio_input else None,
|
| 555 |
+
"history": list(history) if history else None,
|
| 556 |
+
"stream": bool(stream),
|
| 557 |
+
"max_completion_tokens": int(st.session_state.max_completion_tokens),
|
| 558 |
+
"temperature": float(st.session_state.temperature),
|
| 559 |
+
"top_p": float(st.session_state.top_p),
|
| 560 |
+
"repetition_penalty": float(st.session_state.repetition_penalty),
|
| 561 |
+
"top_k": int(st.session_state.top_k),
|
| 562 |
+
"length_penalty": float(st.session_state.length_penalty),
|
| 563 |
+
"seed": int(st.session_state.seed),
|
| 564 |
+
"extra_params": {}
|
| 565 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
+
# print(request_data)
|
| 568 |
+
# print(model_name)
|
|
|
|
| 569 |
|
| 570 |
+
error_msg = ""
|
| 571 |
+
warnings = []
|
| 572 |
response = ""
|
| 573 |
+
|
| 574 |
+
try:
|
| 575 |
if stream:
|
| 576 |
+
# Streaming response
|
| 577 |
+
response_stream = requests.post(f"{API_BASE_URL}chat", json=request_data, stream=True)
|
| 578 |
+
response_stream.raise_for_status()
|
| 579 |
+
|
| 580 |
+
response_obj = itertools.chain([prefix], (chunk.decode() for chunk in response_stream))
|
| 581 |
response = st.write_stream(response_obj)
|
| 582 |
else:
|
| 583 |
+
# Non-streaming response
|
| 584 |
+
api_response = requests.post(f"{API_BASE_URL}chat", json=request_data)
|
| 585 |
+
api_response.raise_for_status()
|
| 586 |
+
result = api_response.json()
|
| 587 |
+
|
| 588 |
+
if "warnings" in result:
|
| 589 |
+
warnings = result["warnings"]
|
| 590 |
+
|
| 591 |
+
response = result.get("response", "")
|
| 592 |
if normalise_response:
|
| 593 |
response = postprocess_voice_transcription(response)
|
| 594 |
response = prefix + response
|
| 595 |
st.write(response)
|
| 596 |
|
| 597 |
+
except requests.exceptions.RequestException as e:
|
| 598 |
+
error_msg = f"API request failed: {str(e)}"
|
| 599 |
+
st.error(error_msg)
|
| 600 |
+
|
| 601 |
+
if show_warning:
|
| 602 |
+
for warning_msg in warnings:
|
| 603 |
+
st.warning(warning_msg)
|
| 604 |
+
|
| 605 |
st.session_state.logger.register_query(
|
| 606 |
session_id=st.session_state.session_id,
|
| 607 |
base64_audio=base64_audio_input,
|
| 608 |
text_input=text_input,
|
| 609 |
history=history,
|
| 610 |
+
params=request_data["extra_params"],
|
| 611 |
response=response,
|
| 612 |
warnings=warnings,
|
| 613 |
error_msg=error_msg
|
src/exceptions.py
CHANGED
|
@@ -1,6 +1,2 @@
|
|
| 1 |
class NoAudioException(Exception):
|
| 2 |
-
pass
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class TunnelNotRunningException(Exception):
|
| 6 |
pass
|
|
|
|
| 1 |
class NoAudioException(Exception):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
pass
|
src/generation.py
CHANGED
|
@@ -1,15 +1,3 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import re
|
| 3 |
-
import time
|
| 4 |
-
from typing import List, Dict, Optional
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import streamlit as st
|
| 8 |
-
from openai import OpenAI, APIConnectionError
|
| 9 |
-
|
| 10 |
-
from src.exceptions import TunnelNotRunningException
|
| 11 |
-
|
| 12 |
-
|
| 13 |
FIXED_GENERATION_CONFIG = dict(
|
| 14 |
max_completion_tokens=1024,
|
| 15 |
top_k=50,
|
|
@@ -20,25 +8,6 @@ FIXED_GENERATION_CONFIG = dict(
|
|
| 20 |
MAX_AUDIO_LENGTH = 120
|
| 21 |
|
| 22 |
|
| 23 |
-
def load_model() -> Dict:
|
| 24 |
-
"""
|
| 25 |
-
Create an OpenAI client with connection to vllm server.
|
| 26 |
-
"""
|
| 27 |
-
openai_api_key = os.getenv('API_KEY')
|
| 28 |
-
local_ports = os.getenv('LOCAL_PORTS').split(" ")
|
| 29 |
-
|
| 30 |
-
name_to_client_mapper = {}
|
| 31 |
-
for port in local_ports:
|
| 32 |
-
client = OpenAI(
|
| 33 |
-
api_key=openai_api_key,
|
| 34 |
-
base_url=f"http://localhost:{port}/v1",
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
for model in client.models.list().data:
|
| 38 |
-
name_to_client_mapper[model.id] = client
|
| 39 |
-
|
| 40 |
-
return name_to_client_mapper
|
| 41 |
-
|
| 42 |
|
| 43 |
def prepare_multimodal_content(text_input, base64_audio_input):
|
| 44 |
return [
|
|
@@ -76,112 +45,3 @@ def change_multimodal_content(
|
|
| 76 |
}
|
| 77 |
|
| 78 |
return original_content
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def _retrive_response(
|
| 83 |
-
model: str,
|
| 84 |
-
text_input: str,
|
| 85 |
-
base64_audio_input: str,
|
| 86 |
-
history: Optional[List] = None,
|
| 87 |
-
**kwargs):
|
| 88 |
-
"""
|
| 89 |
-
Send request through OpenAI client.
|
| 90 |
-
"""
|
| 91 |
-
if history is None:
|
| 92 |
-
history = []
|
| 93 |
-
|
| 94 |
-
if base64_audio_input:
|
| 95 |
-
content = [
|
| 96 |
-
{
|
| 97 |
-
"type": "text",
|
| 98 |
-
"text": f"Text instruction: {text_input}"
|
| 99 |
-
},
|
| 100 |
-
{
|
| 101 |
-
"type": "audio_url",
|
| 102 |
-
"audio_url": {
|
| 103 |
-
"url": f"data:audio/ogg;base64,{base64_audio_input}"
|
| 104 |
-
},
|
| 105 |
-
},
|
| 106 |
-
]
|
| 107 |
-
else:
|
| 108 |
-
content = text_input
|
| 109 |
-
|
| 110 |
-
current_client = st.session_state.client_mapper[model]
|
| 111 |
-
|
| 112 |
-
return current_client.chat.completions.create(
|
| 113 |
-
messages=history + [{"role": "user", "content": content}],
|
| 114 |
-
model=model,
|
| 115 |
-
**kwargs
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def _retry_retrive_response_throws_exception(retry=3, **kwargs):
|
| 120 |
-
try:
|
| 121 |
-
response_object = _retrive_response(**kwargs)
|
| 122 |
-
except APIConnectionError as e:
|
| 123 |
-
if not st.session_state.server.is_running():
|
| 124 |
-
if retry == 0:
|
| 125 |
-
raise TunnelNotRunningException()
|
| 126 |
-
|
| 127 |
-
st.toast(f":warning: Internet connection is down. Trying to re-establish connection ({retry}).")
|
| 128 |
-
|
| 129 |
-
if st.session_state.server.is_down():
|
| 130 |
-
st.session_state.server.restart()
|
| 131 |
-
elif st.session_state.server.is_starting():
|
| 132 |
-
time.sleep(2)
|
| 133 |
-
|
| 134 |
-
return _retry_retrive_response_throws_exception(retry-1, **kwargs)
|
| 135 |
-
raise e
|
| 136 |
-
|
| 137 |
-
return response_object
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def _validate_input(text_input, array_audio_input) -> List[str]:
|
| 141 |
-
"""
|
| 142 |
-
TODO: improve the input validation regex.
|
| 143 |
-
"""
|
| 144 |
-
warnings = []
|
| 145 |
-
if re.search("tool|code|python|java|math|calculate", text_input):
|
| 146 |
-
warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.")
|
| 147 |
-
|
| 148 |
-
if re.search(r'[\u4e00-\u9fff]+', text_input):
|
| 149 |
-
warnings.append("NOTE: Please try to prompt in English for the best performance.")
|
| 150 |
-
|
| 151 |
-
if array_audio_input.shape[0] == 0:
|
| 152 |
-
warnings.append("NOTE: Please specify audio from examples or local files.")
|
| 153 |
-
|
| 154 |
-
if array_audio_input.shape[0] / 16000 > 30.0:
|
| 155 |
-
warnings.append((
|
| 156 |
-
"WARNING: MERaLiON-AudioLLM is trained to process audio up to **30 seconds**."
|
| 157 |
-
f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated."
|
| 158 |
-
))
|
| 159 |
-
|
| 160 |
-
return warnings
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def retrive_response(
|
| 164 |
-
text_input: str,
|
| 165 |
-
array_audio_input: np.ndarray,
|
| 166 |
-
**kwargs
|
| 167 |
-
):
|
| 168 |
-
warnings = _validate_input(text_input, array_audio_input)
|
| 169 |
-
|
| 170 |
-
response_object, error_msg = None, ""
|
| 171 |
-
try:
|
| 172 |
-
response_object = _retry_retrive_response_throws_exception(
|
| 173 |
-
text_input=text_input,
|
| 174 |
-
**kwargs
|
| 175 |
-
)
|
| 176 |
-
except TunnelNotRunningException:
|
| 177 |
-
error_msg = "Internet connection cannot be established. Please contact the administrator."
|
| 178 |
-
except Exception as e:
|
| 179 |
-
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
|
| 180 |
-
|
| 181 |
-
return error_msg, warnings, response_object
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def postprocess_voice_transcription(text):
|
| 185 |
-
text = re.sub("<.*>:?|\(.*\)|\[.*\]", "", text)
|
| 186 |
-
text = re.sub("\s+", " ", text).strip()
|
| 187 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
FIXED_GENERATION_CONFIG = dict(
|
| 2 |
max_completion_tokens=1024,
|
| 3 |
top_k=50,
|
|
|
|
| 8 |
MAX_AUDIO_LENGTH = 120
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def prepare_multimodal_content(text_input, base64_audio_input):
|
| 13 |
return [
|
|
|
|
| 45 |
}
|
| 46 |
|
| 47 |
return original_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/retrieval.py
CHANGED
|
@@ -1,10 +1,3 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import streamlit as st
|
| 5 |
-
from FlagEmbedding import BGEM3FlagModel
|
| 6 |
-
|
| 7 |
-
|
| 8 |
STANDARD_QUERIES = [
|
| 9 |
{
|
| 10 |
"query_text": "Please transcribe this speech.",
|
|
@@ -43,71 +36,3 @@ STANDARD_QUERIES = [
|
|
| 43 |
"ui_text": "emotion recognition"
|
| 44 |
},
|
| 45 |
]
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def _colbert_score(q_reps, p_reps):
|
| 49 |
-
"""Compute colbert scores of input queries and passages.
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
q_reps (np.ndarray): Multi-vector embeddings for queries.
|
| 53 |
-
p_reps (np.ndarray): Multi-vector embeddings for passages/corpus.
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
torch.Tensor: Computed colbert scores.
|
| 57 |
-
"""
|
| 58 |
-
# q_reps, p_reps = torch.from_numpy(q_reps), torch.from_numpy(p_reps)
|
| 59 |
-
token_scores = np.einsum('in,jn->ij', q_reps, p_reps)
|
| 60 |
-
scores = token_scores.max(-1)
|
| 61 |
-
scores = np.sum(scores) / q_reps.shape[0]
|
| 62 |
-
return scores
|
| 63 |
-
|
| 64 |
-
class QueryRetriever:
|
| 65 |
-
def __init__(self, docs):
|
| 66 |
-
self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
|
| 67 |
-
self.docs = docs
|
| 68 |
-
self.doc_vectors = self.model.encode(
|
| 69 |
-
[d["doc_text"] for d in self.docs],
|
| 70 |
-
return_sparse=True,
|
| 71 |
-
return_colbert_vecs=True
|
| 72 |
-
)
|
| 73 |
-
self.scorer_attrs = {
|
| 74 |
-
"lexical_weights": {
|
| 75 |
-
"method": self.model.compute_lexical_matching_score,
|
| 76 |
-
"weight": 0.2
|
| 77 |
-
},
|
| 78 |
-
"colbert_vecs": {
|
| 79 |
-
"method": _colbert_score,
|
| 80 |
-
"weight": 0.8
|
| 81 |
-
},
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
def get_relevant_doc_indices(self, prompt, normalize=False) -> np.ndarray:
|
| 85 |
-
scores = np.zeros(len(self.docs))
|
| 86 |
-
|
| 87 |
-
if not prompt:
|
| 88 |
-
return scores
|
| 89 |
-
|
| 90 |
-
prompt_vector = self.model.encode(
|
| 91 |
-
prompt,
|
| 92 |
-
return_sparse=True,
|
| 93 |
-
return_colbert_vecs=True
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
for scorer_name, scorer_attrs in self.scorer_attrs.items():
|
| 97 |
-
for i, doc_vec in enumerate(self.doc_vectors[scorer_name]):
|
| 98 |
-
scores[i] += scorer_attrs["method"](prompt_vector[scorer_name], doc_vec)
|
| 99 |
-
|
| 100 |
-
if normalize:
|
| 101 |
-
scores = scores / np.sum(scores)
|
| 102 |
-
return scores
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
@st.cache_resource()
|
| 106 |
-
def load_retriever():
|
| 107 |
-
return QueryRetriever(docs=STANDARD_QUERIES)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def retrieve_relevant_docs(user_question: str) -> List[int]:
|
| 111 |
-
scores = st.session_state.retriever.get_relevant_doc_indices(user_question, normalize=True)
|
| 112 |
-
selected_indices = np.where(scores > 0.2)[0]
|
| 113 |
-
return selected_indices.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
STANDARD_QUERIES = [
|
| 2 |
{
|
| 3 |
"query_text": "Please transcribe this speech.",
|
|
|
|
| 36 |
"ui_text": "emotion recognition"
|
| 37 |
},
|
| 38 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/tunnel.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
import io
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import paramiko
|
| 5 |
-
import streamlit as st
|
| 6 |
-
from sshtunnel import SSHTunnelForwarder
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
DEFAULT_LOCAL_PORTS = "8000 8001"
|
| 10 |
-
DEFAULT_REMOTE_PORTS = "8000 8001"
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@st.cache_resource()
|
| 14 |
-
def start_server():
|
| 15 |
-
server = SSHTunnelManager()
|
| 16 |
-
server.start()
|
| 17 |
-
return server
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class SSHTunnelManager:
|
| 21 |
-
def __init__(self):
|
| 22 |
-
pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
|
| 23 |
-
|
| 24 |
-
self.server = SSHTunnelForwarder(
|
| 25 |
-
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
|
| 26 |
-
ssh_username="ec2-user",
|
| 27 |
-
ssh_pkey=pkey,
|
| 28 |
-
local_bind_addresses=[
|
| 29 |
-
("127.0.0.1", int(port))
|
| 30 |
-
for port in os.getenv('LOCAL_PORTS', DEFAULT_LOCAL_PORTS).split(" ")
|
| 31 |
-
],
|
| 32 |
-
remote_bind_addresses=[
|
| 33 |
-
("127.0.0.1", int(port))
|
| 34 |
-
for port in os.getenv('REMOTE_PORTS', DEFAULT_REMOTE_PORTS).split(" ")
|
| 35 |
-
]
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
self._is_starting = False
|
| 39 |
-
self._is_running = False
|
| 40 |
-
|
| 41 |
-
def update_status(self):
|
| 42 |
-
if not self._is_starting:
|
| 43 |
-
self.server.check_tunnels()
|
| 44 |
-
self._is_running = all(
|
| 45 |
-
list(self.server.tunnel_is_up.values())
|
| 46 |
-
)
|
| 47 |
-
else:
|
| 48 |
-
self._is_running = False
|
| 49 |
-
|
| 50 |
-
def is_starting(self):
|
| 51 |
-
self.update_status()
|
| 52 |
-
return self._is_starting
|
| 53 |
-
|
| 54 |
-
def is_running(self):
|
| 55 |
-
self.update_status()
|
| 56 |
-
return self._is_running
|
| 57 |
-
|
| 58 |
-
def is_down(self):
|
| 59 |
-
self.update_status()
|
| 60 |
-
return (not self._is_running) and (not self._is_starting)
|
| 61 |
-
|
| 62 |
-
def start(self, *args, **kwargs):
|
| 63 |
-
if not self._is_starting:
|
| 64 |
-
self._is_starting = True
|
| 65 |
-
self.server.start(*args, **kwargs)
|
| 66 |
-
self._is_starting = False
|
| 67 |
-
|
| 68 |
-
def restart(self, *args, **kwargs):
|
| 69 |
-
if not self._is_starting:
|
| 70 |
-
self._is_starting = True
|
| 71 |
-
self.server.restart(*args, **kwargs)
|
| 72 |
-
self._is_starting = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|