Spaces:
Sleeping
Sleeping
gary-boon
Claude
commited on
Commit
·
5aed1a9
1
Parent(s):
992dc8c
Add layer_stride parameter for PromptDiff optimization
Browse files- Add layer_stride parameter to control which layers are captured
- Default to 1 (all layers) for AttentionExplorer
- PromptDiff can use layer_stride=2 for every other layer
- Reduces matrix count from 20 to 10 for better visualization fit
🤖 Generated with Claude Code (https://claude.ai/code)
Co-Authored-By: Claude <[email protected]>
backend/__pycache__/model_service.cpython-310.pyc
CHANGED
|
Binary files a/backend/__pycache__/model_service.cpython-310.pyc and b/backend/__pycache__/model_service.cpython-310.pyc differ
|
|
|
backend/model_service.py
CHANGED
|
@@ -47,6 +47,7 @@ class GenerationRequest(BaseModel):
|
|
| 47 |
top_p: Optional[float] = None
|
| 48 |
extract_traces: bool = True
|
| 49 |
sampling_rate: float = 0.005
|
|
|
|
| 50 |
|
| 51 |
class AblatedGenerationRequest(BaseModel):
|
| 52 |
prompt: str
|
|
@@ -500,7 +501,8 @@ class ModelManager:
|
|
| 500 |
temperature: float = 0.7,
|
| 501 |
top_k: Optional[int] = None,
|
| 502 |
top_p: Optional[float] = None,
|
| 503 |
-
sampling_rate: float = 0.005
|
|
|
|
| 504 |
) -> Dict[str, Any]:
|
| 505 |
"""Generate text with trace extraction"""
|
| 506 |
if not self.model or not self.tokenizer:
|
|
@@ -643,8 +645,8 @@ class ModelManager:
|
|
| 643 |
# Clear previous partial traces and add complete ones
|
| 644 |
traces = [] # Reset traces to only include complete attention patterns
|
| 645 |
|
| 646 |
-
# Capture
|
| 647 |
-
for layer_idx in range(num_layers):
|
| 648 |
try:
|
| 649 |
# Get all token IDs (prompt + generated)
|
| 650 |
all_token_ids = inputs["input_ids"][0].tolist()
|
|
@@ -892,7 +894,8 @@ async def generate(request: GenerationRequest, authenticated: bool = Depends(ver
|
|
| 892 |
temperature=request.temperature,
|
| 893 |
top_k=request.top_k,
|
| 894 |
top_p=request.top_p,
|
| 895 |
-
sampling_rate=request.sampling_rate if request.extract_traces else 0
|
|
|
|
| 896 |
)
|
| 897 |
return result
|
| 898 |
|
|
|
|
| 47 |
top_p: Optional[float] = None
|
| 48 |
extract_traces: bool = True
|
| 49 |
sampling_rate: float = 0.005
|
| 50 |
+
layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc.
|
| 51 |
|
| 52 |
class AblatedGenerationRequest(BaseModel):
|
| 53 |
prompt: str
|
|
|
|
| 501 |
temperature: float = 0.7,
|
| 502 |
top_k: Optional[int] = None,
|
| 503 |
top_p: Optional[float] = None,
|
| 504 |
+
sampling_rate: float = 0.005,
|
| 505 |
+
layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc.
|
| 506 |
) -> Dict[str, Any]:
|
| 507 |
"""Generate text with trace extraction"""
|
| 508 |
if not self.model or not self.tokenizer:
|
|
|
|
| 645 |
# Clear previous partial traces and add complete ones
|
| 646 |
traces = [] # Reset traces to only include complete attention patterns
|
| 647 |
|
| 648 |
+
# Capture layers based on stride (1 = all, 2 = every other, etc.)
|
| 649 |
+
for layer_idx in range(0, num_layers, layer_stride):
|
| 650 |
try:
|
| 651 |
# Get all token IDs (prompt + generated)
|
| 652 |
all_token_ids = inputs["input_ids"][0].tolist()
|
|
|
|
| 894 |
temperature=request.temperature,
|
| 895 |
top_k=request.top_k,
|
| 896 |
top_p=request.top_p,
|
| 897 |
+
sampling_rate=request.sampling_rate if request.extract_traces else 0,
|
| 898 |
+
layer_stride=request.layer_stride
|
| 899 |
)
|
| 900 |
return result
|
| 901 |
|
components/ui/tooltip.tsx
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import * as React from "react"
|
| 4 |
+
import * as TooltipPrimitive from "@radix-ui/react-tooltip"
|
| 5 |
+
|
| 6 |
+
import { cn } from "@/lib/utils"
|
| 7 |
+
|
| 8 |
+
const TooltipProvider = TooltipPrimitive.Provider
|
| 9 |
+
|
| 10 |
+
const Tooltip = TooltipPrimitive.Root
|
| 11 |
+
|
| 12 |
+
const TooltipTrigger = TooltipPrimitive.Trigger
|
| 13 |
+
|
| 14 |
+
const TooltipContent = React.forwardRef<
|
| 15 |
+
React.ElementRef<typeof TooltipPrimitive.Content>,
|
| 16 |
+
React.ComponentPropsWithoutRef<typeof TooltipPrimitive.Content>
|
| 17 |
+
>(({ className, sideOffset = 4, ...props }, ref) => (
|
| 18 |
+
<TooltipPrimitive.Content
|
| 19 |
+
ref={ref}
|
| 20 |
+
sideOffset={sideOffset}
|
| 21 |
+
className={cn(
|
| 22 |
+
"z-50 overflow-hidden rounded-md bg-gray-900 px-3 py-1.5 text-sm text-gray-100 shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 border border-gray-800",
|
| 23 |
+
className
|
| 24 |
+
)}
|
| 25 |
+
{...props}
|
| 26 |
+
/>
|
| 27 |
+
))
|
| 28 |
+
TooltipContent.displayName = TooltipPrimitive.Content.displayName
|
| 29 |
+
|
| 30 |
+
export { Tooltip, TooltipTrigger, TooltipContent, TooltipProvider }
|