Jcalemcg commited on
Commit
d7872a8
Β·
verified Β·
1 Parent(s): e737baa

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +213 -0
train.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fine-tune Zephyr 7B on CyberSecurity Dataset Collection
4
+ Runs on Hugging Face Spaces infrastructure
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ from datasets import load_dataset, concatenate_datasets
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ BitsAndBytesConfig,
14
+ TrainingArguments,
15
+ Trainer,
16
+ DataCollatorForLanguageModeling
17
+ )
18
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
19
+
20
+ # Configuration
21
+ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
22
+ OUTPUT_MODEL_NAME = "Jcalemcg/zephyr-7b-cybersecurity-finetuned"
23
+
24
+ # CyberSecurity datasets from thelordofweb collection
25
+ CYBERSECURITY_DATASETS = [
26
+ "AlicanKiraz0/All-CVE-Records-Training-Dataset",
27
+ "AlicanKiraz0/Cybersecurity-Dataset-v1",
28
+ "Bouquets/Cybersecurity-LLM-CVE",
29
+ "CyberNative/CyberSecurityEval",
30
+ "Mohabahmed03/Alpaca_Dataset_CyberSecurity_Smaller",
31
+ "CyberNative/github_cybersecurity_READMEs",
32
+ "AlicanKiraz0/Cybersecurity-Dataset-Heimdall-v1.1",
33
+ "jcordon5/cybersecurity-rules",
34
+ "Bouquets/DeepSeek-V3-Distill-Cybersecurity-en",
35
+ "Seerene/cybersecurity_dataset",
36
+ "ahmedds10/finetuning_alpaca_Cybersecurity",
37
+ "Tiamz/cybersecurity-instruction-dataset",
38
+ "OhWayTee/Cybersecurity-News_3",
39
+ "Trendyol/All-CVE-Chat-MultiTurn-1999-2025-Dataset",
40
+ "Vanessasml/cyber-reports-news-analysis-llama2-3k",
41
+ "Vanessasml/cybersecurity_32k_instruction_input_output",
42
+ "Vanessasml/enisa_cyber_news_dataset",
43
+ "Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
44
+ ]
45
+
46
+ def format_instruction(example):
47
+ """Format examples into Zephyr chat format"""
48
+ if "instruction" in example and "output" in example:
49
+ prompt = f"<|user|>\n{example['instruction']}"
50
+ if example.get("input", "").strip():
51
+ prompt += f"\n{example['input']}"
52
+ prompt += f"</s>\n<|assistant|>\n{example['output']}</s>"
53
+ return {"text": prompt}
54
+ elif "question" in example and "answer" in example:
55
+ return {"text": f"<|user|>\n{example['question']}</s>\n<|assistant|>\n{example['answer']}</s>"}
56
+ elif "prompt" in example and "completion" in example:
57
+ return {"text": f"<|user|>\n{example['prompt']}</s>\n<|assistant|>\n{example['completion']}</s>"}
58
+ elif "text" in example:
59
+ return {"text": example["text"]}
60
+ elif "messages" in example:
61
+ formatted_text = ""
62
+ for msg in example["messages"]:
63
+ role = msg.get("role", "")
64
+ content = msg.get("content", "")
65
+ if role == "user":
66
+ formatted_text += f"<|user|>\n{content}</s>\n"
67
+ elif role == "assistant":
68
+ formatted_text += f"<|assistant|>\n{content}</s>\n"
69
+ return {"text": formatted_text}
70
+ return {"text": str(example)}
71
+
72
+ def load_datasets():
73
+ """Load and prepare cybersecurity datasets"""
74
+ print("=" * 70)
75
+ print("LOADING CYBERSECURITY DATASETS")
76
+ print("=" * 70)
77
+ all_datasets = []
78
+
79
+ for dataset_name in CYBERSECURITY_DATASETS:
80
+ try:
81
+ print(f"\nLoading: {dataset_name}")
82
+ dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
83
+ formatted = dataset.map(
84
+ format_instruction,
85
+ remove_columns=dataset.column_names,
86
+ desc="Formatting"
87
+ )
88
+ if len(formatted) > 10000:
89
+ formatted = formatted.shuffle(seed=42).select(range(10000))
90
+ all_datasets.append(formatted)
91
+ print(f"βœ“ {len(formatted)} examples loaded")
92
+ except Exception as e:
93
+ print(f"βœ— Failed: {e}")
94
+
95
+ combined = concatenate_datasets(all_datasets)
96
+ print(f"\n{'='*70}")
97
+ print(f"TOTAL DATASET SIZE: {len(combined):,} examples")
98
+ print(f"{'='*70}\n")
99
+
100
+ combined = combined.shuffle(seed=42)
101
+ return combined.train_test_split(test_size=0.05, seed=42)
102
+
103
+ def setup_model():
104
+ """Setup model with QLoRA"""
105
+ print("Setting up Zephyr 7B with QLoRA...")
106
+
107
+ bnb_config = BitsAndBytesConfig(
108
+ load_in_4bit=True,
109
+ bnb_4bit_quant_type="nf4",
110
+ bnb_4bit_compute_dtype=torch.float16,
111
+ bnb_4bit_use_double_quant=True,
112
+ )
113
+
114
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
115
+ tokenizer.pad_token = tokenizer.eos_token
116
+ tokenizer.padding_side = "right"
117
+
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ MODEL_NAME,
120
+ quantization_config=bnb_config,
121
+ device_map="auto",
122
+ trust_remote_code=True,
123
+ )
124
+
125
+ model = prepare_model_for_kbit_training(model)
126
+
127
+ lora_config = LoraConfig(
128
+ r=16,
129
+ lora_alpha=32,
130
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
131
+ lora_dropout=0.05,
132
+ bias="none",
133
+ task_type="CAUSAL_LM"
134
+ )
135
+
136
+ model = get_peft_model(model, lora_config)
137
+ model.print_trainable_parameters()
138
+
139
+ return model, tokenizer
140
+
141
+ def main():
142
+ print("\n" + "=" * 70)
143
+ print("ZEPHYR 7B CYBERSECURITY FINE-TUNING")
144
+ print("=" * 70 + "\n")
145
+
146
+ # Load data
147
+ datasets = load_datasets()
148
+ train_data = datasets["train"]
149
+ eval_data = datasets["test"]
150
+
151
+ # Setup model
152
+ model, tokenizer = setup_model()
153
+
154
+ # Tokenize
155
+ print("\nTokenizing datasets...")
156
+ def tokenize(examples):
157
+ return tokenizer(examples["text"], truncation=True, max_length=2048, padding="max_length")
158
+
159
+ train_data = train_data.map(tokenize, batched=True, remove_columns=train_data.column_names)
160
+ eval_data = eval_data.map(tokenize, batched=True, remove_columns=eval_data.column_names)
161
+
162
+ # Training config
163
+ training_args = TrainingArguments(
164
+ output_dir="./output",
165
+ num_train_epochs=3,
166
+ per_device_train_batch_size=4,
167
+ per_device_eval_batch_size=4,
168
+ gradient_accumulation_steps=4,
169
+ learning_rate=2e-4,
170
+ fp16=True,
171
+ save_strategy="steps",
172
+ save_steps=500,
173
+ eval_strategy="steps",
174
+ eval_steps=500,
175
+ logging_steps=50,
176
+ warmup_steps=100,
177
+ lr_scheduler_type="cosine",
178
+ optim="paged_adamw_8bit",
179
+ save_total_limit=3,
180
+ load_best_model_at_end=True,
181
+ push_to_hub=True,
182
+ hub_model_id=OUTPUT_MODEL_NAME,
183
+ hub_strategy="every_save",
184
+ report_to="tensorboard",
185
+ )
186
+
187
+ # Train
188
+ trainer = Trainer(
189
+ model=model,
190
+ args=training_args,
191
+ train_dataset=train_data,
192
+ eval_dataset=eval_data,
193
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
194
+ )
195
+
196
+ print("\n" + "=" * 70)
197
+ print("STARTING TRAINING")
198
+ print("=" * 70 + "\n")
199
+
200
+ trainer.train()
201
+
202
+ print("\nSaving model...")
203
+ trainer.save_model()
204
+ model.push_to_hub(OUTPUT_MODEL_NAME)
205
+ tokenizer.push_to_hub(OUTPUT_MODEL_NAME)
206
+
207
+ print("\n" + "=" * 70)
208
+ print("βœ“ TRAINING COMPLETE")
209
+ print(f"βœ“ Model: {OUTPUT_MODEL_NAME}")
210
+ print("=" * 70)
211
+
212
+ if __name__ == "__main__":
213
+ main()