monster07's picture
Update app.py
bbe970a verified
raw
history blame
6.12 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
# βœ… Xception Block Definition
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super().__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self, in_filters, out_filters, reps, stride=1, start_with_relu=True, grow_first=True):
super().__init__()
layers = []
filters = in_filters
if grow_first:
if start_with_relu:
layers.append(nn.ReLU(inplace=True))
layers.extend([
SeparableConv2d(in_filters, out_filters, 3, 1, 1),
nn.BatchNorm2d(out_filters)
])
filters = out_filters
for _ in range(reps - 1):
layers.extend([
nn.ReLU(inplace=True),
SeparableConv2d(filters, filters, 3, 1, 1),
nn.BatchNorm2d(filters)
])
if not grow_first:
layers.extend([
nn.ReLU(inplace=True),
SeparableConv2d(in_filters, out_filters, 3, 1, 1),
nn.BatchNorm2d(out_filters)
])
if stride != 1:
layers.append(nn.MaxPool2d(3, stride, 1))
self.block = nn.Sequential(*layers)
self.skip = nn.Conv2d(in_filters, out_filters, 1, stride, bias=False)
self.skipbn = nn.BatchNorm2d(out_filters)
def forward(self, inp):
x = self.block(inp)
skip = self.skipbn(self.skip(inp))
x += skip
return x
# βœ… Xception Architecture
class Xception(nn.Module):
def __init__(self, num_classes=1):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
self.block2 = Block(128, 256, 2, 2, grow_first=True)
self.block3 = Block(256, 728, 2, 2, grow_first=True)
self.block4 = Block(728, 728, 3)
self.block5 = Block(728, 728, 3)
self.block6 = Block(728, 728, 3)
self.block7 = Block(728, 728, 3)
self.block8 = Block(728, 728, 3)
self.block9 = Block(728, 728, 3)
self.block10 = Block(728, 728, 3)
self.block11 = Block(728, 728, 3)
self.block12 = Block(728, 1024, 2, 2, grow_first=False)
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
self.bn3 = nn.BatchNorm2d(1536)
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
self.bn4 = nn.BatchNorm2d(2048)
self.fc = nn.Linear(2048, num_classes)
def features(self, input):
x = self.relu(self.bn1(self.conv1(input)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.relu(self.bn3(self.conv3(x)))
x = self.relu(self.bn4(self.conv4(x)))
return x
def forward(self, input):
x = self.features(input)
x = nn.AdaptiveAvgPool2d((1, 1))(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# βœ… Load weights
model = Xception()
model.load_state_dict(torch.hub.load_state_dict_from_url(
"https://huggingface.co/Selimsef/xception-cnn-df/resolve/main/xception-binary-weights.pt",
map_location="cpu"
))
model.eval()
# βœ… Transform
transform = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
# βœ… Analyze function
def analyze_deepfake(video_path):
if not video_path:
return "❌ No video uploaded", None
cap = cv2.VideoCapture(video_path)
preds = []
count = 0
max_frames = 20
while True:
ret, frame = cap.read()
if not ret or count >= max_frames:
break
h, w, _ = frame.shape
y1 = int(h * 0.25)
y2 = int(h * 0.75)
x1 = int(w * 0.25)
x2 = int(w * 0.75)
crop = frame[y1:y2, x1:x2]
image = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
out = model(input_tensor)
prob = torch.sigmoid(out)[0].item()
preds.append(prob)
count += 1
cap.release()
if not preds:
return "❌ No frames analyzed", None
avg = np.mean(preds)
label = "**FAKE**" if avg > 0.5 else "**REAL**"
result = f"🎯 Verdict: {label}\nConfidence: {avg:.2f}"
fig, ax = plt.subplots()
ax.hist(preds, bins=10, color="red" if avg > 0.5 else "green", edgecolor="black")
ax.set_title("Confidence per Frame")
ax.set_xlabel("Fake Probability")
ax.set_ylabel("Frames")
ax.grid(True)
return result, fig
# βœ… Gradio App
with gr.Blocks() as demo:
gr.Markdown("# 🎭 Deepfake Detector with Xception (DFDC)")
gr.Markdown("Upload a `.mp4` video. The app will classify it as REAL or FAKE based on pretrained deepfake model.")
video = gr.Video(label="Upload Video")
output_text = gr.Markdown()
output_plot = gr.Plot()
analyze = gr.Button("πŸ” Analyze")
analyze.click(fn=analyze_deepfake, inputs=video, outputs=[output_text, output_plot])
demo.queue().launch()