Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset, Dataset | |
| # import faiss | |
| import os | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import torch | |
| from threading import Thread | |
| from ragatouille import RAGPretrainedModel | |
| from datasets import load_dataset | |
| token = os.environ["HF_TOKEN"] | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-7b-it", | |
| # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| torch_dtype=torch.float16, | |
| token=token, | |
| ) | |
| tok = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) | |
| device = torch.device("cuda") | |
| model = model.to(device) | |
| RAG = RAGPretrainedModel.from_pretrained("mixedbread-ai/mxbai-colbert-v1") | |
| # prepare data | |
| # since data is too big we will only select the first 3K lines | |
| dataset = load_dataset( | |
| "wikimedia/wikipedia", "20231101.en", split="train", streaming=True | |
| ) | |
| # init data | |
| data = Dataset.from_dict({}) | |
| i = 0 | |
| for i, entry in enumerate(dataset): | |
| # each entry has the following columns | |
| # ['id', 'url', 'title', 'text'] | |
| data.add_item(entry) | |
| if i == 3000: | |
| break | |
| # free memory | |
| del dataset # we keep data | |
| # index data | |
| documents = data["text"] | |
| RAG.index(documents, index_name="wikipedia", use_faiss=True) | |
| # free memory | |
| del documents | |
| def search(query, k: int = 5): | |
| results = RAG.search(query, k=k) | |
| # results are ordered according to their score | |
| # results has the following keys | |
| # | |
| # {'content' : 'retrieved content' | |
| # 'score' : score[float] | |
| # 'rank' : "results are sorted using score and each is given a rank, also can be called place, 1 2 3 4 ..." | |
| # 'document_id' : "no clue man i just got here" | |
| # 'passage_id' : "or original row number" | |
| # } | |
| # | |
| return [result["passage_id"] for result in results] | |
| def prepare_prompt(query, indexes,data = data): | |
| prompt = ( | |
| f"Query: {query}\nContinue to answer the query by using the Search Results:\n" | |
| ) | |
| titles = [] | |
| urls = [] | |
| for i in indexes: | |
| title = entry["title"][i] | |
| text = entry["text"][i] | |
| url = entry["url"][i] | |
| titles.append(title) | |
| urls.append(url) | |
| prompt += f"Title: {title}, Text: {text}\n" | |
| return prompt, (titles,urls) | |
| def talk(message, history): | |
| indexes = search(message) | |
| message,metadata = prepare_prompt(message, indexes) | |
| resources = "\nRESOURCES:\n" | |
| for title,url in metadata: | |
| resources += f"[{title}]({url}), " | |
| chat = [] | |
| for item in history: | |
| chat.append({"role": "user", "content": item[0]}) | |
| if item[1] is not None: | |
| cleaned_past = item[1].split("\nRESOURCES:\n")[0] | |
| chat.append({"role": "assistant", "content": cleaned_past}) | |
| chat.append({"role": "user", "content": message}) | |
| messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
| # Tokenize the messages string | |
| model_inputs = tok([messages], return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer( | |
| tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=1000, | |
| temperature=0.75, | |
| num_beams=1, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # Initialize an empty string to store the generated text | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| yield partial_text | |
| partial_text += resources | |
| yield partial_text | |
| TITLE = "RAG" | |
| DESCRIPTION = """ | |
| ## Resources used to build this project | |
| * https://huggingface.co/mixedbread-ai/mxbai-colbert-large-v1 | |
| * me π | |
| ## Models | |
| the models used in this space are : | |
| * google/gemma-7b-it | |
| * mixedbread-ai/mxbai-colbert-v1 | |
| """ | |
| demo = gr.ChatInterface( | |
| fn=talk, | |
| chatbot=gr.Chatbot( | |
| show_label=True, | |
| show_share_button=True, | |
| show_copy_button=True, | |
| likeable=True, | |
| layout="bubble", | |
| bubble_full_width=False, | |
| ), | |
| theme="Soft", | |
| examples=[["what is machine learning"]], | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| ) | |
| demo.launch() | |