ClinicalThought-AI-8B / Scripts /Inference_safetensors.py
Raymond-dev-546730's picture
Upload 2 files
44b6aa9 verified
raw
history blame
1.41 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Insert your medical query here
MEDICAL_QUERY = """
"""
def load_model(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate_response(model, tokenizer, medical_query):
medical_query = medical_query.strip()
prompt = f"USER: <medical_query>{medical_query}</medical_query>\nASSISTANT:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=12000,
temperature=0.3,
top_p=0.7,
repetition_penalty=1.05,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
full_response = response.split("ASSISTANT:")[-1].strip()
if "</answer>" in full_response:
end_pos = full_response.find("</answer>") + len("</answer>")
return full_response[:end_pos]
return full_response
def run():
model_path = "./" # Path to the directory containing your model weight files
model, tokenizer = load_model(model_path)
result = generate_response(model, tokenizer, MEDICAL_QUERY)
print(result)
if __name__ == "__main__":
run()