Skip to content

Commit f7a786e

Browse files
committed
Fixed Seq2seq training issue
1 parent 73063c2 commit f7a786e

16 files changed

+319
-65
lines changed

examples/example.py

Lines changed: 185 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datasets import load_dataset
2-
32
from textpredict import (
43
Benchmarking,
54
Explainability,
@@ -10,28 +9,31 @@
109
clean_text,
1110
initialize,
1211
load_data,
12+
set_device,
1313
)
1414

1515

1616
# Function to test simple prediction using default model
1717
def text_simple_prediction():
18+
set_device("gpu")
19+
1820
# sentiment
1921
texts = ["I love this product!", "I love this product!"]
2022
text = "I love this product!"
21-
model = initialize(task="sentiment")
23+
model = initialize(task="sentiment", device="gpu")
2224
result = model.analyze(texts, return_probs=False)
2325
print(f"Simple Prediction Result: {result}")
2426

2527
# # emotion
2628
text = ["I am happy today", "I am happy today"]
27-
model = initialize(task="emotion")
29+
model = initialize(task="emotion", device="gpu")
2830
result = model.analyze(text, return_probs=False)
2931
print(f"Emotion found : {result}")
3032

3133
# zeroshot
3234
texts = ["I am happy today", "I am happy today"]
3335
text = "I am happy today"
34-
model = initialize(task="zeroshot")
36+
model = initialize(task="zeroshot", device="gpu")
3537

3638
result = model.analyze(
3739
text, candidate_labels=["negative", "positive"], return_probs=True
@@ -49,6 +51,7 @@ def text_simple_prediction():
4951

5052
# Function to test prediction using a Hugging Face model
5153
def text_hf_prediction():
54+
set_device("cuda")
5255
text = "I love this product!"
5356

5457
model = initialize(
@@ -81,7 +84,10 @@ def text_hf_prediction():
8184
print(f"Zeroshot Prediction Result: {result}")
8285

8386
# ner
84-
texts = ["I am in London, united kingdom", "I am in Manchester, united kingdom"]
87+
texts = [
88+
"I am in London, united kingdom",
89+
"I am in Manchester, united kingdom",
90+
] # noqa: F841
8591
text = "I am in Manchester, united kingdom"
8692

8793
model = initialize(task="ner", source="huggingface")
@@ -91,9 +97,11 @@ def text_hf_prediction():
9197

9298
# Function to train a sequence classification model
9399
def train_sequence_classification():
100+
set_device("cuda")
101+
94102
# Load and preprocess the dataset
95-
raw_train_dataset = load_dataset("imdb", split="train[:10]")
96-
raw_validation_dataset = load_dataset("imdb", split="test[:10]")
103+
raw_train_dataset = load_dataset("imdb", split="train[:100]")
104+
raw_validation_dataset = load_dataset("imdb", split="test[:100]")
97105

98106
tokenized_train_dataset = load_data(dataset=raw_train_dataset, splits=["train"])
99107
tokenized_validation_dataset = load_data(
@@ -111,7 +119,7 @@ def train_sequence_classification():
111119
trainer = SequenceClassificationTrainer(
112120
model_name="bert-base-uncased",
113121
output_dir="./results_new",
114-
device="cpu",
122+
device="cuda",
115123
training_config=training_config,
116124
)
117125

@@ -137,8 +145,160 @@ def train_sequence_classification():
137145
print("result", result)
138146

139147

148+
def train_seq2seq():
149+
from datasets import load_dataset # type: ignore
150+
from textpredict import Seq2seqTrainer, load_data
151+
152+
ds = load_dataset("google-research-datasets/mbpp", "sanitized")
153+
154+
# Load dataset
155+
dataset = load_data(
156+
dataset=ds,
157+
splits=["train", "validation", "test"],
158+
text_column="prompt",
159+
label_column="code",
160+
)
161+
162+
# Initialize the trainer
163+
trainer = Seq2seqTrainer(
164+
model_name="google/flan-t5-small",
165+
output_dir="./seq2seq_model",
166+
training_config={
167+
"num_train_epochs": 3,
168+
"per_device_train_batch_size": 8,
169+
"per_device_eval_batch_size": 8,
170+
"learning_rate": 3e-5,
171+
"logging_dir": "./logs",
172+
"evaluation_strategy": "epoch",
173+
"save_strategy": "epoch",
174+
"save_total_limit": 2,
175+
"load_best_model_at_end": True,
176+
},
177+
)
178+
179+
# Set datasets
180+
trainer.train_dataset = dataset["train"]
181+
trainer.val_dataset = dataset["validation"]
182+
183+
# Start training
184+
trainer.train()
185+
186+
# Save the model
187+
trainer.save()
188+
189+
metrics = trainer.get_metrics()
190+
print(f"Training Metrics: {metrics}")
191+
192+
evaluate = trainer.evaluate(test_dataset=dataset["test"])
193+
print(f"Evaluation Metrics: {evaluate}")
194+
195+
model = initialize(model_name="./results_seq2seq", task="seq2seq")
196+
197+
text = "Summarize the following document: ..."
198+
199+
result = model.analyze(text, return_probs=True)
200+
201+
print("result", result)
202+
203+
204+
# def train_token_classification():
205+
206+
# import torch # type: ignore
207+
# from textpredict import TokenClassificationTrainer # noqa: E402
208+
# from transformers import AutoTokenizer # type: ignore
209+
210+
# # Set device to cuda if available
211+
# device = "cuda" if torch.cuda.is_available() else "cpu"
212+
213+
# # Load tokenizer
214+
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
215+
216+
# # Load and preprocess the dataset
217+
# raw_train_dataset = load_dataset("conll2003", split="train[:100]")
218+
# raw_validation_dataset = load_dataset("conll2003", split="validation[:100]")
219+
220+
# # Tokenize the datasets
221+
# def tokenize_and_align_labels(examples):
222+
# tokenized_inputs = tokenizer(
223+
# examples["tokens"],
224+
# truncation=True,
225+
# is_split_into_words=True,
226+
# padding="max_length",
227+
# max_length=128,
228+
# )
229+
# labels = []
230+
# for i, label in enumerate(examples["ner_tags"]):
231+
# word_ids = tokenized_inputs.word_ids(batch_index=i)
232+
# label_ids = []
233+
# previous_word_idx = None
234+
# for word_idx in word_ids:
235+
# if word_idx is None:
236+
# label_ids.append(-100)
237+
# elif word_idx != previous_word_idx:
238+
# label_ids.append(label[word_idx])
239+
# else:
240+
# label_ids.append(-100)
241+
# previous_word_idx = word_idx
242+
# labels.append(label_ids)
243+
# tokenized_inputs["labels"] = labels
244+
# return tokenized_inputs
245+
246+
# tokenized_train_dataset = raw_train_dataset.map(
247+
# tokenize_and_align_labels, batched=True
248+
# )
249+
# tokenized_validation_dataset = raw_validation_dataset.map(
250+
# tokenize_and_align_labels, batched=True
251+
# )
252+
253+
# # Set the format for PyTorch tensors
254+
# tokenized_train_dataset.set_format(
255+
# type="torch", columns=["input_ids", "attention_mask", "labels"]
256+
# )
257+
# tokenized_validation_dataset.set_format(
258+
# type="torch", columns=["input_ids", "attention_mask", "labels"]
259+
# )
260+
261+
# # Define training configuration
262+
# training_config = {
263+
# "num_train_epochs": 3,
264+
# "per_device_train_batch_size": 8,
265+
# }
266+
267+
# # Initialize the trainer
268+
# trainer = TokenClassificationTrainer(
269+
# model_name="bert-base-uncased",
270+
# output_dir="./results_token_classification",
271+
# device=device,
272+
# training_config=training_config,
273+
# )
274+
275+
# # Assign the preprocessed training data to the trainer
276+
# trainer.train_dataset = tokenized_train_dataset
277+
# trainer.val_dataset = tokenized_validation_dataset
278+
279+
# # Train the model
280+
# trainer.train(from_checkpoint=False)
281+
# trainer.save()
282+
# metrics = trainer.get_metrics()
283+
# print(f"Training Metrics: {metrics}")
284+
285+
# evaluate = trainer.evaluate(test_dataset=tokenized_validation_dataset)
286+
# print(f"Evaluation Metrics: {evaluate}")
287+
288+
# model = initialize(
289+
# model_name="./results_token_classification", task="token_classification"
290+
# )
291+
292+
# text = "Hawking was a theoretical physicist."
293+
294+
# result = model.analyze(text, return_probs=True)
295+
296+
# print("result", result)
297+
298+
140299
# Function to evaluate a sequence classification model
141300
def evaluate_sequence_classification():
301+
set_device("cuda")
142302
# Load and preprocess the dataset
143303
raw_test_dataset = load_dataset("imdb", split="test[:10]")
144304

@@ -151,7 +311,7 @@ def evaluate_sequence_classification():
151311

152312
evaluator = SequenceClassificationEvaluator(
153313
model_name="bert-base-uncased",
154-
device="cpu",
314+
device="cuda",
155315
evaluation_config=evaluation_config,
156316
)
157317

@@ -223,17 +383,25 @@ def main():
223383
# print("Running Simple Prediction...")
224384
# text_simple_prediction()
225385

226-
print("\nRunning Hugging Face Prediction...")
227-
text_hf_prediction()
386+
# print("\nRunning Hugging Face Prediction...")
387+
# text_hf_prediction()
388+
389+
# print("\nTraining Sequence Classification Model...")
390+
# train_sequence_classification()
391+
392+
# Run the training function
393+
print("\Trainig Seq2seq Model...")
394+
train_seq2seq()
228395

229-
print("\nTraining Sequence Classification Model...")
230-
train_sequence_classification()
396+
# Run the training function
397+
# print("\Trainig Toekn c;assification Model...")
398+
# train_token_classification()
231399

232-
print("\nEvaluating Sequence Classification Model...")
233-
evaluate_sequence_classification()
400+
# print("\nEvaluating Sequence Classification Model...")
401+
# evaluate_sequence_classification()
234402

235-
print("\nBenchmarking Model...")
236-
benchmark_model()
403+
# print("\nBenchmarking Model...")
404+
# benchmark_model()
237405

238406
# print("\nVisualizing Metrics...")
239407
# visualize_metrics()

examples/training_seq2seq.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import logging
2+
3+
from datasets import load_dataset # type: ignore
4+
from textpredict import Seq2seqTrainer, initialize, load_data, set_device
5+
6+
logging.basicConfig(level=logging.INFO)
7+
8+
9+
def train_seq2seq():
10+
set_device("cpu")
11+
12+
# Load the dataset
13+
ds = load_dataset("google-research-datasets/mbpp", "sanitized")
14+
15+
# Load and preprocess dataset with specified text column
16+
dataset = load_data(
17+
dataset=ds,
18+
splits=["train", "test"],
19+
text_column="prompt",
20+
label_column="code",
21+
)
22+
23+
# Initialize the trainer
24+
trainer = Seq2seqTrainer(
25+
model_name="google/flan-t5-small",
26+
output_dir="./seq2seq_model",
27+
training_config={
28+
"num_train_epochs": 0.064,
29+
"per_device_train_batch_size": 2,
30+
"per_device_eval_batch_size": 2,
31+
"learning_rate": 3e-5,
32+
"logging_dir": "./logs",
33+
"evaluation_strategy": "epoch",
34+
"save_strategy": "epoch",
35+
"save_total_limit": 2,
36+
"load_best_model_at_end": True,
37+
},
38+
device="cpu",
39+
)
40+
41+
# Set datasets
42+
trainer.train_dataset = dataset["train"]
43+
trainer.val_dataset = dataset["test"]
44+
45+
# Start training
46+
trainer.train()
47+
48+
# Save the model
49+
trainer.save()
50+
51+
# Get training metrics
52+
metrics = trainer.get_metrics()
53+
print(f"Training Metrics: {metrics}")
54+
55+
# Evaluate the model
56+
evaluate = trainer.evaluate(test_dataset=dataset["test"])
57+
print(f"Evaluation Metrics: {evaluate}")
58+
59+
# Load the trained model
60+
model = initialize(model_name="./seq2seq_model", task="seq2seq")
61+
62+
# Analyze a sample text
63+
text = "Summarize the following document: ..."
64+
result = model.analyze(text, return_probs=True)
65+
print("Result:", result)
66+
67+
68+
if __name__ == "__main__":
69+
train_seq2seq()

0 commit comments

Comments
 (0)