Shago's picture
Update llm.py
9a57e7c verified
raw
history blame
1.17 kB
from transformers import pipeline
import torch
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
# Initialize HF pipeline for text generation
text_generator = pipeline(
"text-generation", # Task type
model="google/gemma-3n-e4b-it",
# device="cuda" if torch.cuda.is_available() else "cpu",
device= "cpu",,
torch_dtype=torch.bfloat16,
max_new_tokens=500 # Limit output length
)
# Wrap pipeline for LangChain compatibility
model = HuggingFacePipeline(pipeline=text_generator)
def generate_sentences(topic, n=1):
prompt = ChatPromptTemplate.from_template(
"You are a helpful assistant. Generate exactly {n} simple sentences about the topic: {topic}. "
"Each sentence must be in English and appropriate for all audiences. "
"Return each sentence on a new line without any numbering or bullets"
)
chain = prompt | model | StrOutputParser()
response = chain.invoke({"topic": topic, "n": n})
return [s.strip() for s in response.splitlines() if s.strip()][:n]