passagereptile455 commited on
Commit
c5c3c89
·
verified ·
1 Parent(s): 38463bb

Upload train_sft_demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sft_demo.py +32 -0
train_sft_demo.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "datasets", "transformers", "torch", "accelerate"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ import os
9
+
10
+ # Load a small dataset
11
+ dataset = load_dataset("trl-lib/Capybara", split="train[:500]")
12
+
13
+ # Setup trainer
14
+ trainer = SFTTrainer(
15
+ model="Qwen/Qwen2.5-0.5B",
16
+ train_dataset=dataset,
17
+ peft_config=LoraConfig(r=16, lora_alpha=32, target_modules="all-linear"),
18
+ args=SFTConfig(
19
+ output_dir="qwen-demo-sft",
20
+ max_steps=100,
21
+ per_device_train_batch_size=2,
22
+ gradient_accumulation_steps=4,
23
+ logging_steps=10,
24
+ push_to_hub=True,
25
+ hub_model_id="passagereptile455/qwen-demo-sft",
26
+ hub_private_repo=True,
27
+ )
28
+ )
29
+
30
+ trainer.train()
31
+ trainer.push_to_hub()
32
+ print("Training complete! Model pushed to Hub.")