Spaces:
Runtime error
Runtime error
| import warnings | |
| import torchvision | |
| import torch | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import streamlit as st | |
| # Suppress torchvision beta warnings | |
| torchvision.disable_beta_transforms_warning() | |
| warnings.filterwarnings("ignore", category=UserWarning, module="torchvision") | |
| # Load tokenizer and model with error handling for compatibility | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False) | |
| model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased") | |
| except Exception: | |
| st.warning("Switching to xlm-roberta-base model due to compatibility issues.") | |
| tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False) | |
| model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base") | |
| # Initialize the fill-mask pipeline | |
| pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt") | |
| # Function to generate embeddings | |
| def get_embedding(text): | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return outputs.logits[:, 0, :].cpu().numpy() | |
| # Streamlit app setup | |
| st.title("Thai Full Sentence Similarity App") | |
| st.write(""" | |
| ## using Thai Law nlp dataset""") | |
| st.write(""" | |
| ### How This App Works | |
| This app uses a mask-filling model to predict possible words or phrases that could fill in the `<mask>` token in a given sentence. It then calculates the similarity of each prediction with the original sentence to determine the most contextually appropriate completion. | |
| ### Example Sentence | |
| In this example, we have the following sentence in Thai with a `<mask>` token: | |
| - **Input**: `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน <mask> เพื่อสัมผัสธรรมชาติ"` | |
| - **Translation**: "Many tourists choose to visit `<mask>` to experience nature." | |
| The `<mask>` token represents a location popular for its natural beauty. | |
| ### Potential Predictions | |
| Here are some possible predictions the model might generate for `<mask>`: | |
| 1. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เชียงใหม่ เพื่อสัมผัสธรรมชาติ"` - Chiang Mai | |
| 2. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เขาใหญ่ เพื่อสัมผัสธรรมชาติ"` - Khao Yai | |
| 3. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เกาะสมุย เพื่อสัมผัสธรรมชาติ"` - Koh Samui | |
| 4. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน ภูเก็ต เพื่อสัมผัสธรรมชาติ"` - Phuket | |
| ### Results Table | |
| For each prediction, the app calculates: | |
| - **Similarity Score**: Indicates how similar the predicted sentence is to the original input. | |
| - **Model Score**: Represents the model's confidence in the predicted word for `<mask>`. | |
| ### Most Similar Prediction | |
| The app will display the most contextually similar prediction based on the similarity score. For example: | |
| - **Most Similar Prediction**: `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เชียงใหม่ เพื่อสัมผัสธรรมชาติ"` | |
| - **Similarity Score**: 0.89 | |
| - **Model Score**: 0.16 | |
| Feel free to enter your own sentence with `<mask>` and explore the predictions! | |
| """) | |
| # User input box | |
| st.subheader("Input Text") | |
| input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "เมนูโปรดของฉันคือ <mask> ที่ทำจากวัตถุดิบสดใหม่") | |
| # Ensure the input includes a `<mask>` | |
| if "<mask>" not in input_text: | |
| input_text += " <mask>" | |
| st.warning("`<mask>` token was missing in your input. It has been added automatically.") | |
| # Process the input when available | |
| if input_text: | |
| st.write(f"Input Text: {input_text}") | |
| # Generate baseline embedding (removing `<mask>` to get the full sentence) | |
| baseline_text = input_text.replace("<mask>", "") | |
| input_embedding = get_embedding(baseline_text) | |
| # Generate mask predictions and calculate similarity with the full sentences | |
| similarity_results = [] | |
| try: | |
| result = pipe(input_text) | |
| for r in result: | |
| prediction_text = r.get('sequence', '') | |
| if prediction_text: | |
| prediction_embedding = get_embedding(prediction_text) | |
| similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0] | |
| similarity_results.append({ | |
| "Prediction": prediction_text, | |
| "Similarity Score": similarity, | |
| "Model Score": r['score'] | |
| }) | |
| # Convert results to DataFrame for easy sorting and display | |
| df_results = pd.DataFrame(similarity_results).sort_values(by="Similarity Score", ascending=False) | |
| # Display all predictions sorted by similarity score | |
| st.subheader("All Predictions Sorted by Similarity") | |
| st.dataframe(df_results) | |
| # Display the most similar prediction | |
| most_similar = df_results.iloc[0] | |
| st.subheader("Most Similar Prediction") | |
| st.write(f"**Prediction**: {most_similar['Prediction']}") | |
| st.write(f"**Similarity Score**: {most_similar['Similarity Score']:.4f}") | |
| st.write(f"**Model Score**: {most_similar['Model Score']:.4f}") | |
| except KeyError: | |
| st.error("Unexpected model output structure; unable to retrieve predictions.") | |