Chanlefe's picture
Update app.py
25ec883 verified
raw
history blame
6.81 kB
import gradio as gr
import torch
from transformers import (
BertForSequenceClassification,
BertTokenizer,
CLIPProcessor,
CLIPModel,
pipeline
)
from PIL import Image
import pytesseract
import numpy as np
import cv2
class BertClipMemeAnalyzer:
def __init__(self):
# BERT for sentiment analysis
self.bert_sentiment = pipeline(
"sentiment-analysis",
model="nlptown/bert-base-multilingual-uncased-sentiment"
)
# Alternative: Use a BERT model specifically
# self.bert_model = BertForSequenceClassification.from_pretrained(
# "bert-base-uncased",
# num_labels=3 # positive, negative, neutral
# )
# self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# CLIP for zero-shot image classification
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Define meme categories for CLIP
self.meme_categories = [
"a hateful meme",
"a non-hateful meme",
"a funny meme",
"an offensive meme",
"a wholesome meme",
"a sarcastic meme",
"a political meme",
"a neutral meme"
]
def extract_text_from_image(self, image):
"""Extract text using OCR"""
try:
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
enhanced = cv2.convertScaleAbs(gray, alpha=1.5, beta=0)
text = pytesseract.image_to_string(enhanced)
return text.strip()
except:
return ""
def classify_with_clip(self, image):
"""Use CLIP to classify meme type"""
# Prepare inputs
inputs = self.clip_processor(
text=self.meme_categories,
images=image,
return_tensors="pt",
padding=True
)
# Get predictions
with torch.no_grad():
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
# Get results
results = []
for i, (category, prob) in enumerate(zip(self.meme_categories, probs[0])):
results.append({
'category': category.replace("a ", "").replace(" meme", ""),
'score': prob.item()
})
# Sort by score
results.sort(key=lambda x: x['score'], reverse=True)
return results
def analyze_meme(self, text_input, image):
output = "## 🎭 Meme Analysis Results\n\n"
# Extract text if image provided
extracted_text = ""
if image is not None:
extracted_text = self.extract_text_from_image(image)
if extracted_text:
output += f"### πŸ“ Extracted Text (OCR):\n`{extracted_text}`\n\n"
# Combine texts
full_text = text_input or ""
if extracted_text:
full_text = f"{full_text} {extracted_text}".strip()
# BERT Sentiment Analysis
if full_text:
output += "### 🧠 BERT Sentiment Analysis:\n"
sentiment_results = self.bert_sentiment(full_text)
# Map to your categories
label_map = {
'1 star': 'Very Negative',
'2 stars': 'Negative',
'3 stars': 'Neutral',
'4 stars': 'Positive',
'5 stars': 'Very Positive',
'POSITIVE': 'Positive',
'NEGATIVE': 'Negative',
'NEUTRAL': 'Neutral'
}
sentiment = sentiment_results[0]
mapped_label = label_map.get(sentiment['label'], sentiment['label'])
output += f"**{mapped_label}** (Confidence: {sentiment['score']:.1%})\n\n"
else:
output += "### 🧠 BERT Sentiment Analysis:\n"
output += "No text provided for sentiment analysis\n\n"
# CLIP Meme Classification
if image is not None:
output += "### πŸ–ΌοΈ CLIP Meme Classification:\n"
clip_results = self.classify_with_clip(image)
# Top prediction
top = clip_results[0]
output += f"**Primary Classification: {top['category'].title()}** ({top['score']:.1%})\n\n"
# Hateful/Non-hateful specific
hateful_score = next(r['score'] for r in clip_results if 'hateful' in r['category'])
non_hateful_score = next(r['score'] for r in clip_results if 'non-hateful' in r['category'])
if hateful_score > non_hateful_score:
output += f"⚠️ **Hate Detection: Potentially Hateful** ({hateful_score:.1%})\n\n"
else:
output += f"βœ… **Hate Detection: Non-hateful** ({non_hateful_score:.1%})\n\n"
# All classifications
output += "**All Classifications:**\n"
for result in clip_results:
bar = "β–ˆ" * int(result['score'] * 20)
output += f"- {result['category'].title()}: {bar} {result['score']:.1%}\n"
else:
output += "### πŸ–ΌοΈ CLIP Meme Classification:\n"
output += "No image provided for classification\n\n"
# Summary
output += "\n### πŸ“Š Summary:\n"
if full_text:
output += f"- Text analyzed: {len(full_text.split())} words\n"
if image:
output += f"- Image analyzed: βœ“\n"
if extracted_text:
output += f"- OCR successful: βœ“\n"
return output
# Initialize
analyzer = BertClipMemeAnalyzer()
# Create interface
demo = gr.Interface(
fn=analyzer.analyze_meme,
inputs=[
gr.Textbox(
label="πŸ“ Text Input (Optional)",
placeholder="Enter meme text or leave empty for OCR...",
lines=2
),
gr.Image(
label="πŸ–ΌοΈ Upload Meme",
type="pil"
)
],
outputs=gr.Markdown(label="Analysis Results"),
title="🎭 BERT + CLIP Meme Analyzer",
description="""
This analyzer uses:
- **BERT**: For sentiment analysis of text (Positive/Negative/Neutral)
- **CLIP**: For zero-shot meme classification (Hateful/Non-hateful, Funny, etc.)
- **OCR**: To extract text from meme images
Upload a meme to see both text sentiment and visual classification!
""",
theme=gr.themes.Soft()
)
demo.launch()