File size: 6,813 Bytes
088d0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import gradio as gr
import torch
from transformers import (
    BertForSequenceClassification, 
    BertTokenizer,
    CLIPProcessor, 
    CLIPModel,
    pipeline
)
from PIL import Image
import pytesseract
import numpy as np
import cv2

class BertClipMemeAnalyzer:
    def __init__(self):
        # BERT for sentiment analysis
        self.bert_sentiment = pipeline(
            "sentiment-analysis",
            model="nlptown/bert-base-multilingual-uncased-sentiment"
        )
        
        # Alternative: Use a BERT model specifically
        # self.bert_model = BertForSequenceClassification.from_pretrained(
        #     "bert-base-uncased", 
        #     num_labels=3  # positive, negative, neutral
        # )
        # self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        
        # CLIP for zero-shot image classification
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
        # Define meme categories for CLIP
        self.meme_categories = [
            "a hateful meme",
            "a non-hateful meme", 
            "a funny meme",
            "an offensive meme",
            "a wholesome meme",
            "a sarcastic meme",
            "a political meme",
            "a neutral meme"
        ]
        
    def extract_text_from_image(self, image):
        """Extract text using OCR"""
        try:
            image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
            gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
            enhanced = cv2.convertScaleAbs(gray, alpha=1.5, beta=0)
            text = pytesseract.image_to_string(enhanced)
            return text.strip()
        except:
            return ""
    
    def classify_with_clip(self, image):
        """Use CLIP to classify meme type"""
        # Prepare inputs
        inputs = self.clip_processor(
            text=self.meme_categories,
            images=image,
            return_tensors="pt",
            padding=True
        )
        
        # Get predictions
        with torch.no_grad():
            outputs = self.clip_model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=1)
        
        # Get results
        results = []
        for i, (category, prob) in enumerate(zip(self.meme_categories, probs[0])):
            results.append({
                'category': category.replace("a ", "").replace(" meme", ""),
                'score': prob.item()
            })
        
        # Sort by score
        results.sort(key=lambda x: x['score'], reverse=True)
        return results
    
    def analyze_meme(self, text_input, image):
        output = "## 🎭 Meme Analysis Results\n\n"
        
        # Extract text if image provided
        extracted_text = ""
        if image is not None:
            extracted_text = self.extract_text_from_image(image)
            if extracted_text:
                output += f"### πŸ“ Extracted Text (OCR):\n`{extracted_text}`\n\n"
        
        # Combine texts
        full_text = text_input or ""
        if extracted_text:
            full_text = f"{full_text} {extracted_text}".strip()
        
        # BERT Sentiment Analysis
        if full_text:
            output += "### 🧠 BERT Sentiment Analysis:\n"
            sentiment_results = self.bert_sentiment(full_text)
            
            # Map to your categories
            label_map = {
                '1 star': 'Very Negative',
                '2 stars': 'Negative',
                '3 stars': 'Neutral',
                '4 stars': 'Positive',
                '5 stars': 'Very Positive',
                'POSITIVE': 'Positive',
                'NEGATIVE': 'Negative',
                'NEUTRAL': 'Neutral'
            }
            
            sentiment = sentiment_results[0]
            mapped_label = label_map.get(sentiment['label'], sentiment['label'])
            output += f"**{mapped_label}** (Confidence: {sentiment['score']:.1%})\n\n"
        else:
            output += "### 🧠 BERT Sentiment Analysis:\n"
            output += "No text provided for sentiment analysis\n\n"
        
        # CLIP Meme Classification
        if image is not None:
            output += "### πŸ–ΌοΈ CLIP Meme Classification:\n"
            clip_results = self.classify_with_clip(image)
            
            # Top prediction
            top = clip_results[0]
            output += f"**Primary Classification: {top['category'].title()}** ({top['score']:.1%})\n\n"
            
            # Hateful/Non-hateful specific
            hateful_score = next(r['score'] for r in clip_results if 'hateful' in r['category'])
            non_hateful_score = next(r['score'] for r in clip_results if 'non-hateful' in r['category'])
            
            if hateful_score > non_hateful_score:
                output += f"⚠️ **Hate Detection: Potentially Hateful** ({hateful_score:.1%})\n\n"
            else:
                output += f"βœ… **Hate Detection: Non-hateful** ({non_hateful_score:.1%})\n\n"
            
            # All classifications
            output += "**All Classifications:**\n"
            for result in clip_results:
                bar = "β–ˆ" * int(result['score'] * 20)
                output += f"- {result['category'].title()}: {bar} {result['score']:.1%}\n"
        else:
            output += "### πŸ–ΌοΈ CLIP Meme Classification:\n"
            output += "No image provided for classification\n\n"
        
        # Summary
        output += "\n### πŸ“Š Summary:\n"
        if full_text:
            output += f"- Text analyzed: {len(full_text.split())} words\n"
        if image:
            output += f"- Image analyzed: βœ“\n"
        if extracted_text:
            output += f"- OCR successful: βœ“\n"
        
        return output

# Initialize
analyzer = BertClipMemeAnalyzer()

# Create interface
demo = gr.Interface(
    fn=analyzer.analyze_meme,
    inputs=[
        gr.Textbox(
            label="πŸ“ Text Input (Optional)",
            placeholder="Enter meme text or leave empty for OCR...",
            lines=2
        ),
        gr.Image(
            label="πŸ–ΌοΈ Upload Meme",
            type="pil"
        )
    ],
    outputs=gr.Markdown(label="Analysis Results"),
    title="🎭 BERT + CLIP Meme Analyzer",
    description="""
    This analyzer uses:
    - **BERT**: For sentiment analysis of text (Positive/Negative/Neutral)
    - **CLIP**: For zero-shot meme classification (Hateful/Non-hateful, Funny, etc.)
    - **OCR**: To extract text from meme images
    
    Upload a meme to see both text sentiment and visual classification!
    """,
    theme=gr.themes.Soft()
)

demo.launch()