mindchain commited on
Commit
4f8d4a7
·
verified ·
1 Parent(s): 178f459

Add training script

Browse files
Files changed (1) hide show
  1. train.py +37 -0
train.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "transformers>=4.36.0", "accelerate>=0.24.0", "trackio"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
7
+ from trl import SFTTrainer, SFTConfig
8
+
9
+ dataset = load_dataset("mindchain/container-status-de", split="train")
10
+ split = dataset.train_test_split(test_size=0.15, seed=42)
11
+
12
+ def fmt(ex):
13
+ return {"text": f"Status: {ex['text']}", "label": ex["label"]}
14
+
15
+ train_ds = split["train"].map(fmt, remove_columns=split["train"].column_names)
16
+ eval_ds = split["test"].map(fmt, remove_columns=split["test"].column_names)
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2-270m")
19
+ model = T5ForConditionalGeneration.from_pretrained("google/t5gemma-2-270m")
20
+
21
+ config = SFTConfig(
22
+ output_dir="out",
23
+ push_to_hub=True,
24
+ hub_model_id="mindchain/t5gemma-270m-container-status",
25
+ num_train_epochs=5,
26
+ per_device_train_batch_size=2,
27
+ gradient_accumulation_steps=4,
28
+ learning_rate=3e-4,
29
+ logging_steps=5,
30
+ max_length=256,
31
+ report_to="trackio",
32
+ )
33
+
34
+ trainer = SFTTrainer(model=model, tokenizer=tokenizer, train_dataset=train_ds, eval_dataset=eval_ds, args=config)
35
+ trainer.train()
36
+ trainer.push_to_hub()
37
+ print('DONE')