Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| BertForSequenceClassification, | |
| BertTokenizer, | |
| CLIPProcessor, | |
| CLIPModel, | |
| pipeline | |
| ) | |
| from PIL import Image | |
| import pytesseract | |
| import numpy as np | |
| import cv2 | |
| class BertClipMemeAnalyzer: | |
| def __init__(self): | |
| # BERT for sentiment analysis | |
| self.bert_sentiment = pipeline( | |
| "sentiment-analysis", | |
| model="nlptown/bert-base-multilingual-uncased-sentiment" | |
| ) | |
| # Alternative: Use a BERT model specifically | |
| # self.bert_model = BertForSequenceClassification.from_pretrained( | |
| # "bert-base-uncased", | |
| # num_labels=3 # positive, negative, neutral | |
| # ) | |
| # self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| # CLIP for zero-shot image classification | |
| self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Define meme categories for CLIP | |
| self.meme_categories = [ | |
| "a hateful meme", | |
| "a non-hateful meme", | |
| "a funny meme", | |
| "an offensive meme", | |
| "a wholesome meme", | |
| "a sarcastic meme", | |
| "a political meme", | |
| "a neutral meme" | |
| ] | |
| def extract_text_from_image(self, image): | |
| """Extract text using OCR""" | |
| try: | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY) | |
| enhanced = cv2.convertScaleAbs(gray, alpha=1.5, beta=0) | |
| text = pytesseract.image_to_string(enhanced) | |
| return text.strip() | |
| except: | |
| return "" | |
| def classify_with_clip(self, image): | |
| """Use CLIP to classify meme type""" | |
| # Prepare inputs | |
| inputs = self.clip_processor( | |
| text=self.meme_categories, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = self.clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) | |
| # Get results | |
| results = [] | |
| for i, (category, prob) in enumerate(zip(self.meme_categories, probs[0])): | |
| results.append({ | |
| 'category': category.replace("a ", "").replace(" meme", ""), | |
| 'score': prob.item() | |
| }) | |
| # Sort by score | |
| results.sort(key=lambda x: x['score'], reverse=True) | |
| return results | |
| def analyze_meme(self, text_input, image): | |
| output = "## π Meme Analysis Results\n\n" | |
| # Extract text if image provided | |
| extracted_text = "" | |
| if image is not None: | |
| extracted_text = self.extract_text_from_image(image) | |
| if extracted_text: | |
| output += f"### π Extracted Text (OCR):\n`{extracted_text}`\n\n" | |
| # Combine texts | |
| full_text = text_input or "" | |
| if extracted_text: | |
| full_text = f"{full_text} {extracted_text}".strip() | |
| # BERT Sentiment Analysis | |
| if full_text: | |
| output += "### π§ BERT Sentiment Analysis:\n" | |
| sentiment_results = self.bert_sentiment(full_text) | |
| # Map to your categories | |
| label_map = { | |
| '1 star': 'Very Negative', | |
| '2 stars': 'Negative', | |
| '3 stars': 'Neutral', | |
| '4 stars': 'Positive', | |
| '5 stars': 'Very Positive', | |
| 'POSITIVE': 'Positive', | |
| 'NEGATIVE': 'Negative', | |
| 'NEUTRAL': 'Neutral' | |
| } | |
| sentiment = sentiment_results[0] | |
| mapped_label = label_map.get(sentiment['label'], sentiment['label']) | |
| output += f"**{mapped_label}** (Confidence: {sentiment['score']:.1%})\n\n" | |
| else: | |
| output += "### π§ BERT Sentiment Analysis:\n" | |
| output += "No text provided for sentiment analysis\n\n" | |
| # CLIP Meme Classification | |
| if image is not None: | |
| output += "### πΌοΈ CLIP Meme Classification:\n" | |
| clip_results = self.classify_with_clip(image) | |
| # Top prediction | |
| top = clip_results[0] | |
| output += f"**Primary Classification: {top['category'].title()}** ({top['score']:.1%})\n\n" | |
| # Hateful/Non-hateful specific | |
| hateful_score = next(r['score'] for r in clip_results if 'hateful' in r['category']) | |
| non_hateful_score = next(r['score'] for r in clip_results if 'non-hateful' in r['category']) | |
| if hateful_score > non_hateful_score: | |
| output += f"β οΈ **Hate Detection: Potentially Hateful** ({hateful_score:.1%})\n\n" | |
| else: | |
| output += f"β **Hate Detection: Non-hateful** ({non_hateful_score:.1%})\n\n" | |
| # All classifications | |
| output += "**All Classifications:**\n" | |
| for result in clip_results: | |
| bar = "β" * int(result['score'] * 20) | |
| output += f"- {result['category'].title()}: {bar} {result['score']:.1%}\n" | |
| else: | |
| output += "### πΌοΈ CLIP Meme Classification:\n" | |
| output += "No image provided for classification\n\n" | |
| # Summary | |
| output += "\n### π Summary:\n" | |
| if full_text: | |
| output += f"- Text analyzed: {len(full_text.split())} words\n" | |
| if image: | |
| output += f"- Image analyzed: β\n" | |
| if extracted_text: | |
| output += f"- OCR successful: β\n" | |
| return output | |
| # Initialize | |
| analyzer = BertClipMemeAnalyzer() | |
| # Create interface | |
| demo = gr.Interface( | |
| fn=analyzer.analyze_meme, | |
| inputs=[ | |
| gr.Textbox( | |
| label="π Text Input (Optional)", | |
| placeholder="Enter meme text or leave empty for OCR...", | |
| lines=2 | |
| ), | |
| gr.Image( | |
| label="πΌοΈ Upload Meme", | |
| type="pil" | |
| ) | |
| ], | |
| outputs=gr.Markdown(label="Analysis Results"), | |
| title="π BERT + CLIP Meme Analyzer", | |
| description=""" | |
| This analyzer uses: | |
| - **BERT**: For sentiment analysis of text (Positive/Negative/Neutral) | |
| - **CLIP**: For zero-shot meme classification (Hateful/Non-hateful, Funny, etc.) | |
| - **OCR**: To extract text from meme images | |
| Upload a meme to see both text sentiment and visual classification! | |
| """, | |
| theme=gr.themes.Soft() | |
| ) | |
| demo.launch() |